mirror of
https://github.com/amd/blis.git
synced 2026-05-12 01:59:59 +00:00
3m_sqp conjugate support added
1. 3m_sqp support for A matrix with conjugate_no_transpose and conjugate_transpose added. AMD-Internal: [CPUPL-1521] Change-Id: Ie6e5c49cf86f7d3b95d78705cf445e57f20b3d1f
This commit is contained in:
@@ -666,8 +666,7 @@ void zgemm_
|
||||
sqp_on = true;
|
||||
}
|
||||
#endif
|
||||
if( ( ( blis_transa == BLIS_TRANSPOSE ) || ( blis_transa == BLIS_NO_TRANSPOSE ) )
|
||||
&& ( blis_transb == BLIS_NO_TRANSPOSE) && (sqp_on == true))
|
||||
if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) )
|
||||
{
|
||||
//sqp algo is found better for n > 40
|
||||
if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS)
|
||||
|
||||
@@ -40,16 +40,31 @@
|
||||
#define BLIS_LOADFIRST 0
|
||||
#define MEM_ALLOC 1//malloc performs better than bli_malloc.
|
||||
|
||||
#define SET_TRANS(X,Y)\
|
||||
Y = BLIS_NO_TRANSPOSE;\
|
||||
if(bli_obj_has_trans( a ))\
|
||||
{\
|
||||
Y = BLIS_TRANSPOSE;\
|
||||
if(bli_obj_has_conj(a))\
|
||||
{\
|
||||
Y = BLIS_CONJ_TRANSPOSE;\
|
||||
}\
|
||||
}\
|
||||
else if(bli_obj_has_conj(a))\
|
||||
{\
|
||||
Y = BLIS_CONJ_NO_TRANSPOSE;\
|
||||
}
|
||||
|
||||
//Macro for 3m_sqp n loop
|
||||
#define BLI_SQP_ZGEMM_N(MX)\
|
||||
int j=0;\
|
||||
for(; j<=(n-nx); j+= nx)\
|
||||
{\
|
||||
status = bli_sqp_zgemm_m8( m, nx, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, isTransA, MX, p_istart, kx, &mem_3m_sqp);\
|
||||
status = bli_sqp_zgemm_m8( m, nx, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, transa, MX, p_istart, kx, &mem_3m_sqp);\
|
||||
}\
|
||||
if(j<n)\
|
||||
{\
|
||||
status = bli_sqp_zgemm_m8( m, n-j, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, isTransA, MX, p_istart, kx, &mem_3m_sqp);\
|
||||
status = bli_sqp_zgemm_m8( m, n-j, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, transa, MX, p_istart, kx, &mem_3m_sqp);\
|
||||
}
|
||||
|
||||
//Macro for sqp_dgemm n loop
|
||||
@@ -135,7 +150,7 @@ BLIS_INLINE err_t bli_sqp_zgemm( gint_t m,
|
||||
guint_t ldc,
|
||||
double alpha,
|
||||
double beta,
|
||||
bool isTransA,
|
||||
trans_t transa,
|
||||
dim_t nt);
|
||||
|
||||
BLIS_INLINE err_t bli_sqp_dgemm( gint_t m,
|
||||
@@ -181,7 +196,7 @@ err_t bli_gemm_sqp
|
||||
return BLIS_INVALID_ROW_STRIDE;
|
||||
}
|
||||
|
||||
if(bli_obj_has_conj(a) || bli_obj_has_conj(b))
|
||||
if(bli_obj_has_conj(b))
|
||||
{
|
||||
return BLIS_NOT_YET_IMPLEMENTED;
|
||||
}
|
||||
@@ -216,6 +231,9 @@ err_t bli_gemm_sqp
|
||||
isTransA = true;
|
||||
}
|
||||
|
||||
trans_t transa = BLIS_NO_TRANSPOSE;
|
||||
SET_TRANS(a,transa)
|
||||
|
||||
dim_t nt = bli_thread_get_num_threads(); // get number of threads
|
||||
|
||||
double* ap = ( double* )bli_obj_buffer( a );
|
||||
@@ -237,7 +255,7 @@ err_t bli_gemm_sqp
|
||||
return BLIS_NOT_YET_IMPLEMENTED;
|
||||
}
|
||||
//printf("zsqp ");
|
||||
return bli_sqp_zgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA, nt);
|
||||
return bli_sqp_zgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, transa, nt);
|
||||
}
|
||||
else if(dt == BLIS_DOUBLE)
|
||||
{
|
||||
@@ -254,6 +272,8 @@ err_t bli_gemm_sqp
|
||||
return BLIS_NOT_YET_IMPLEMENTED;
|
||||
}
|
||||
//printf("dsqp ");
|
||||
// dgemm case only transpose or no-transpose is handled.
|
||||
// conjugate_transpose and conjugate no transpose are not applicable.
|
||||
return bli_sqp_dgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, *alpha_cast, *beta_cast, isTransA, nt);
|
||||
}
|
||||
|
||||
@@ -627,7 +647,7 @@ BLIS_INLINE void bli_sqp_zgemm_kx( gint_t m,
|
||||
guint_t ldb,
|
||||
double* c,
|
||||
guint_t ldc,
|
||||
bool isTransA,
|
||||
trans_t transa,
|
||||
double alpha,
|
||||
double beta,
|
||||
gint_t mx,
|
||||
@@ -653,7 +673,7 @@ BLIS_INLINE void bli_sqp_zgemm_kx( gint_t m,
|
||||
double* pas = as;
|
||||
|
||||
/* a matrix real and imag packing and compute. */
|
||||
bli_3m_sqp_packA_real_imag_sum(a, i, kx+p, lda, par, pai, pas, isTransA, mx, p);
|
||||
bli_3m_sqp_packA_real_imag_sum(a, i, kx+p, lda, par, pai, pas, transa, mx, p);
|
||||
|
||||
double* pcr = cr;
|
||||
double* pci = ci;
|
||||
@@ -874,7 +894,7 @@ BLIS_INLINE err_t bli_sqp_zgemm_m8( gint_t m,
|
||||
guint_t ldc,
|
||||
double alpha,
|
||||
double beta,
|
||||
bool isTransA,
|
||||
trans_t transa,
|
||||
gint_t mx,
|
||||
gint_t* p_istart,
|
||||
gint_t kx,
|
||||
@@ -926,14 +946,14 @@ BLIS_INLINE err_t bli_sqp_zgemm_m8( gint_t m,
|
||||
for(; p <= (k-kx); p += kx)
|
||||
{
|
||||
bli_sqp_zgemm_kx(m, n, kx, p, a, lda, k, c, ldc,
|
||||
isTransA, alpha, beta, mx, i, ar, ai, as,
|
||||
transa, alpha, beta, mx, i, ar, ai, as,
|
||||
br + p, bi + p, bs + p, cr, ci, w, a_aligned);
|
||||
}// k loop end
|
||||
|
||||
if(p<k)
|
||||
{
|
||||
bli_sqp_zgemm_kx(m, n, (k - p), p, a, lda, k, c, ldc,
|
||||
isTransA, alpha, beta, mx, i, ar, ai, as,
|
||||
transa, alpha, beta, mx, i, ar, ai, as,
|
||||
br + p, bi + p, bs + p, cr, ci, w, a_aligned);
|
||||
}
|
||||
#else//kloop
|
||||
@@ -946,7 +966,7 @@ BLIS_INLINE err_t bli_sqp_zgemm_m8( gint_t m,
|
||||
double* pas = as;
|
||||
|
||||
/* a matrix real and imag packing and compute. */
|
||||
bli_3m_sqp_packA_real_imag_sum(a, i, k, lda, par, pai, pas, isTransA, mx, 0);
|
||||
bli_3m_sqp_packA_real_imag_sum(a, i, k, lda, par, pai, pas, transa, mx, 0);
|
||||
|
||||
double* pcr = cr;
|
||||
double* pci = ci;
|
||||
@@ -1035,7 +1055,7 @@ BLIS_INLINE err_t bli_sqp_zgemm(gint_t m,
|
||||
guint_t ldc,
|
||||
double alpha_real,
|
||||
double beta_real,
|
||||
bool isTransA,
|
||||
trans_t transa,
|
||||
dim_t nt)
|
||||
{
|
||||
gint_t istart = 0;
|
||||
@@ -1078,7 +1098,7 @@ BLIS_INLINE err_t bli_sqp_zgemm(gint_t m,
|
||||
kx = k;
|
||||
}
|
||||
// for tn case there is a bug in handling k parts. To be fixed.
|
||||
if(isTransA==true)
|
||||
if(transa!=BLIS_NO_TRANSPOSE)
|
||||
{
|
||||
kx = k;
|
||||
}
|
||||
|
||||
@@ -1412,19 +1412,18 @@ void bli_3m_sqp_packA_real_imag_sum(double *pa,
|
||||
double *par,
|
||||
double *pai,
|
||||
double *pas,
|
||||
bool isTransA,
|
||||
trans_t transa,
|
||||
gint_t mx,
|
||||
gint_t p)
|
||||
{
|
||||
__m256d av0, av1, av2, av3;
|
||||
__m256d tv0, tv1, sum;
|
||||
__m256d tv0, tv1, sum, zerov;
|
||||
gint_t poffset = p;
|
||||
#if KLP
|
||||
//k = p + k;
|
||||
#endif
|
||||
if(mx==8)
|
||||
{
|
||||
if(isTransA==false)
|
||||
if(transa == BLIS_NO_TRANSPOSE)
|
||||
{
|
||||
pa = pa +i;
|
||||
#if KLP
|
||||
@@ -1432,7 +1431,6 @@ void bli_3m_sqp_packA_real_imag_sum(double *pa,
|
||||
#else
|
||||
p = 0;
|
||||
#endif
|
||||
//printf("packA from p_%d to p_%d \n", p, k);
|
||||
for (; p < k; p += 1)
|
||||
{
|
||||
//for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each.
|
||||
@@ -1462,12 +1460,51 @@ void bli_3m_sqp_packA_real_imag_sum(double *pa,
|
||||
pa = pa + lda;
|
||||
}
|
||||
}
|
||||
else
|
||||
else if(transa == BLIS_CONJ_NO_TRANSPOSE)
|
||||
{
|
||||
zerov = _mm256_setzero_pd();
|
||||
pa = pa +i;
|
||||
#if KLP
|
||||
pa = pa + (p*lda);
|
||||
#else
|
||||
p = 0;
|
||||
#endif
|
||||
for (; p < k; p += 1)
|
||||
{
|
||||
//for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each.
|
||||
av0 = _mm256_loadu_pd(pa);
|
||||
av1 = _mm256_loadu_pd(pa+4);
|
||||
av2 = _mm256_loadu_pd(pa+8);
|
||||
av3 = _mm256_loadu_pd(pa+12);
|
||||
|
||||
tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);
|
||||
tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);
|
||||
av0 = _mm256_unpacklo_pd(tv0, tv1);
|
||||
av1 = _mm256_unpackhi_pd(tv0, tv1);
|
||||
av1 = _mm256_sub_pd(zerov,av1);//negate imaginary component
|
||||
sum = _mm256_add_pd(av0, av1);
|
||||
_mm256_storeu_pd(par, av0); par += 4;
|
||||
_mm256_storeu_pd(pai, av1); pai += 4;
|
||||
_mm256_storeu_pd(pas, sum); pas += 4;
|
||||
|
||||
tv0 = _mm256_permute2f128_pd(av2, av3, 0x20);
|
||||
tv1 = _mm256_permute2f128_pd(av2, av3, 0x31);
|
||||
av2 = _mm256_unpacklo_pd(tv0, tv1);
|
||||
av3 = _mm256_unpackhi_pd(tv0, tv1);
|
||||
av3 = _mm256_sub_pd(zerov,av3);//negate imaginary component
|
||||
sum = _mm256_add_pd(av2, av3);
|
||||
_mm256_storeu_pd(par, av2); par += 4;
|
||||
_mm256_storeu_pd(pai, av3); pai += 4;
|
||||
_mm256_storeu_pd(pas, sum); pas += 4;
|
||||
|
||||
pa = pa + lda;
|
||||
}
|
||||
}
|
||||
else if(transa == BLIS_TRANSPOSE)
|
||||
{
|
||||
gint_t idx = (i/2) * lda;
|
||||
pa = pa + idx;
|
||||
#if KLP
|
||||
//pa = pa + p;
|
||||
#else
|
||||
p = 0;
|
||||
#endif
|
||||
@@ -1527,18 +1564,79 @@ void bli_3m_sqp_packA_real_imag_sum(double *pa,
|
||||
}
|
||||
}
|
||||
}
|
||||
} //mx==8
|
||||
else//mx==1
|
||||
{
|
||||
if(isTransA==false)
|
||||
else if(transa == BLIS_CONJ_TRANSPOSE)
|
||||
{
|
||||
pa = pa + i;
|
||||
gint_t idx = (i/2) * lda;
|
||||
pa = pa + idx;
|
||||
#if KLP
|
||||
#else
|
||||
p = 0;
|
||||
#endif
|
||||
//A conjugate Transpose case:
|
||||
for (gint_t ii = 0; ii < BLIS_MX8 ; ii++)
|
||||
{
|
||||
gint_t idx = ii * lda;
|
||||
gint_t sidx;
|
||||
gint_t pidx = 0;
|
||||
gint_t max_k = (k*2) - 8;
|
||||
for (p = poffset; p <= max_k; p += 8)
|
||||
{
|
||||
double ar0_ = *(pa + idx + p);
|
||||
double ai0_ = -(*(pa + idx + p + 1));
|
||||
|
||||
double ar1_ = *(pa + idx + p + 2);
|
||||
double ai1_ = -(*(pa + idx + p + 3));
|
||||
|
||||
double ar2_ = *(pa + idx + p + 4);
|
||||
double ai2_ = -(*(pa + idx + p + 5));
|
||||
|
||||
double ar3_ = *(pa + idx + p + 6);
|
||||
double ai3_ = -(*(pa + idx + p + 7));
|
||||
|
||||
sidx = (pidx/2) * BLIS_MX8;
|
||||
*(par + sidx + ii) = ar0_;
|
||||
*(pai + sidx + ii) = ai0_;
|
||||
*(pas + sidx + ii) = ar0_ + ai0_;
|
||||
|
||||
sidx = ((pidx+2)/2) * BLIS_MX8;
|
||||
*(par + sidx + ii) = ar1_;
|
||||
*(pai + sidx + ii) = ai1_;
|
||||
*(pas + sidx + ii) = ar1_ + ai1_;
|
||||
|
||||
sidx = ((pidx+4)/2) * BLIS_MX8;
|
||||
*(par + sidx + ii) = ar2_;
|
||||
*(pai + sidx + ii) = ai2_;
|
||||
*(pas + sidx + ii) = ar2_ + ai2_;
|
||||
|
||||
sidx = ((pidx+6)/2) * BLIS_MX8;
|
||||
*(par + sidx + ii) = ar3_;
|
||||
*(pai + sidx + ii) = ai3_;
|
||||
*(pas + sidx + ii) = ar3_ + ai3_;
|
||||
pidx += 8;
|
||||
}
|
||||
|
||||
for (; p < (k*2); p += 2)
|
||||
{
|
||||
double ar_ = *(pa + idx + p);
|
||||
double ai_ = -(*(pa + idx + p + 1));
|
||||
gint_t sidx = (pidx/2) * BLIS_MX8;
|
||||
*(par + sidx + ii) = ar_;
|
||||
*(pai + sidx + ii) = ai_;
|
||||
*(pas + sidx + ii) = ar_ + ai_;
|
||||
pidx += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
} //mx==8
|
||||
else//mx==1
|
||||
{
|
||||
if(transa == BLIS_NO_TRANSPOSE)
|
||||
{
|
||||
pa = pa + i;
|
||||
#if KLP
|
||||
//pa = pa + (p*lda); done below.. not needed
|
||||
#else
|
||||
p = 0;
|
||||
#endif
|
||||
//printf(" packAx1 from p_%d to p_%d ",p,k-1);
|
||||
//A No transpose case:
|
||||
for (; p < k; p += 1)
|
||||
{
|
||||
@@ -1554,12 +1652,33 @@ void bli_3m_sqp_packA_real_imag_sum(double *pa,
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
else if(transa == BLIS_CONJ_NO_TRANSPOSE)
|
||||
{
|
||||
pa = pa + i;
|
||||
#if KLP
|
||||
#else
|
||||
p = 0;
|
||||
#endif
|
||||
//A conjuate No transpose case:
|
||||
for (; p < k; p += 1)
|
||||
{
|
||||
gint_t idx = p * lda;
|
||||
for (gint_t ii = 0; ii < (mx*2) ; ii += 2)
|
||||
{ //real + imag : Rkernel needs 8 elements each.
|
||||
double ar_ = *(pa + idx + ii);
|
||||
double ai_ = -(*(pa + idx + ii + 1));// conjugate: negate imaginary component
|
||||
*par = ar_;
|
||||
*pai = ai_;
|
||||
*pas = ar_ + ai_;
|
||||
par++; pai++; pas++;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(transa == BLIS_TRANSPOSE)
|
||||
{
|
||||
gint_t idx = (i/2) * lda;
|
||||
pa = pa + idx;
|
||||
#if KLP
|
||||
//pa = pa + p; done below.. not needed
|
||||
#else
|
||||
p = 0;
|
||||
#endif
|
||||
@@ -1583,6 +1702,34 @@ void bli_3m_sqp_packA_real_imag_sum(double *pa,
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(transa == BLIS_CONJ_TRANSPOSE)
|
||||
{
|
||||
gint_t idx = (i/2) * lda;
|
||||
pa = pa + idx;
|
||||
#if KLP
|
||||
#else
|
||||
p = 0;
|
||||
#endif
|
||||
//A Transpose case:
|
||||
for (gint_t ii = 0; ii < mx ; ii++)
|
||||
{
|
||||
gint_t idx = ii * lda;
|
||||
gint_t sidx;
|
||||
gint_t pidx = 0;
|
||||
for (p = poffset;p < (k*2); p += 2)
|
||||
{
|
||||
double ar0_ = *(pa + idx + p);
|
||||
double ai0_ = -(*(pa + idx + p + 1));
|
||||
|
||||
sidx = (pidx/2) * mx;
|
||||
*(par + sidx + ii) = ar0_;
|
||||
*(pai + sidx + ii) = ai0_;
|
||||
*(pas + sidx + ii) = ar0_ + ai0_;
|
||||
pidx += 2;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}//mx==1
|
||||
}
|
||||
|
||||
|
||||
@@ -62,4 +62,4 @@ void bli_prepackA_1(double* pa, double* aPacked, gint_t k, guint_t lda, bool isT
|
||||
/* Pack real and imaginary parts in separate buffers and also multipy with multiplication factor */
|
||||
void bli_3m_sqp_packC_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul, gint_t mx);
|
||||
void bli_3m_sqp_packB_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul, gint_t mx);
|
||||
void bli_3m_sqp_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, bool isTransA, gint_t mx, gint_t p);
|
||||
void bli_3m_sqp_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, trans_t transa, gint_t mx, gint_t p);
|
||||
Reference in New Issue
Block a user