mirror of
https://github.com/amd/blis.git
synced 2026-05-11 09:39:59 +00:00
Fixed x86_64 kernel bugs and other minor issues.
Details: - Fixed bugs in trmv_l and trsv_u due to backwards iteration resulting in unaligned subpartitions. We were already going out of our way a bit to handle edge cases in the first iteration for blocked variants, and this was simply the unblocked-fused extension of that idea. - Fixed control tree handling in her/her2/syr/syr2 that was not taking into account how the choice of variant needed to be altered for upper-stored matrices (given that only lower-stored algorithms are explicitly implemented). - Added bli_determine_blocksize_dim_f(), bli_determine_blocksize_dim_b() macros to provide inlined versions of bli_determine_blocksize_[fb]() for use by unblocked-fused variants. - Integrated new blocksize_dim macros into gemv/hemv unf variants for consistency with that of the bugfix for trmv/trsv (both of which now use the same macros). - Modified bli_obj_vector_inc() so that 1 is returned if the object is a vector of length 1 (ie: 1 x 1). This fixes a bug whereby under certain conditions (e.g. dotv_opt_var1), an invalid increment was returned, which was invalid only because the code was expecting 1 (for purposes of performing contiguous vector loads) but got a value greater than 1 because the column stride of the object (e.g. rho) was inflated for alignment purposes (albeit unnecessarily since there is only one element in the object). - Replaced some old invocations of set0 with set0s. - Added alpha parameter to gemmtrsm ukernels for x86_64 and use accordingly. - Fixed increment bug in cleanup loop of gemm ukernel for x86_64. - Added safeguard to test modules so that testing a problem with a zero dimension does not result in a failure. - Tweaked handling of zero dimensions in level-2 and level-3 operations' internal back-ends to correctly handle cases where output operand still needs to be scaled (e.g. by beta, in the case of gemm with k = 0).
This commit is contained in:
@@ -272,13 +272,12 @@ void bli_dgemmtrsm_l_opt_d4x2(
|
||||
"movl %10, %%eax \n\t" // load address of alpha
|
||||
"movddup (%%eax), %%xmm7 \n\t" // load alpha and duplicate
|
||||
" \n\t"
|
||||
"movapd 0 * 16(%%ebx), %%xmm4 \n\t" // load xmm4 = ( beta00 beta01 )
|
||||
"movapd 1 * 16(%%ebx), %%xmm5 \n\t" // load xmm5 = ( beta10 beta11 )
|
||||
"movapd 2 * 16(%%ebx), %%xmm6 \n\t" // load xmm6 = ( beta20 beta21 )
|
||||
"mulpd %%xmm7, %%xmm4 \n\t" // xmm4 *= alpha
|
||||
"mulpd %%xmm7, %%xmm5 \n\t" // xmm5 *= alpha
|
||||
"mulpd %%xmm7, %%xmm6 \n\t" // xmm6 *= alpha
|
||||
//"movapd 3 * 16(%%ebx), %%xmm7 \n\t" // load xmm7 = ( beta30 beta31 )
|
||||
"movapd 0 * 16(%%ebx), %%xmm4 \n\t"
|
||||
"movapd 1 * 16(%%ebx), %%xmm5 \n\t"
|
||||
"mulpd %%xmm7, %%xmm4 \n\t" // xmm4 = alpha * ( beta00 beta01 )
|
||||
"mulpd %%xmm7, %%xmm5 \n\t" // xmm5 = alpha * ( beta10 beta11 )
|
||||
"movapd 2 * 16(%%ebx), %%xmm6 \n\t"
|
||||
"mulpd %%xmm7, %%xmm6 \n\t" // xmm6 = alpha * ( beta20 beta21 )
|
||||
"mulpd 3 * 16(%%ebx), %%xmm7 \n\t" // xmm7 = alpha * ( beta30 beta31 )
|
||||
" \n\t"
|
||||
"subpd %%xmm0, %%xmm4 \n\t" // xmm4 -= xmm0
|
||||
|
||||
@@ -117,11 +117,11 @@ void PASTEMAC3(chx,chy,chr,varname)( \
|
||||
\
|
||||
if ( bli_zero_dim1( n ) ) \
|
||||
{ \
|
||||
PASTEMAC(chr,set0)( *rho_cast ); \
|
||||
PASTEMAC(chr,set0s)( *rho_cast ); \
|
||||
return; \
|
||||
} \
|
||||
\
|
||||
PASTEMAC(chr,set0)( dotxy ); \
|
||||
PASTEMAC(chr,set0s)( dotxy ); \
|
||||
\
|
||||
chi1 = x_cast; \
|
||||
psi1 = y_cast; \
|
||||
@@ -216,7 +216,7 @@ void bli_ddddotv_opt_var1(
|
||||
|
||||
if ( bli_zero_dim1( n ) )
|
||||
{
|
||||
PASTEMAC(d,set0)( *rho_cast );
|
||||
PASTEMAC(d,set0s)( *rho_cast );
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -238,7 +238,7 @@ void bli_ddddotv_opt_var1(
|
||||
x1 = x_cast;
|
||||
y1 = y_cast;
|
||||
|
||||
PASTEMAC(d,set0)( rho1 );
|
||||
PASTEMAC(d,set0s)( rho1 );
|
||||
|
||||
if ( n_pre == 1 )
|
||||
{
|
||||
|
||||
@@ -133,7 +133,7 @@ void bli_ddddotaxpyv_opt_var1(
|
||||
|
||||
if ( bli_zero_dim1( n ) )
|
||||
{
|
||||
PASTEMAC(d,set0)( *rho_cast );
|
||||
PASTEMAC(d,set0s)( *rho_cast );
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ void bli_ddddotaxpyv_opt_var1(
|
||||
stepy = 2 * incy;
|
||||
stepz = 2 * incz;
|
||||
|
||||
PASTEMAC(d,set0)( rho1c );
|
||||
PASTEMAC(d,set0s)( rho1c );
|
||||
|
||||
alpha1c = *alpha_cast;
|
||||
|
||||
|
||||
@@ -233,11 +233,11 @@ void bli_ddddotxaxpyf_opt_var1(
|
||||
PASTEMAC2(d,d,scals)( *alpha_cast, chi2 );
|
||||
PASTEMAC2(d,d,scals)( *alpha_cast, chi3 );
|
||||
|
||||
PASTEMAC(d,set0)( rho0 );
|
||||
PASTEMAC(d,set0)( rho1 );
|
||||
PASTEMAC(d,set0)( rho2 );
|
||||
PASTEMAC(d,set0)( rho3 );
|
||||
PASTEMAC(d,set0)( zeta1 );
|
||||
PASTEMAC(d,set0s)( rho0 );
|
||||
PASTEMAC(d,set0s)( rho1 );
|
||||
PASTEMAC(d,set0s)( rho2 );
|
||||
PASTEMAC(d,set0s)( rho3 );
|
||||
PASTEMAC(d,set0s)( zeta1 );
|
||||
|
||||
if ( m_pre == 1 )
|
||||
{
|
||||
|
||||
@@ -267,10 +267,10 @@ void bli_ddddotxf_opt_var1(
|
||||
x3 = x_cast + 3*ldx;
|
||||
y0 = y_cast;
|
||||
|
||||
PASTEMAC(d,set0)( rho0 );
|
||||
PASTEMAC(d,set0)( rho1 );
|
||||
PASTEMAC(d,set0)( rho2 );
|
||||
PASTEMAC(d,set0)( rho3 );
|
||||
PASTEMAC(d,set0s)( rho0 );
|
||||
PASTEMAC(d,set0s)( rho1 );
|
||||
PASTEMAC(d,set0s)( rho2 );
|
||||
PASTEMAC(d,set0s)( rho3 );
|
||||
|
||||
if ( n_pre == 1 )
|
||||
{
|
||||
|
||||
@@ -281,8 +281,8 @@ void bli_dgemm_opt_d4x4(
|
||||
"movaps -5 * 16(%%rax), %%xmm1 \n\t"
|
||||
" \n\t"
|
||||
" \n\t"
|
||||
"addq $4 * 4 * 8, %%rax \n\t" // a += 4 (1 x mr)
|
||||
"addq $4 * 4 * 8, %%rbx \n\t" // b += 4 (1 x nr)
|
||||
"addq $4 * 1 * 8, %%rax \n\t" // a += 4 (1 x mr)
|
||||
"addq $4 * 1 * 8, %%rbx \n\t" // b += 4 (1 x nr)
|
||||
" \n\t"
|
||||
" \n\t"
|
||||
"decq %%rsi \n\t" // i -= 1;
|
||||
|
||||
@@ -36,6 +36,7 @@
|
||||
|
||||
void bli_sgemmtrsm_l_opt_d4x4(
|
||||
dim_t k,
|
||||
float* restrict alpha,
|
||||
float* restrict a10,
|
||||
float* restrict a11,
|
||||
float* restrict bd01,
|
||||
@@ -51,6 +52,7 @@ void bli_sgemmtrsm_l_opt_d4x4(
|
||||
|
||||
void bli_dgemmtrsm_l_opt_d4x4(
|
||||
dim_t k,
|
||||
double* restrict alpha,
|
||||
double* restrict a10,
|
||||
double* restrict a11,
|
||||
double* restrict bd01,
|
||||
@@ -334,14 +336,26 @@ void bli_dgemmtrsm_l_opt_d4x4(
|
||||
" \n\t" // xmm2: ( ab20 ab21 ) xmm6: ( ab22 ab23 )
|
||||
" \n\t" // xmm3: ( ab30 ab31 ) xmm7: ( ab32 ab33 )
|
||||
" \n\t"
|
||||
"movapd 0 * 16(%%rbx), %%xmm8 \n\t"
|
||||
"movq %10, %%rax \n\t" // load address of alpha
|
||||
"movddup (%%rax), %%xmm15 \n\t" // load alpha and duplicate
|
||||
" \n\t"
|
||||
"movapd 0 * 16(%%rbx), %%xmm8 \n\t"
|
||||
"movapd 1 * 16(%%rbx), %%xmm12 \n\t"
|
||||
"mulpd %%xmm15, %%xmm8 \n\t" // xmm8 = alpha * ( beta00 beta01 )
|
||||
"mulpd %%xmm15, %%xmm12 \n\t" // xmm12 = alpha * ( beta02 beta03 )
|
||||
"movapd 2 * 16(%%rbx), %%xmm9 \n\t"
|
||||
"movapd 3 * 16(%%rbx), %%xmm13 \n\t"
|
||||
"mulpd %%xmm15, %%xmm9 \n\t" // xmm9 = alpha * ( beta10 beta11 )
|
||||
"mulpd %%xmm15, %%xmm13 \n\t" // xmm13 = alpha * ( beta12 beta13 )
|
||||
"movapd 4 * 16(%%rbx), %%xmm10 \n\t"
|
||||
"movapd 5 * 16(%%rbx), %%xmm14 \n\t"
|
||||
"mulpd %%xmm15, %%xmm10 \n\t" // xmm10 = alpha * ( beta20 beta21 )
|
||||
"mulpd %%xmm15, %%xmm14 \n\t" // xmm14 = alpha * ( beta22 beta23 )
|
||||
"movapd 6 * 16(%%rbx), %%xmm11 \n\t"
|
||||
"movapd 7 * 16(%%rbx), %%xmm15 \n\t"
|
||||
"mulpd %%xmm15, %%xmm11 \n\t" // xmm11 = alpha * ( beta30 beta31 )
|
||||
"mulpd 7 * 16(%%rbx), %%xmm15 \n\t" // xmm15 = alpha * ( beta32 beta33 )
|
||||
" \n\t"
|
||||
" \n\t" // (Now scaled by alpha:)
|
||||
" \n\t" // xmm8: ( beta00 beta01 ) xmm12: ( beta02 beta03 )
|
||||
" \n\t" // xmm9: ( beta10 beta11 ) xmm13: ( beta12 beta13 )
|
||||
" \n\t" // xmm10: ( beta20 beta21 ) xmm14: ( beta22 beta23 )
|
||||
@@ -491,7 +505,8 @@ void bli_dgemmtrsm_l_opt_d4x4(
|
||||
"m" (b11),
|
||||
"m" (c11),
|
||||
"m" (rs_c),
|
||||
"m" (cs_c)
|
||||
"m" (cs_c),
|
||||
"m" (alpha)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
@@ -505,6 +520,7 @@ void bli_dgemmtrsm_l_opt_d4x4(
|
||||
|
||||
void bli_cgemmtrsm_l_opt_d4x4(
|
||||
dim_t k,
|
||||
scomplex* restrict alpha,
|
||||
scomplex* restrict a10,
|
||||
scomplex* restrict a11,
|
||||
scomplex* restrict bd01,
|
||||
@@ -520,6 +536,7 @@ void bli_cgemmtrsm_l_opt_d4x4(
|
||||
|
||||
void bli_zgemmtrsm_l_opt_d4x4(
|
||||
dim_t k,
|
||||
dcomplex* restrict alpha,
|
||||
dcomplex* restrict a10,
|
||||
dcomplex* restrict a11,
|
||||
dcomplex* restrict bd01,
|
||||
|
||||
@@ -38,6 +38,7 @@
|
||||
\
|
||||
void PASTEMAC(ch,varname)( \
|
||||
dim_t k, \
|
||||
ctype* restrict alpha, \
|
||||
ctype* restrict a10, \
|
||||
ctype* restrict a11, \
|
||||
ctype* restrict bd01, \
|
||||
|
||||
@@ -36,6 +36,7 @@
|
||||
|
||||
void bli_sgemmtrsm_u_opt_d4x4(
|
||||
dim_t k,
|
||||
float* restrict alpha,
|
||||
float* restrict a12,
|
||||
float* restrict a11,
|
||||
float* restrict bd21,
|
||||
@@ -51,6 +52,7 @@ void bli_sgemmtrsm_u_opt_d4x4(
|
||||
|
||||
void bli_dgemmtrsm_u_opt_d4x4(
|
||||
dim_t k,
|
||||
double* restrict alpha,
|
||||
double* restrict a12,
|
||||
double* restrict a11,
|
||||
double* restrict bd21,
|
||||
@@ -282,8 +284,8 @@ void bli_dgemmtrsm_u_opt_d4x4(
|
||||
"movaps -5 * 16(%%rax), %%xmm1 \n\t"
|
||||
" \n\t"
|
||||
" \n\t"
|
||||
"addq $4 * 4 * 8, %%rax \n\t" // a += 4 (1 x mr)
|
||||
"addq $4 * 4 * 8, %%rbx \n\t" // b += 4 (1 x nr)
|
||||
"addq $4 * 1 * 8, %%rax \n\t" // a += 4 (1 x mr)
|
||||
"addq $4 * 1 * 8, %%rbx \n\t" // b += 4 (1 x nr)
|
||||
" \n\t"
|
||||
" \n\t"
|
||||
"decq %%rsi \n\t" // i -= 1;
|
||||
@@ -334,14 +336,26 @@ void bli_dgemmtrsm_u_opt_d4x4(
|
||||
" \n\t" // xmm2: ( ab20 ab21 ) xmm6: ( ab22 ab23 )
|
||||
" \n\t" // xmm3: ( ab30 ab31 ) xmm7: ( ab32 ab33 )
|
||||
" \n\t"
|
||||
"movq %10, %%rax \n\t" // load address of alpha
|
||||
"movddup (%%rax), %%xmm15 \n\t" // load alpha and duplicate
|
||||
" \n\t"
|
||||
"movapd 0 * 16(%%rbx), %%xmm8 \n\t"
|
||||
"movapd 1 * 16(%%rbx), %%xmm12 \n\t"
|
||||
"mulpd %%xmm15, %%xmm8 \n\t" // xmm8 = alpha * ( beta00 beta01 )
|
||||
"mulpd %%xmm15, %%xmm12 \n\t" // xmm12 = alpha * ( beta02 beta03 )
|
||||
"movapd 2 * 16(%%rbx), %%xmm9 \n\t"
|
||||
"movapd 3 * 16(%%rbx), %%xmm13 \n\t"
|
||||
"mulpd %%xmm15, %%xmm9 \n\t" // xmm9 = alpha * ( beta10 beta11 )
|
||||
"mulpd %%xmm15, %%xmm13 \n\t" // xmm13 = alpha * ( beta12 beta13 )
|
||||
"movapd 4 * 16(%%rbx), %%xmm10 \n\t"
|
||||
"movapd 5 * 16(%%rbx), %%xmm14 \n\t"
|
||||
"mulpd %%xmm15, %%xmm10 \n\t" // xmm10 = alpha * ( beta20 beta21 )
|
||||
"mulpd %%xmm15, %%xmm14 \n\t" // xmm14 = alpha * ( beta22 beta23 )
|
||||
"movapd 6 * 16(%%rbx), %%xmm11 \n\t"
|
||||
"movapd 7 * 16(%%rbx), %%xmm15 \n\t"
|
||||
"mulpd %%xmm15, %%xmm11 \n\t" // xmm11 = alpha * ( beta30 beta31 )
|
||||
"mulpd 7 * 16(%%rbx), %%xmm15 \n\t" // xmm15 = alpha * ( beta32 beta33 )
|
||||
" \n\t"
|
||||
" \n\t" // (Now scaled by alpha:)
|
||||
" \n\t" // xmm8: ( beta00 beta01 ) xmm12: ( beta02 beta03 )
|
||||
" \n\t" // xmm9: ( beta10 beta11 ) xmm13: ( beta12 beta13 )
|
||||
" \n\t" // xmm10: ( beta20 beta21 ) xmm14: ( beta22 beta23 )
|
||||
@@ -494,7 +508,8 @@ void bli_dgemmtrsm_u_opt_d4x4(
|
||||
"m" (b11),
|
||||
"m" (c11),
|
||||
"m" (rs_c),
|
||||
"m" (cs_c)
|
||||
"m" (cs_c),
|
||||
"m" (alpha)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
@@ -508,6 +523,7 @@ void bli_dgemmtrsm_u_opt_d4x4(
|
||||
|
||||
void bli_cgemmtrsm_u_opt_d4x4(
|
||||
dim_t k,
|
||||
scomplex* restrict alpha,
|
||||
scomplex* restrict a12,
|
||||
scomplex* restrict a11,
|
||||
scomplex* restrict bd21,
|
||||
@@ -523,6 +539,7 @@ void bli_cgemmtrsm_u_opt_d4x4(
|
||||
|
||||
void bli_zgemmtrsm_u_opt_d4x4(
|
||||
dim_t k,
|
||||
dcomplex* restrict alpha,
|
||||
dcomplex* restrict a12,
|
||||
dcomplex* restrict a11,
|
||||
dcomplex* restrict bd21,
|
||||
|
||||
@@ -38,6 +38,7 @@
|
||||
\
|
||||
void PASTEMAC(ch,varname)( \
|
||||
dim_t k, \
|
||||
ctype* restrict alpha, \
|
||||
ctype* restrict a12, \
|
||||
ctype* restrict a11, \
|
||||
ctype* restrict bd21, \
|
||||
|
||||
Reference in New Issue
Block a user