Vectorized and parallelized zdscal routine

- Implemented optimized intrinsic kernel for zdscalv for the cases where AVX2 is supported.
- Also added multithreaded support for the same.
- The optimal number of threads is being calculated on the basis of input size.

AMD-Internal: [CPUPL-2602]
Change-Id: I4d05c3b1cc365a7770703286a89c6dce3875c067
This commit is contained in:
Arnav Sharma
2022-09-29 23:32:10 +05:30
committed by Arnav Sharma
parent 9c292b79e2
commit 90f915d3a9
4 changed files with 403 additions and 2 deletions

View File

@@ -355,6 +355,176 @@ void dscal_
dscal_blis_impl( n, alpha, x, incx );
}
void zdscal_blis_impl
(
const f77_int* n,
const double* alpha,
dcomplex* x, const f77_int* incx
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1)
AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', (void *) alpha, *n, *incx );
dim_t n0;
dcomplex* x0;
inc_t incx0;
/* Initialize BLIS. */
//bli_init_auto();
/* Convert/typecast negative values of n to zero. */
if ( *n < 0 ) n0 = ( dim_t )0;
else n0 = ( dim_t )(*n);
if (*n == 0 || alpha == NULL) {
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
/* If the input increments are negative, adjust the pointers so we can
use positive increments instead. */
if ( *incx < 0 )
{
/* The semantics of negative stride in BLAS are that the vector
operand be traversed in reverse order. (Another way to think
of this is that negative strides effectively reverse the order
of the vector, but without any explicit data movements.) This
is also how BLIS interprets negative strides. The differences
is that with BLAS, the caller *always* passes in the 0th (i.e.,
top-most or left-most) element of the vector, even when the
stride is negative. By contrast, in BLIS, negative strides are
used *relative* to the vector address as it is given. Thus, in
BLIS, if this backwards traversal is desired, the caller *must*
pass in the address to the (n-1)th (i.e., the bottom-most or
right-most) element along with a negative stride. */
x0 = (x) + (n0-1)*(-*incx);
incx0 = ( inc_t )(*incx);
}
else
{
x0 = (x);
incx0 = ( inc_t )(*incx);
}
// This function is invoked on all architectures including generic.
// Non-AVX platforms will use the kernels derived from the context.
if ( bli_cpuid_is_avx_supported() == TRUE )
{
#ifdef BLIS_ENABLE_OPENMP
// For sizes less than 10000, optimal number of threads is 1, but
// due to the overhead of calling omp functions it is being done outside
// by directly calling dscalv so as to get maximum performance.
if ( n0 <= 10000 )
{
bli_zdscalv_zen_int10
(
BLIS_NO_CONJUGATE,
n0,
(double*) alpha,
x0, incx0,
NULL
);
}
else
{
rntm_t rntm_local;
bli_rntm_init_from_global( &rntm_local );
dim_t nt = bli_rntm_num_threads( &rntm_local );
#ifdef AOCL_DYNAMIC
dim_t nt_ideal;
if ( n0 <= 20000 ) nt_ideal = 4;
else if ( n0 <= 1000000 ) nt_ideal = 8;
else if ( n0 <= 2500000 ) nt_ideal = 12;
else if ( n0 <= 5000000 ) nt_ideal = 32;
else nt_ideal = 64;
nt = bli_min( nt_ideal, nt );
#endif
dim_t n_elem_per_thread = n0 / nt;
dim_t n_elem_rem = n0 % nt;
#pragma omp parallel num_threads( nt )
{
// The following conditions handle the optimal distribution of
// load among the threads.
// Say we have n0 = 50 & nt = 4.
// So we get 12 ( n0 / nt ) elements per thread along with 2
// remaining elements. Each of these remaining elements is given
// to the last threads, respectively.
// So, t0, t1, t2 and t3 gets 12, 12, 13 and 13 elements,
// respectively.
dim_t t_id = omp_get_thread_num();
dim_t npt, offset;
if ( t_id < ( nt - n_elem_rem ) )
{
npt = n_elem_per_thread;
offset = t_id * npt * incx0;
}
else
{
npt = n_elem_per_thread + 1;
offset = ( ( t_id * n_elem_per_thread ) +
( t_id - ( nt - n_elem_rem ) ) ) * incx0;
}
bli_zdscalv_zen_int10
(
BLIS_NO_CONJUGATE,
npt,
(double *) alpha,
x0 + offset, incx0,
NULL
);
}
}
#else
// Default call to zdscalv for single-threaded work
bli_zdscalv_zen_int10
(
BLIS_NO_CONJUGATE,
n0,
(double *) alpha,
x0, incx0,
NULL
);
#endif
}
else
{
// Sub-optimal implementation for zdscal
// by casting alpha to the double complex domain and
// calling the zscal
dcomplex alpha_cast;
PASTEMAC2(d,z,copys)( *alpha, alpha_cast );
/* Call BLIS interface. */ \
PASTEMAC2(z,scalv,BLIS_TAPI_EX_SUF) \
( \
BLIS_NO_CONJUGATE, \
n0, \
&alpha_cast, \
x0, incx0, \
NULL, \
NULL \
); \
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
}
void zdscal_
(
const f77_int* n,
const double* alpha,
dcomplex* x, const f77_int* incx
)
{
zdscal_blis_impl( n, alpha, x, incx );
}
INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv )
#endif

View File

@@ -159,8 +159,8 @@ GENTFUNCR2( dcomplex, double, z, d, blasname, blisname )
\
GENTFUNCSCAL( scomplex, scomplex, c, , blasname, blisname ) \
GENTFUNCSCAL( dcomplex, dcomplex, z, , blasname, blisname ) \
GENTFUNCSCAL( scomplex, float, c, s, blasname, blisname ) \
GENTFUNCSCAL( dcomplex, double, z, d, blasname, blisname )
GENTFUNCSCAL( scomplex, float, c, s, blasname, blisname )
// GENTFUNCSCAL( dcomplex, double, z, d, blasname, blisname )
#define INSERT_GENTFUNCSCAL_BLAS( blasname, blisname ) \

View File

@@ -578,3 +578,225 @@ void bli_dscalv_zen_int10
}
}
void bli_zdscalv_zen_int10
(
conj_t conjalpha,
dim_t n,
double* restrict alpha,
dcomplex* restrict x, inc_t incx,
cntx_t* restrict cntx
)
{
dim_t i = 0;
const dim_t n_elem_per_reg = 4; // number of elements per register
double* restrict x0 = (double*) x;
const double alphac = *alpha;
if ( incx == 1 )
{
__m256d alphav;
__m256d xv[15];
alphav = _mm256_broadcast_sd( alpha );
for ( ; ( i + 29 ) < n; i += 30 )
{
xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg );
xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg );
xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg );
xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg );
xv[4] = _mm256_loadu_pd( x0 + 4 * n_elem_per_reg );
xv[5] = _mm256_loadu_pd( x0 + 5 * n_elem_per_reg );
xv[6] = _mm256_loadu_pd( x0 + 6 * n_elem_per_reg );
xv[7] = _mm256_loadu_pd( x0 + 7 * n_elem_per_reg );
xv[8] = _mm256_loadu_pd( x0 + 8 * n_elem_per_reg );
xv[9] = _mm256_loadu_pd( x0 + 9 * n_elem_per_reg );
xv[10] = _mm256_loadu_pd( x0 + 10 * n_elem_per_reg );
xv[11] = _mm256_loadu_pd( x0 + 11 * n_elem_per_reg );
xv[12] = _mm256_loadu_pd( x0 + 12 * n_elem_per_reg );
xv[13] = _mm256_loadu_pd( x0 + 13 * n_elem_per_reg );
xv[14] = _mm256_loadu_pd( x0 + 14 * n_elem_per_reg );
xv[0] = _mm256_mul_pd( alphav, xv[0] );
xv[1] = _mm256_mul_pd( alphav, xv[1] );
xv[2] = _mm256_mul_pd( alphav, xv[2] );
xv[3] = _mm256_mul_pd( alphav, xv[3] );
xv[4] = _mm256_mul_pd( alphav, xv[4] );
xv[5] = _mm256_mul_pd( alphav, xv[5] );
xv[6] = _mm256_mul_pd( alphav, xv[6] );
xv[7] = _mm256_mul_pd( alphav, xv[7] );
xv[8] = _mm256_mul_pd( alphav, xv[8] );
xv[9] = _mm256_mul_pd( alphav, xv[9] );
xv[10] = _mm256_mul_pd( alphav, xv[10] );
xv[11] = _mm256_mul_pd( alphav, xv[11] );
xv[12] = _mm256_mul_pd( alphav, xv[12] );
xv[13] = _mm256_mul_pd( alphav, xv[13] );
xv[14] = _mm256_mul_pd( alphav, xv[14] );
_mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] );
_mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] );
_mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] );
_mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] );
_mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] );
_mm256_storeu_pd( (x0 + 5*n_elem_per_reg), xv[5] );
_mm256_storeu_pd( (x0 + 6*n_elem_per_reg), xv[6] );
_mm256_storeu_pd( (x0 + 7*n_elem_per_reg), xv[7] );
_mm256_storeu_pd( (x0 + 8*n_elem_per_reg), xv[8] );
_mm256_storeu_pd( (x0 + 9*n_elem_per_reg), xv[9] );
_mm256_storeu_pd( (x0 + 10*n_elem_per_reg), xv[10] );
_mm256_storeu_pd( (x0 + 11*n_elem_per_reg), xv[11] );
_mm256_storeu_pd( (x0 + 12*n_elem_per_reg), xv[12] );
_mm256_storeu_pd( (x0 + 13*n_elem_per_reg), xv[13] );
_mm256_storeu_pd( (x0 + 14*n_elem_per_reg), xv[14] );
x0 += 15 * n_elem_per_reg;
}
for ( ; ( i + 23 ) < n; i += 24 )
{
xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg );
xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg );
xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg );
xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg );
xv[4] = _mm256_loadu_pd( x0 + 4 * n_elem_per_reg );
xv[5] = _mm256_loadu_pd( x0 + 5 * n_elem_per_reg );
xv[6] = _mm256_loadu_pd( x0 + 6 * n_elem_per_reg );
xv[7] = _mm256_loadu_pd( x0 + 7 * n_elem_per_reg );
xv[8] = _mm256_loadu_pd( x0 + 8 * n_elem_per_reg );
xv[9] = _mm256_loadu_pd( x0 + 9 * n_elem_per_reg );
xv[10] = _mm256_loadu_pd( x0 + 10 * n_elem_per_reg );
xv[11] = _mm256_loadu_pd( x0 + 11 * n_elem_per_reg );
xv[0] = _mm256_mul_pd( alphav, xv[0] );
xv[1] = _mm256_mul_pd( alphav, xv[1] );
xv[2] = _mm256_mul_pd( alphav, xv[2] );
xv[3] = _mm256_mul_pd( alphav, xv[3] );
xv[4] = _mm256_mul_pd( alphav, xv[4] );
xv[5] = _mm256_mul_pd( alphav, xv[5] );
xv[6] = _mm256_mul_pd( alphav, xv[6] );
xv[7] = _mm256_mul_pd( alphav, xv[7] );
xv[8] = _mm256_mul_pd( alphav, xv[8] );
xv[9] = _mm256_mul_pd( alphav, xv[9] );
xv[10] = _mm256_mul_pd( alphav, xv[10] );
xv[11] = _mm256_mul_pd( alphav, xv[11] );
_mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] );
_mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] );
_mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] );
_mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] );
_mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] );
_mm256_storeu_pd( (x0 + 5*n_elem_per_reg), xv[5] );
_mm256_storeu_pd( (x0 + 6*n_elem_per_reg), xv[6] );
_mm256_storeu_pd( (x0 + 7*n_elem_per_reg), xv[7] );
_mm256_storeu_pd( (x0 + 8*n_elem_per_reg), xv[8] );
_mm256_storeu_pd( (x0 + 9*n_elem_per_reg), xv[9] );
_mm256_storeu_pd( (x0 + 10*n_elem_per_reg), xv[10] );
_mm256_storeu_pd( (x0 + 11*n_elem_per_reg), xv[11] );
x0 += 12 * n_elem_per_reg;
}
for ( ; ( i + 15 ) < n; i += 16 )
{
xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg );
xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg );
xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg );
xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg );
xv[4] = _mm256_loadu_pd( x0 + 4 * n_elem_per_reg );
xv[5] = _mm256_loadu_pd( x0 + 5 * n_elem_per_reg );
xv[6] = _mm256_loadu_pd( x0 + 6 * n_elem_per_reg );
xv[7] = _mm256_loadu_pd( x0 + 7 * n_elem_per_reg );
xv[0] = _mm256_mul_pd( alphav, xv[0] );
xv[1] = _mm256_mul_pd( alphav, xv[1] );
xv[2] = _mm256_mul_pd( alphav, xv[2] );
xv[3] = _mm256_mul_pd( alphav, xv[3] );
xv[4] = _mm256_mul_pd( alphav, xv[4] );
xv[5] = _mm256_mul_pd( alphav, xv[5] );
xv[6] = _mm256_mul_pd( alphav, xv[6] );
xv[7] = _mm256_mul_pd( alphav, xv[7] );
_mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] );
_mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] );
_mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] );
_mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] );
_mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] );
_mm256_storeu_pd( (x0 + 5*n_elem_per_reg), xv[5] );
_mm256_storeu_pd( (x0 + 6*n_elem_per_reg), xv[6] );
_mm256_storeu_pd( (x0 + 7*n_elem_per_reg), xv[7] );
x0 += 8 * n_elem_per_reg;
}
for ( ; ( i + 7 ) < n; i += 8 )
{
xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg );
xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg );
xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg );
xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg );
xv[0] = _mm256_mul_pd( alphav, xv[0] );
xv[1] = _mm256_mul_pd( alphav, xv[1] );
xv[2] = _mm256_mul_pd( alphav, xv[2] );
xv[3] = _mm256_mul_pd( alphav, xv[3] );
_mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] );
_mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] );
_mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] );
_mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] );
x0 += 4 * n_elem_per_reg;
}
for ( ; ( i + 3 ) < n; i += 4 )
{
xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg );
xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg );
xv[0] = _mm256_mul_pd( alphav, xv[0] );
xv[1] = _mm256_mul_pd( alphav, xv[1] );
_mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] );
_mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] );
x0 += 2 * n_elem_per_reg;
}
for ( ; ( i + 1 ) < n; i += 2 )
{
xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg );
xv[0] = _mm256_mul_pd( alphav, xv[0] );
_mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] );
x0 += 1 * n_elem_per_reg;
}
for ( ; i < n; i++ )
{
( *x0 ) *= alphac;
( *( x0 + 1 ) ) *= alphac;
x0 += 2 * incx;
}
// Issue vzeroupper instruction to clear upper lanes of ymm registers.
// This avoids a performance penalty caused by false dependencies when
// transitioning from from AVX to SSE instructions (which may occur
// as soon as the n_left cleanup loop below if BLIS is compiled with
// -mfpmath=sse).
_mm256_zeroupper();
}
else
{
for ( ; i < n; ++i )
{
( *x0 ) *= alphac;
( *( x0 + 1 ) ) *= alphac;
x0 += 2 * incx;
}
}
}

View File

@@ -385,4 +385,13 @@ void bli_dnorm2fv_unb_var1_avx
double* x, inc_t incx,
double* norm,
cntx_t* cntx
);
void bli_zdscalv_zen_int10
(
conj_t conjalpha,
dim_t n,
double* restrict alpha,
dcomplex* restrict x, inc_t incx,
cntx_t* restrict cntx
);