diff --git a/frame/compat/bla_scal_amd.c b/frame/compat/bla_scal_amd.c index 2a80636b0..195a14ee5 100644 --- a/frame/compat/bla_scal_amd.c +++ b/frame/compat/bla_scal_amd.c @@ -355,6 +355,176 @@ void dscal_ dscal_blis_impl( n, alpha, x, incx ); } +void zdscal_blis_impl + ( + const f77_int* n, + const double* alpha, + dcomplex* x, const f77_int* incx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', (void *) alpha, *n, *incx ); + dim_t n0; + dcomplex* x0; + inc_t incx0; + /* Initialize BLIS. */ + //bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + if (*n == 0 || alpha == NULL) { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if ( bli_cpuid_is_avx_supported() == TRUE ) + { +#ifdef BLIS_ENABLE_OPENMP + // For sizes less than 10000, optimal number of threads is 1, but + // due to the overhead of calling omp functions it is being done outside + // by directly calling dscalv so as to get maximum performance. + if ( n0 <= 10000 ) + { + bli_zdscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (double*) alpha, + x0, incx0, + NULL + ); + } + else + { + rntm_t rntm_local; + bli_rntm_init_from_global( &rntm_local ); + dim_t nt = bli_rntm_num_threads( &rntm_local ); + +#ifdef AOCL_DYNAMIC + dim_t nt_ideal; + + if ( n0 <= 20000 ) nt_ideal = 4; + else if ( n0 <= 1000000 ) nt_ideal = 8; + else if ( n0 <= 2500000 ) nt_ideal = 12; + else if ( n0 <= 5000000 ) nt_ideal = 32; + else nt_ideal = 64; + + nt = bli_min( nt_ideal, nt ); +#endif + dim_t n_elem_per_thread = n0 / nt; + dim_t n_elem_rem = n0 % nt; + + #pragma omp parallel num_threads( nt ) + { + // The following conditions handle the optimal distribution of + // load among the threads. + // Say we have n0 = 50 & nt = 4. + // So we get 12 ( n0 / nt ) elements per thread along with 2 + // remaining elements. Each of these remaining elements is given + // to the last threads, respectively. + // So, t0, t1, t2 and t3 gets 12, 12, 13 and 13 elements, + // respectively. + dim_t t_id = omp_get_thread_num(); + dim_t npt, offset; + + if ( t_id < ( nt - n_elem_rem ) ) + { + npt = n_elem_per_thread; + offset = t_id * npt * incx0; + } + else + { + npt = n_elem_per_thread + 1; + offset = ( ( t_id * n_elem_per_thread ) + + ( t_id - ( nt - n_elem_rem ) ) ) * incx0; + } + + bli_zdscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + npt, + (double *) alpha, + x0 + offset, incx0, + NULL + ); + } + } +#else + // Default call to zdscalv for single-threaded work + bli_zdscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (double *) alpha, + x0, incx0, + NULL + ); +#endif + } + else + { + // Sub-optimal implementation for zdscal + // by casting alpha to the double complex domain and + // calling the zscal + dcomplex alpha_cast; + PASTEMAC2(d,z,copys)( *alpha, alpha_cast ); + + /* Call BLIS interface. */ \ + PASTEMAC2(z,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + &alpha_cast, \ + x0, incx0, \ + NULL, \ + NULL \ + ); \ + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +void zdscal_ + ( + const f77_int* n, + const double* alpha, + dcomplex* x, const f77_int* incx + ) +{ + zdscal_blis_impl( n, alpha, x, incx ); +} + + INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv ) #endif diff --git a/frame/include/bli_gentfunc_macro_defs.h b/frame/include/bli_gentfunc_macro_defs.h index 49c79cb8a..9e1b4d70f 100644 --- a/frame/include/bli_gentfunc_macro_defs.h +++ b/frame/include/bli_gentfunc_macro_defs.h @@ -159,8 +159,8 @@ GENTFUNCR2( dcomplex, double, z, d, blasname, blisname ) \ GENTFUNCSCAL( scomplex, scomplex, c, , blasname, blisname ) \ GENTFUNCSCAL( dcomplex, dcomplex, z, , blasname, blisname ) \ -GENTFUNCSCAL( scomplex, float, c, s, blasname, blisname ) \ -GENTFUNCSCAL( dcomplex, double, z, d, blasname, blisname ) +GENTFUNCSCAL( scomplex, float, c, s, blasname, blisname ) +// GENTFUNCSCAL( dcomplex, double, z, d, blasname, blisname ) #define INSERT_GENTFUNCSCAL_BLAS( blasname, blisname ) \ diff --git a/kernels/zen/1/bli_scalv_zen_int10.c b/kernels/zen/1/bli_scalv_zen_int10.c index 7146e8687..9cea24680 100644 --- a/kernels/zen/1/bli_scalv_zen_int10.c +++ b/kernels/zen/1/bli_scalv_zen_int10.c @@ -578,3 +578,225 @@ void bli_dscalv_zen_int10 } } +void bli_zdscalv_zen_int10 + ( + conj_t conjalpha, + dim_t n, + double* restrict alpha, + dcomplex* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + dim_t i = 0; + const dim_t n_elem_per_reg = 4; // number of elements per register + + double* restrict x0 = (double*) x; + const double alphac = *alpha; + + if ( incx == 1 ) + { + __m256d alphav; + __m256d xv[15]; + + alphav = _mm256_broadcast_sd( alpha ); + + for ( ; ( i + 29 ) < n; i += 30 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg ); + xv[4] = _mm256_loadu_pd( x0 + 4 * n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5 * n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6 * n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7 * n_elem_per_reg ); + xv[8] = _mm256_loadu_pd( x0 + 8 * n_elem_per_reg ); + xv[9] = _mm256_loadu_pd( x0 + 9 * n_elem_per_reg ); + xv[10] = _mm256_loadu_pd( x0 + 10 * n_elem_per_reg ); + xv[11] = _mm256_loadu_pd( x0 + 11 * n_elem_per_reg ); + xv[12] = _mm256_loadu_pd( x0 + 12 * n_elem_per_reg ); + xv[13] = _mm256_loadu_pd( x0 + 13 * n_elem_per_reg ); + xv[14] = _mm256_loadu_pd( x0 + 14 * n_elem_per_reg ); + + xv[0] = _mm256_mul_pd( alphav, xv[0] ); + xv[1] = _mm256_mul_pd( alphav, xv[1] ); + xv[2] = _mm256_mul_pd( alphav, xv[2] ); + xv[3] = _mm256_mul_pd( alphav, xv[3] ); + xv[4] = _mm256_mul_pd( alphav, xv[4] ); + xv[5] = _mm256_mul_pd( alphav, xv[5] ); + xv[6] = _mm256_mul_pd( alphav, xv[6] ); + xv[7] = _mm256_mul_pd( alphav, xv[7] ); + xv[8] = _mm256_mul_pd( alphav, xv[8] ); + xv[9] = _mm256_mul_pd( alphav, xv[9] ); + xv[10] = _mm256_mul_pd( alphav, xv[10] ); + xv[11] = _mm256_mul_pd( alphav, xv[11] ); + xv[12] = _mm256_mul_pd( alphav, xv[12] ); + xv[13] = _mm256_mul_pd( alphav, xv[13] ); + xv[14] = _mm256_mul_pd( alphav, xv[14] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), xv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), xv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), xv[7] ); + _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), xv[8] ); + _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), xv[9] ); + _mm256_storeu_pd( (x0 + 10*n_elem_per_reg), xv[10] ); + _mm256_storeu_pd( (x0 + 11*n_elem_per_reg), xv[11] ); + _mm256_storeu_pd( (x0 + 12*n_elem_per_reg), xv[12] ); + _mm256_storeu_pd( (x0 + 13*n_elem_per_reg), xv[13] ); + _mm256_storeu_pd( (x0 + 14*n_elem_per_reg), xv[14] ); + + x0 += 15 * n_elem_per_reg; + } + + for ( ; ( i + 23 ) < n; i += 24 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg ); + xv[4] = _mm256_loadu_pd( x0 + 4 * n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5 * n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6 * n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7 * n_elem_per_reg ); + xv[8] = _mm256_loadu_pd( x0 + 8 * n_elem_per_reg ); + xv[9] = _mm256_loadu_pd( x0 + 9 * n_elem_per_reg ); + xv[10] = _mm256_loadu_pd( x0 + 10 * n_elem_per_reg ); + xv[11] = _mm256_loadu_pd( x0 + 11 * n_elem_per_reg ); + + xv[0] = _mm256_mul_pd( alphav, xv[0] ); + xv[1] = _mm256_mul_pd( alphav, xv[1] ); + xv[2] = _mm256_mul_pd( alphav, xv[2] ); + xv[3] = _mm256_mul_pd( alphav, xv[3] ); + xv[4] = _mm256_mul_pd( alphav, xv[4] ); + xv[5] = _mm256_mul_pd( alphav, xv[5] ); + xv[6] = _mm256_mul_pd( alphav, xv[6] ); + xv[7] = _mm256_mul_pd( alphav, xv[7] ); + xv[8] = _mm256_mul_pd( alphav, xv[8] ); + xv[9] = _mm256_mul_pd( alphav, xv[9] ); + xv[10] = _mm256_mul_pd( alphav, xv[10] ); + xv[11] = _mm256_mul_pd( alphav, xv[11] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), xv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), xv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), xv[7] ); + _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), xv[8] ); + _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), xv[9] ); + _mm256_storeu_pd( (x0 + 10*n_elem_per_reg), xv[10] ); + _mm256_storeu_pd( (x0 + 11*n_elem_per_reg), xv[11] ); + + x0 += 12 * n_elem_per_reg; + } + + for ( ; ( i + 15 ) < n; i += 16 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg ); + xv[4] = _mm256_loadu_pd( x0 + 4 * n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5 * n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6 * n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7 * n_elem_per_reg ); + + xv[0] = _mm256_mul_pd( alphav, xv[0] ); + xv[1] = _mm256_mul_pd( alphav, xv[1] ); + xv[2] = _mm256_mul_pd( alphav, xv[2] ); + xv[3] = _mm256_mul_pd( alphav, xv[3] ); + xv[4] = _mm256_mul_pd( alphav, xv[4] ); + xv[5] = _mm256_mul_pd( alphav, xv[5] ); + xv[6] = _mm256_mul_pd( alphav, xv[6] ); + xv[7] = _mm256_mul_pd( alphav, xv[7] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), xv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), xv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), xv[7] ); + + x0 += 8 * n_elem_per_reg; + } + + for ( ; ( i + 7 ) < n; i += 8 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg ); + + xv[0] = _mm256_mul_pd( alphav, xv[0] ); + xv[1] = _mm256_mul_pd( alphav, xv[1] ); + xv[2] = _mm256_mul_pd( alphav, xv[2] ); + xv[3] = _mm256_mul_pd( alphav, xv[3] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); + + x0 += 4 * n_elem_per_reg; + } + + for ( ; ( i + 3 ) < n; i += 4 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + + xv[0] = _mm256_mul_pd( alphav, xv[0] ); + xv[1] = _mm256_mul_pd( alphav, xv[1] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + + x0 += 2 * n_elem_per_reg; + } + + for ( ; ( i + 1 ) < n; i += 2 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg ); + + xv[0] = _mm256_mul_pd( alphav, xv[0] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); + + x0 += 1 * n_elem_per_reg; + } + + for ( ; i < n; i++ ) + { + ( *x0 ) *= alphac; + ( *( x0 + 1 ) ) *= alphac; + + x0 += 2 * incx; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + } + else + { + for ( ; i < n; ++i ) + { + ( *x0 ) *= alphac; + ( *( x0 + 1 ) ) *= alphac; + + x0 += 2 * incx; + } + } +} \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 210052d06..193fcd0fb 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -385,4 +385,13 @@ void bli_dnorm2fv_unb_var1_avx double* x, inc_t incx, double* norm, cntx_t* cntx + ); + +void bli_zdscalv_zen_int10 + ( + conj_t conjalpha, + dim_t n, + double* restrict alpha, + dcomplex* restrict x, inc_t incx, + cntx_t* restrict cntx ); \ No newline at end of file