From a8bc55c37322f6cf6d6f7c92ae1607faf660de39 Mon Sep 17 00:00:00 2001 From: "S, HariharaSudhan" Date: Tue, 29 Mar 2022 18:05:59 +0530 Subject: [PATCH] Multithreaded SGEMV var 1 with smart threading - Implemented an OpenMP based stand alone SGEMV kernel for row-major (var 1) for multithread scenarios - Smart threading is enabled when AOCL DYNAMIC is defined - Number of threads are decided based on the input dims using smart threading AMD-Internal: [CPUPL-1984] Change-Id: I9b191e965ba7468e95aabcce21b35a533017502e --- frame/2/gemv/bli_gemv_unf_var1_amd.c | 128 ++++++++- kernels/zen/2/bli_gemv_zen_int_4.c | 395 +++++++++++++++++++++++++++ 2 files changed, 522 insertions(+), 1 deletion(-) diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c index 7228c12f7..8295f3927 100644 --- a/frame/2/gemv/bli_gemv_unf_var1_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -332,6 +332,92 @@ void bli_dgemv_unf_var1 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); } +// Returns the optimal number of threads for the given input sizes and fuse factor +void bli_sgemv_var1_smart_threading + ( + dim_t m, dim_t n, + dim_t fuse, + dim_t* nt, dim_t nt_max + ) +{ + // Calculate the amount data processed per iteration + dim_t n_per_loop = n / fuse; + double data_per_iter = n_per_loop* m; + double m_n_ratio = m/n; + + // When the input value is less than the fuse factor + if(n_per_loop < 1) + { + *nt = 1; + return; + } + + // Then there are two cases one + // In m < n the thread spawning is less aggressive when compared to m > n and m = n cases + if(m_n_ratio <= 0.6) + { + // Boundary units is the amount of data processed by each iteration + // This is the variable X in the equation + const double lower_boundary = 50000; + const double higher_boundary = 500000; + + if(data_per_iter < lower_boundary) + { + double coeff_x = 0.9148; + double constant = -1.6252; + // Number of threads = 0.9148 * log(x) - 1.6252 + *nt = ceil(coeff_x * log(data_per_iter) + constant); + } + else if(data_per_iter < higher_boundary) + { + float coeff_x = 10.23; + float constant = -82.332; + // Number of threads = 10.23 * log(x) - 82.332 + *nt = ceil(coeff_x * log(data_per_iter) + constant); + } + else + { + // When the amount of data to be processed is above both of the boundaries + // The number of threads spawned will be equal to the max number of threads set + *nt = nt_max; + } + } + else + { + // Boundary units is the amount of data processed by each iteration + // This is the variable X in the equation + const float lower_boundary = 50000; + const float higher_boundary = 360000; + + if(data_per_iter < lower_boundary) + { + float coeff_x2 = -2E-09; + float coeff_x = 0.0002; + float constant = 1.0234; + // Number of threads = -2E-09*x^2 + 0.0002 * x + 1.0234 + *nt = ceil(coeff_x2 * (data_per_iter * data_per_iter) + coeff_x * data_per_iter + constant); + } + else if(data_per_iter < higher_boundary) + { + float coeff_x = 16.917; + float constant = -164.82; + // Number of threads = 16.917 * log(x) - 164.82 + *nt = ceil(coeff_x * log(data_per_iter) + constant); + } + else + { + // When the amount of data to be processed is above both of the boundaries + // The number of threads spawned will be equal to the max number of threads set + *nt = nt_max; + } + } + + // When the number of threads calculated is greater than the user provided value + // Choose the user provided value + if(*nt > nt_max) + *nt = nt_max; +} + void bli_sgemv_unf_var1 ( trans_t transa, @@ -407,7 +493,46 @@ void bli_sgemv_unf_var1 return; } - /* Query the context for the kernel function pointer and fusing factor. */ +// If both multithreading and OpenMP are enabled, GEMV will multithread +#if defined(BLIS_ENABLE_MULTITHREADING) && defined(BLIS_ENABLE_OPENMP) + dim_t nt, nt_max; + + rntm_t rnmt_obj; + + b_fuse = 4; + + // Initialize a local runtime with global settings. + bli_rntm_init_from_global( &rnmt_obj ); + + // Query the total number of threads from the rntm_t object. + nt_max = bli_rntm_num_threads( &rnmt_obj ); + + //Setting the thread count to the maximum number of threads provided + nt = nt_max; + + // Enable smart threading when AOCL dynamic is enabled + #ifdef AOCL_DYNAMIC + bli_sgemv_var1_smart_threading(n_elem, n_iter, b_fuse, &nt, nt_max); + #endif + + // Pass the input paramaters along with the number of threads to be used + bli_multi_sgemv_4x2 + ( + conja, + conjx, + n_elem, + n_iter, + alpha, + a, cs_at, rs_at, + x, incx, + beta, + y, incy, + cntx, + nt + ); + +#else + b_fuse = 8; for ( i = 0; i < n_iter; i += f ) @@ -434,6 +559,7 @@ void bli_sgemv_unf_var1 ); } +#endif } INSERT_GENTFUNC_BASIC0_CZ( gemv_unf_var1 ) diff --git a/kernels/zen/2/bli_gemv_zen_int_4.c b/kernels/zen/2/bli_gemv_zen_int_4.c index b3c92b551..74904605e 100644 --- a/kernels/zen/2/bli_gemv_zen_int_4.c +++ b/kernels/zen/2/bli_gemv_zen_int_4.c @@ -35,6 +35,24 @@ #include "immintrin.h" #include "blis.h" +/* Union data structure to access AVX registers + One 256-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + + +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 4 SP elements. */ +typedef union +{ + __m128 v; + float f[4] __attribute__((aligned(64))); +} v4sf_t; + + /* This implementation uses 512 bits of cache line efficiently for column stored matrix and vectors. @@ -477,3 +495,380 @@ void bli_cgemv_zen_int_4x4 } } + +/* +Function performs multithreaded GEMV for float datatype +All parameters are similar to single thread GEMV except +n_thread which specifies the number of threads to be used +*/ +void bli_multi_sgemv_4x2 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + float* restrict alpha, + float* restrict a, inc_t inca, inc_t lda, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx, + dim_t n_threads + ) +{ + const dim_t b_fuse = 4; + const dim_t n_elem_per_reg = 8; + dim_t total_iteration = 0; + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(s, eq0)(*alpha)) + { + + bli_sscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if (b_n < b_fuse) + { + for (dim_t i = 0; i < b_n; ++i) + { + float *a1 = a + (0) * inca + (i)*lda; + float *x1 = x + (0) * incx; + float *psi1 = y + (i)*incy; + + bli_sdotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } + + // Calculate the total number of multithreaded iteration + total_iteration = b_n / b_fuse; + +#pragma omp parallel for num_threads(n_threads) + for (dim_t j = 0; j < total_iteration; j++) + { + float *A1 = a + (b_fuse * j) * lda; + float *x1 = x; + float *y1 = y + (b_fuse * j) * incy; + + // Intermediate variables to hold the completed dot products + float rho0[4] = {0, 0, 0, 0}; + + // If vectorization is possible, perform them with vector + // instructions. + if (inca == 1 && incx == 1) + { + const dim_t n_iter_unroll = 2; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t l, unroll_inc, m_viter[2], m_left = m; + + unroll_inc = n_elem_per_reg * n_iter_unroll; + + m_viter[0] = m_left / unroll_inc; + m_left = m_left % unroll_inc; + + m_viter[1] = m_left / n_elem_per_reg ; + m_left = m_left % n_elem_per_reg; + + // Set up pointers for x and the b_n columns of A (rows of A^T). + float *restrict x0 = x1; + float *restrict av[4]; + + av[0] = A1 + 0 * lda; + av[1] = A1 + 1 * lda; + av[2] = A1 + 2 * lda; + av[3] = A1 + 3 * lda; + + // Initialize b_n rho vector accumulators to zero. + v8sf_t rhov[4]; + + rhov[0].v = _mm256_setzero_ps(); + rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); + + v8sf_t xv[2]; + v8sf_t a_vec[8]; + + // FMA operation is broken down to mul and add + // to reduce backend stalls + for (l = 0; l < m_viter[0]; ++l) + { + xv[0].v = _mm256_loadu_ps(x0); + x0 += n_elem_per_reg; + xv[1].v = _mm256_loadu_ps(x0); + x0 += n_elem_per_reg; + + a_vec[0].v = _mm256_loadu_ps(av[0]); + a_vec[4].v = _mm256_loadu_ps(av[0] + n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + a_vec[0].v = _mm256_mul_ps(a_vec[0].v, xv[0].v); + rhov[0].v = _mm256_fmadd_ps(a_vec[4].v, xv[1].v, rhov[0].v); + rhov[0].v = _mm256_add_ps(a_vec[0].v, rhov[0].v); + + a_vec[1].v = _mm256_loadu_ps(av[1]); + a_vec[5].v = _mm256_loadu_ps(av[1] + n_elem_per_reg); + + a_vec[1].v = _mm256_mul_ps(a_vec[1].v, xv[0].v); + rhov[1].v = _mm256_fmadd_ps(a_vec[5].v, xv[1].v, rhov[1].v); + rhov[1].v = _mm256_add_ps(a_vec[1].v, rhov[1].v); + + a_vec[2].v = _mm256_loadu_ps(av[2]); + a_vec[6].v = _mm256_loadu_ps(av[2] + n_elem_per_reg); + + a_vec[2].v = _mm256_mul_ps(a_vec[2].v, xv[0].v); + rhov[2].v = _mm256_fmadd_ps(a_vec[6].v, xv[1].v, rhov[2].v); + rhov[2].v = _mm256_add_ps(a_vec[2].v, rhov[2].v); + + a_vec[3].v = _mm256_loadu_ps(av[3]); + a_vec[7].v = _mm256_loadu_ps(av[3] + n_elem_per_reg); + + a_vec[3].v = _mm256_mul_ps(a_vec[3].v, xv[0].v); + rhov[3].v = _mm256_fmadd_ps(a_vec[7].v, xv[1].v, rhov[3].v); + rhov[3].v = _mm256_add_ps(a_vec[3].v, rhov[3].v); + + av[0] += unroll_inc; + av[1] += unroll_inc; + av[2] += unroll_inc; + av[3] += unroll_inc; + } + + for (l = 0; l < m_viter[1]; ++l) + { + // Load the input values. + xv[0].v = _mm256_loadu_ps(x0); + x0 += n_elem_per_reg; + + a_vec[0].v = _mm256_loadu_ps(av[0]); + a_vec[1].v = _mm256_loadu_ps(av[1]); + + rhov[0].v = _mm256_fmadd_ps(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_ps(a_vec[1].v, xv[0].v, rhov[1].v); + + av[0] += n_elem_per_reg; + av[1] += n_elem_per_reg; + + a_vec[2].v = _mm256_loadu_ps(av[2]); + a_vec[3].v = _mm256_loadu_ps(av[3]); + + rhov[2].v = _mm256_fmadd_ps(a_vec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_ps(a_vec[3].v, xv[0].v, rhov[3].v); + + av[2] += n_elem_per_reg; + av[3] += n_elem_per_reg; + } + + // Sum the elements within each vector. + // Sum the elements of a given rho?v with hadd. + rhov[0].v = _mm256_hadd_ps(rhov[0].v, rhov[1].v); + rhov[2].v = _mm256_hadd_ps(rhov[2].v, rhov[3].v); + rhov[0].v = _mm256_hadd_ps(rhov[0].v, rhov[0].v); + rhov[2].v = _mm256_hadd_ps(rhov[2].v, rhov[2].v); + + // Manually add the results from above to finish the sum. + rho0[0] = rhov[0].f[0] + rhov[0].f[4]; + rho0[1] = rhov[0].f[1] + rhov[0].f[5]; + rho0[2] = rhov[2].f[0] + rhov[2].f[4]; + rho0[3] = rhov[2].f[1] + rhov[2].f[5]; + + // If leftover elements are more than 4, perform SSE + if (m_left > 4) + { + v4sf_t xv128, a_vec128[4], rhov128[4]; + + rhov128[0].v = _mm_set1_ps(0); + rhov128[1].v = _mm_set1_ps(0); + rhov128[2].v = _mm_set1_ps(0); + rhov128[3].v = _mm_set1_ps(0); + + // Load the input values. + xv128.v = _mm_loadu_ps(x0 + 0 * n_elem_per_reg); + x0 += 4; + m_left -= 4; + + a_vec128[0].v = _mm_loadu_ps(av[0]); + a_vec128[1].v = _mm_loadu_ps(av[1]); + + // perform: rho?v += a?v * x0v; + rhov128[0].v = _mm_fmadd_ps(a_vec128[0].v, xv128.v, rhov128[0].v); + rhov128[1].v = _mm_fmadd_ps(a_vec128[1].v, xv128.v, rhov128[1].v); + rhov128[0].v = _mm_hadd_ps(rhov128[0].v, rhov128[1].v); + rhov128[0].v = _mm_hadd_ps(rhov128[0].v, rhov128[0].v); + + a_vec128[2].v = _mm_loadu_ps(av[2]); + a_vec128[3].v = _mm_loadu_ps(av[3]); + + rhov128[2].v = _mm_fmadd_ps(a_vec128[2].v, xv128.v, rhov128[2].v); + rhov128[3].v = _mm_fmadd_ps(a_vec128[3].v, xv128.v, rhov128[3].v); + rhov128[2].v = _mm_hadd_ps(rhov128[2].v, rhov128[3].v); + rhov128[2].v = _mm_hadd_ps(rhov128[2].v, rhov128[2].v); + + rho0[0] += rhov128[0].f[0]; + rho0[1] += rhov128[0].f[1]; + rho0[2] += rhov128[2].f[0]; + rho0[3] += rhov128[2].f[1]; + + av[0] += 4; + av[1] += 4; + av[2] += 4; + av[3] += 4; + } + + // If there are leftover iterations, perform them with scalar code. + for (l = 0; l < m_left; ++l) + { + rho0[0] += *(av[0]) * (*x0); + rho0[1] += *(av[1]) * (*x0); + rho0[2] += *(av[2]) * (*x0); + rho0[3] += *(av[3]) * (*x0); + + x0 += incx; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + } + + } + else + { + // When vectorization is not possible, perform with scalar code + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + float *restrict x0 = x1; + float *restrict a0 = A1 + 0 * lda; + float *restrict a1 = A1 + 1 * lda; + float *restrict a2 = A1 + 2 * lda; + float *restrict a3 = A1 + 3 * lda; + + for (dim_t l = 0; l < m; ++l) + { + const float x0c = *x0; + + const float a0c = *a0; + const float a1c = *a1; + const float a2c = *a2; + const float a3c = *a3; + + rho0[0] += a0c * x0c; + rho0[1] += a1c * x0c; + rho0[2] += a2c * x0c; + rho0[3] += a3c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + } + } + + v4sf_t rho0v, y0v; + + rho0v.v = _mm_loadu_ps(rho0); + + // Broadcast the alpha scalar. + v4sf_t alphav; + alphav.v = _mm_broadcast_ss(alpha); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(s, eq0)(*beta)) + { + // Apply alpha to the accumulated dot product in rho: + // y := alpha * rho + y0v.v = _mm_mul_ps(alphav.v, rho0v.v); + } + else + { + // Broadcast the beta scalar. + v4sf_t betav; + betav.v = _mm_broadcast_ss(beta); + + if (incy == 0) + { + // Load y. + y0v.v = _mm_loadu_ps(y1 + 0 * n_elem_per_reg); + } + else + { + // Load y. + y0v.f[0] = *(y1 + 0 * incy); + y0v.f[1] = *(y1 + 1 * incy); + y0v.f[2] = *(y1 + 2 * incy); + y0v.f[3] = *(y1 + 3 * incy); + } + + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + y0v.v = _mm_mul_ps(betav.v, y0v.v); + y0v.v = _mm_fmadd_ps(alphav.v, rho0v.v, y0v.v); + } + + // Store the output. + if (incy == 1) + { + _mm_storeu_ps((y1 + 0 * n_elem_per_reg), y0v.v); + } + else + { + // Store the output. + *(y1 + 0 * incy) = y0v.f[0]; + *(y1 + 1 * incy) = y0v.f[1]; + *(y1 + 2 * incy) = y0v.f[2]; + *(y1 + 3 * incy) = y0v.f[3]; + } + } + + // Performs the complete computation if OpenMP is not enabled + dim_t start = total_iteration * b_fuse; + dim_t new_fuse = 8, f; + + // Left over corner cases completed using fused kernel + for (dim_t i = start; i < b_n; i += f) + { + f = bli_determine_blocksize_dim_f(i, b_n, new_fuse); + + float *A1 = a + (i)*lda + (0) * inca; + float *x1 = x + (0) * incx; + float *y1 = y + (i)*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + bli_sdotxf_zen_int_8( + conjat, + conjx, + m, + f, + alpha, + A1, inca, lda, + x1, incx, + beta, + y1, incy, + cntx); + } +}