mirror of
https://github.com/amd/blis.git
synced 2026-05-11 17:50:00 +00:00
Vectorized and parallelized zdscal routine
- Implemented optimized intrinsic kernel for zdscalv for the cases where AVX2 is supported. - Also added multithreaded support for the same. - The optimal number of threads is being calculated on the basis of input size. AMD-Internal: [CPUPL-2602] Change-Id: I4d05c3b1cc365a7770703286a89c6dce3875c067
This commit is contained in:
committed by
Arnav Sharma
parent
9c292b79e2
commit
90f915d3a9
@@ -355,6 +355,176 @@ void dscal_
|
||||
dscal_blis_impl( n, alpha, x, incx );
|
||||
}
|
||||
|
||||
void zdscal_blis_impl
|
||||
(
|
||||
const f77_int* n,
|
||||
const double* alpha,
|
||||
dcomplex* x, const f77_int* incx
|
||||
)
|
||||
{
|
||||
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1)
|
||||
AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', (void *) alpha, *n, *incx );
|
||||
dim_t n0;
|
||||
dcomplex* x0;
|
||||
inc_t incx0;
|
||||
/* Initialize BLIS. */
|
||||
//bli_init_auto();
|
||||
|
||||
/* Convert/typecast negative values of n to zero. */
|
||||
if ( *n < 0 ) n0 = ( dim_t )0;
|
||||
else n0 = ( dim_t )(*n);
|
||||
|
||||
if (*n == 0 || alpha == NULL) {
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
return;
|
||||
}
|
||||
|
||||
/* If the input increments are negative, adjust the pointers so we can
|
||||
use positive increments instead. */
|
||||
if ( *incx < 0 )
|
||||
{
|
||||
/* The semantics of negative stride in BLAS are that the vector
|
||||
operand be traversed in reverse order. (Another way to think
|
||||
of this is that negative strides effectively reverse the order
|
||||
of the vector, but without any explicit data movements.) This
|
||||
is also how BLIS interprets negative strides. The differences
|
||||
is that with BLAS, the caller *always* passes in the 0th (i.e.,
|
||||
top-most or left-most) element of the vector, even when the
|
||||
stride is negative. By contrast, in BLIS, negative strides are
|
||||
used *relative* to the vector address as it is given. Thus, in
|
||||
BLIS, if this backwards traversal is desired, the caller *must*
|
||||
pass in the address to the (n-1)th (i.e., the bottom-most or
|
||||
right-most) element along with a negative stride. */
|
||||
|
||||
x0 = (x) + (n0-1)*(-*incx);
|
||||
incx0 = ( inc_t )(*incx);
|
||||
}
|
||||
else
|
||||
{
|
||||
x0 = (x);
|
||||
incx0 = ( inc_t )(*incx);
|
||||
}
|
||||
|
||||
// This function is invoked on all architectures including ‘generic’.
|
||||
// Non-AVX platforms will use the kernels derived from the context.
|
||||
if ( bli_cpuid_is_avx_supported() == TRUE )
|
||||
{
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
// For sizes less than 10000, optimal number of threads is 1, but
|
||||
// due to the overhead of calling omp functions it is being done outside
|
||||
// by directly calling dscalv so as to get maximum performance.
|
||||
if ( n0 <= 10000 )
|
||||
{
|
||||
bli_zdscalv_zen_int10
|
||||
(
|
||||
BLIS_NO_CONJUGATE,
|
||||
n0,
|
||||
(double*) alpha,
|
||||
x0, incx0,
|
||||
NULL
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
rntm_t rntm_local;
|
||||
bli_rntm_init_from_global( &rntm_local );
|
||||
dim_t nt = bli_rntm_num_threads( &rntm_local );
|
||||
|
||||
#ifdef AOCL_DYNAMIC
|
||||
dim_t nt_ideal;
|
||||
|
||||
if ( n0 <= 20000 ) nt_ideal = 4;
|
||||
else if ( n0 <= 1000000 ) nt_ideal = 8;
|
||||
else if ( n0 <= 2500000 ) nt_ideal = 12;
|
||||
else if ( n0 <= 5000000 ) nt_ideal = 32;
|
||||
else nt_ideal = 64;
|
||||
|
||||
nt = bli_min( nt_ideal, nt );
|
||||
#endif
|
||||
dim_t n_elem_per_thread = n0 / nt;
|
||||
dim_t n_elem_rem = n0 % nt;
|
||||
|
||||
#pragma omp parallel num_threads( nt )
|
||||
{
|
||||
// The following conditions handle the optimal distribution of
|
||||
// load among the threads.
|
||||
// Say we have n0 = 50 & nt = 4.
|
||||
// So we get 12 ( n0 / nt ) elements per thread along with 2
|
||||
// remaining elements. Each of these remaining elements is given
|
||||
// to the last threads, respectively.
|
||||
// So, t0, t1, t2 and t3 gets 12, 12, 13 and 13 elements,
|
||||
// respectively.
|
||||
dim_t t_id = omp_get_thread_num();
|
||||
dim_t npt, offset;
|
||||
|
||||
if ( t_id < ( nt - n_elem_rem ) )
|
||||
{
|
||||
npt = n_elem_per_thread;
|
||||
offset = t_id * npt * incx0;
|
||||
}
|
||||
else
|
||||
{
|
||||
npt = n_elem_per_thread + 1;
|
||||
offset = ( ( t_id * n_elem_per_thread ) +
|
||||
( t_id - ( nt - n_elem_rem ) ) ) * incx0;
|
||||
}
|
||||
|
||||
bli_zdscalv_zen_int10
|
||||
(
|
||||
BLIS_NO_CONJUGATE,
|
||||
npt,
|
||||
(double *) alpha,
|
||||
x0 + offset, incx0,
|
||||
NULL
|
||||
);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Default call to zdscalv for single-threaded work
|
||||
bli_zdscalv_zen_int10
|
||||
(
|
||||
BLIS_NO_CONJUGATE,
|
||||
n0,
|
||||
(double *) alpha,
|
||||
x0, incx0,
|
||||
NULL
|
||||
);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
// Sub-optimal implementation for zdscal
|
||||
// by casting alpha to the double complex domain and
|
||||
// calling the zscal
|
||||
dcomplex alpha_cast;
|
||||
PASTEMAC2(d,z,copys)( *alpha, alpha_cast );
|
||||
|
||||
/* Call BLIS interface. */ \
|
||||
PASTEMAC2(z,scalv,BLIS_TAPI_EX_SUF) \
|
||||
( \
|
||||
BLIS_NO_CONJUGATE, \
|
||||
n0, \
|
||||
&alpha_cast, \
|
||||
x0, incx0, \
|
||||
NULL, \
|
||||
NULL \
|
||||
); \
|
||||
}
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
|
||||
}
|
||||
|
||||
void zdscal_
|
||||
(
|
||||
const f77_int* n,
|
||||
const double* alpha,
|
||||
dcomplex* x, const f77_int* incx
|
||||
)
|
||||
{
|
||||
zdscal_blis_impl( n, alpha, x, incx );
|
||||
}
|
||||
|
||||
|
||||
INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv )
|
||||
|
||||
#endif
|
||||
|
||||
@@ -159,8 +159,8 @@ GENTFUNCR2( dcomplex, double, z, d, blasname, blisname )
|
||||
\
|
||||
GENTFUNCSCAL( scomplex, scomplex, c, , blasname, blisname ) \
|
||||
GENTFUNCSCAL( dcomplex, dcomplex, z, , blasname, blisname ) \
|
||||
GENTFUNCSCAL( scomplex, float, c, s, blasname, blisname ) \
|
||||
GENTFUNCSCAL( dcomplex, double, z, d, blasname, blisname )
|
||||
GENTFUNCSCAL( scomplex, float, c, s, blasname, blisname )
|
||||
// GENTFUNCSCAL( dcomplex, double, z, d, blasname, blisname )
|
||||
|
||||
|
||||
#define INSERT_GENTFUNCSCAL_BLAS( blasname, blisname ) \
|
||||
|
||||
@@ -578,3 +578,225 @@ void bli_dscalv_zen_int10
|
||||
}
|
||||
}
|
||||
|
||||
void bli_zdscalv_zen_int10
|
||||
(
|
||||
conj_t conjalpha,
|
||||
dim_t n,
|
||||
double* restrict alpha,
|
||||
dcomplex* restrict x, inc_t incx,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
dim_t i = 0;
|
||||
const dim_t n_elem_per_reg = 4; // number of elements per register
|
||||
|
||||
double* restrict x0 = (double*) x;
|
||||
const double alphac = *alpha;
|
||||
|
||||
if ( incx == 1 )
|
||||
{
|
||||
__m256d alphav;
|
||||
__m256d xv[15];
|
||||
|
||||
alphav = _mm256_broadcast_sd( alpha );
|
||||
|
||||
for ( ; ( i + 29 ) < n; i += 30 )
|
||||
{
|
||||
xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg );
|
||||
xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg );
|
||||
xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg );
|
||||
xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg );
|
||||
xv[4] = _mm256_loadu_pd( x0 + 4 * n_elem_per_reg );
|
||||
xv[5] = _mm256_loadu_pd( x0 + 5 * n_elem_per_reg );
|
||||
xv[6] = _mm256_loadu_pd( x0 + 6 * n_elem_per_reg );
|
||||
xv[7] = _mm256_loadu_pd( x0 + 7 * n_elem_per_reg );
|
||||
xv[8] = _mm256_loadu_pd( x0 + 8 * n_elem_per_reg );
|
||||
xv[9] = _mm256_loadu_pd( x0 + 9 * n_elem_per_reg );
|
||||
xv[10] = _mm256_loadu_pd( x0 + 10 * n_elem_per_reg );
|
||||
xv[11] = _mm256_loadu_pd( x0 + 11 * n_elem_per_reg );
|
||||
xv[12] = _mm256_loadu_pd( x0 + 12 * n_elem_per_reg );
|
||||
xv[13] = _mm256_loadu_pd( x0 + 13 * n_elem_per_reg );
|
||||
xv[14] = _mm256_loadu_pd( x0 + 14 * n_elem_per_reg );
|
||||
|
||||
xv[0] = _mm256_mul_pd( alphav, xv[0] );
|
||||
xv[1] = _mm256_mul_pd( alphav, xv[1] );
|
||||
xv[2] = _mm256_mul_pd( alphav, xv[2] );
|
||||
xv[3] = _mm256_mul_pd( alphav, xv[3] );
|
||||
xv[4] = _mm256_mul_pd( alphav, xv[4] );
|
||||
xv[5] = _mm256_mul_pd( alphav, xv[5] );
|
||||
xv[6] = _mm256_mul_pd( alphav, xv[6] );
|
||||
xv[7] = _mm256_mul_pd( alphav, xv[7] );
|
||||
xv[8] = _mm256_mul_pd( alphav, xv[8] );
|
||||
xv[9] = _mm256_mul_pd( alphav, xv[9] );
|
||||
xv[10] = _mm256_mul_pd( alphav, xv[10] );
|
||||
xv[11] = _mm256_mul_pd( alphav, xv[11] );
|
||||
xv[12] = _mm256_mul_pd( alphav, xv[12] );
|
||||
xv[13] = _mm256_mul_pd( alphav, xv[13] );
|
||||
xv[14] = _mm256_mul_pd( alphav, xv[14] );
|
||||
|
||||
_mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] );
|
||||
_mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] );
|
||||
_mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] );
|
||||
_mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] );
|
||||
_mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] );
|
||||
_mm256_storeu_pd( (x0 + 5*n_elem_per_reg), xv[5] );
|
||||
_mm256_storeu_pd( (x0 + 6*n_elem_per_reg), xv[6] );
|
||||
_mm256_storeu_pd( (x0 + 7*n_elem_per_reg), xv[7] );
|
||||
_mm256_storeu_pd( (x0 + 8*n_elem_per_reg), xv[8] );
|
||||
_mm256_storeu_pd( (x0 + 9*n_elem_per_reg), xv[9] );
|
||||
_mm256_storeu_pd( (x0 + 10*n_elem_per_reg), xv[10] );
|
||||
_mm256_storeu_pd( (x0 + 11*n_elem_per_reg), xv[11] );
|
||||
_mm256_storeu_pd( (x0 + 12*n_elem_per_reg), xv[12] );
|
||||
_mm256_storeu_pd( (x0 + 13*n_elem_per_reg), xv[13] );
|
||||
_mm256_storeu_pd( (x0 + 14*n_elem_per_reg), xv[14] );
|
||||
|
||||
x0 += 15 * n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; ( i + 23 ) < n; i += 24 )
|
||||
{
|
||||
xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg );
|
||||
xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg );
|
||||
xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg );
|
||||
xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg );
|
||||
xv[4] = _mm256_loadu_pd( x0 + 4 * n_elem_per_reg );
|
||||
xv[5] = _mm256_loadu_pd( x0 + 5 * n_elem_per_reg );
|
||||
xv[6] = _mm256_loadu_pd( x0 + 6 * n_elem_per_reg );
|
||||
xv[7] = _mm256_loadu_pd( x0 + 7 * n_elem_per_reg );
|
||||
xv[8] = _mm256_loadu_pd( x0 + 8 * n_elem_per_reg );
|
||||
xv[9] = _mm256_loadu_pd( x0 + 9 * n_elem_per_reg );
|
||||
xv[10] = _mm256_loadu_pd( x0 + 10 * n_elem_per_reg );
|
||||
xv[11] = _mm256_loadu_pd( x0 + 11 * n_elem_per_reg );
|
||||
|
||||
xv[0] = _mm256_mul_pd( alphav, xv[0] );
|
||||
xv[1] = _mm256_mul_pd( alphav, xv[1] );
|
||||
xv[2] = _mm256_mul_pd( alphav, xv[2] );
|
||||
xv[3] = _mm256_mul_pd( alphav, xv[3] );
|
||||
xv[4] = _mm256_mul_pd( alphav, xv[4] );
|
||||
xv[5] = _mm256_mul_pd( alphav, xv[5] );
|
||||
xv[6] = _mm256_mul_pd( alphav, xv[6] );
|
||||
xv[7] = _mm256_mul_pd( alphav, xv[7] );
|
||||
xv[8] = _mm256_mul_pd( alphav, xv[8] );
|
||||
xv[9] = _mm256_mul_pd( alphav, xv[9] );
|
||||
xv[10] = _mm256_mul_pd( alphav, xv[10] );
|
||||
xv[11] = _mm256_mul_pd( alphav, xv[11] );
|
||||
|
||||
_mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] );
|
||||
_mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] );
|
||||
_mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] );
|
||||
_mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] );
|
||||
_mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] );
|
||||
_mm256_storeu_pd( (x0 + 5*n_elem_per_reg), xv[5] );
|
||||
_mm256_storeu_pd( (x0 + 6*n_elem_per_reg), xv[6] );
|
||||
_mm256_storeu_pd( (x0 + 7*n_elem_per_reg), xv[7] );
|
||||
_mm256_storeu_pd( (x0 + 8*n_elem_per_reg), xv[8] );
|
||||
_mm256_storeu_pd( (x0 + 9*n_elem_per_reg), xv[9] );
|
||||
_mm256_storeu_pd( (x0 + 10*n_elem_per_reg), xv[10] );
|
||||
_mm256_storeu_pd( (x0 + 11*n_elem_per_reg), xv[11] );
|
||||
|
||||
x0 += 12 * n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; ( i + 15 ) < n; i += 16 )
|
||||
{
|
||||
xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg );
|
||||
xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg );
|
||||
xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg );
|
||||
xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg );
|
||||
xv[4] = _mm256_loadu_pd( x0 + 4 * n_elem_per_reg );
|
||||
xv[5] = _mm256_loadu_pd( x0 + 5 * n_elem_per_reg );
|
||||
xv[6] = _mm256_loadu_pd( x0 + 6 * n_elem_per_reg );
|
||||
xv[7] = _mm256_loadu_pd( x0 + 7 * n_elem_per_reg );
|
||||
|
||||
xv[0] = _mm256_mul_pd( alphav, xv[0] );
|
||||
xv[1] = _mm256_mul_pd( alphav, xv[1] );
|
||||
xv[2] = _mm256_mul_pd( alphav, xv[2] );
|
||||
xv[3] = _mm256_mul_pd( alphav, xv[3] );
|
||||
xv[4] = _mm256_mul_pd( alphav, xv[4] );
|
||||
xv[5] = _mm256_mul_pd( alphav, xv[5] );
|
||||
xv[6] = _mm256_mul_pd( alphav, xv[6] );
|
||||
xv[7] = _mm256_mul_pd( alphav, xv[7] );
|
||||
|
||||
_mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] );
|
||||
_mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] );
|
||||
_mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] );
|
||||
_mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] );
|
||||
_mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] );
|
||||
_mm256_storeu_pd( (x0 + 5*n_elem_per_reg), xv[5] );
|
||||
_mm256_storeu_pd( (x0 + 6*n_elem_per_reg), xv[6] );
|
||||
_mm256_storeu_pd( (x0 + 7*n_elem_per_reg), xv[7] );
|
||||
|
||||
x0 += 8 * n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; ( i + 7 ) < n; i += 8 )
|
||||
{
|
||||
xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg );
|
||||
xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg );
|
||||
xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg );
|
||||
xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg );
|
||||
|
||||
xv[0] = _mm256_mul_pd( alphav, xv[0] );
|
||||
xv[1] = _mm256_mul_pd( alphav, xv[1] );
|
||||
xv[2] = _mm256_mul_pd( alphav, xv[2] );
|
||||
xv[3] = _mm256_mul_pd( alphav, xv[3] );
|
||||
|
||||
_mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] );
|
||||
_mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] );
|
||||
_mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] );
|
||||
_mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] );
|
||||
|
||||
x0 += 4 * n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; ( i + 3 ) < n; i += 4 )
|
||||
{
|
||||
xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg );
|
||||
xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg );
|
||||
|
||||
xv[0] = _mm256_mul_pd( alphav, xv[0] );
|
||||
xv[1] = _mm256_mul_pd( alphav, xv[1] );
|
||||
|
||||
_mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] );
|
||||
_mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] );
|
||||
|
||||
x0 += 2 * n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; ( i + 1 ) < n; i += 2 )
|
||||
{
|
||||
xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg );
|
||||
|
||||
xv[0] = _mm256_mul_pd( alphav, xv[0] );
|
||||
|
||||
_mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] );
|
||||
|
||||
x0 += 1 * n_elem_per_reg;
|
||||
}
|
||||
|
||||
for ( ; i < n; i++ )
|
||||
{
|
||||
( *x0 ) *= alphac;
|
||||
( *( x0 + 1 ) ) *= alphac;
|
||||
|
||||
x0 += 2 * incx;
|
||||
}
|
||||
|
||||
// Issue vzeroupper instruction to clear upper lanes of ymm registers.
|
||||
// This avoids a performance penalty caused by false dependencies when
|
||||
// transitioning from from AVX to SSE instructions (which may occur
|
||||
// as soon as the n_left cleanup loop below if BLIS is compiled with
|
||||
// -mfpmath=sse).
|
||||
_mm256_zeroupper();
|
||||
}
|
||||
else
|
||||
{
|
||||
for ( ; i < n; ++i )
|
||||
{
|
||||
( *x0 ) *= alphac;
|
||||
( *( x0 + 1 ) ) *= alphac;
|
||||
|
||||
x0 += 2 * incx;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -385,4 +385,13 @@ void bli_dnorm2fv_unb_var1_avx
|
||||
double* x, inc_t incx,
|
||||
double* norm,
|
||||
cntx_t* cntx
|
||||
);
|
||||
|
||||
void bli_zdscalv_zen_int10
|
||||
(
|
||||
conj_t conjalpha,
|
||||
dim_t n,
|
||||
double* restrict alpha,
|
||||
dcomplex* restrict x, inc_t incx,
|
||||
cntx_t* restrict cntx
|
||||
);
|
||||
Reference in New Issue
Block a user