diff --git a/frame/2/gemv/bli_gemv_unf_var1.c b/frame/2/gemv/bli_gemv_unf_var1.c index 677dd0c47..25496f4bd 100644 --- a/frame/2/gemv/bli_gemv_unf_var1.c +++ b/frame/2/gemv/bli_gemv_unf_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -52,6 +52,8 @@ void PASTEMAC(ch,varname) \ cntx_t* cntx \ ) \ { \ +\ + if(cntx == NULL) cntx = bli_gks_query_cntx(); \ \ const num_t dt = PASTEMAC(ch,type); \ \ diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index b1afbba09..563ec8069 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -177,8 +177,8 @@ void bli_dgemv_unf_var2 NULL ); - /* Query the context for the kernel function pointer and fusing factor. */ - b_fuse = 5; + /* Fusing factor. */ + b_fuse = 4; for ( i = 0; i < n_iter; i += f ) { @@ -189,7 +189,7 @@ void bli_dgemv_unf_var2 y1 = y + (0 )*incy; /* y = y + alpha * A1 * x1; */ - bli_daxpyf_zen_int_5 + bli_daxpyf_zen_int_16x4 ( conja, conjx, diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 96884c3df..ab58842c7 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -149,6 +149,9 @@ void PASTEF77(ch,blasname) \ trans_t blis_transb; \ dim_t m0, n0, k0; \ AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ @@ -184,6 +187,71 @@ void PASTEF77(ch,blasname) \ const inc_t cs_b = *ldb; \ const inc_t rs_c = 1; \ const inc_t cs_c = *ldc; \ +\ + if( n0 == 1 ) \ + { \ + if(bli_is_notrans(blis_transa)) \ + { \ + PASTEMAC(ch,gemv_unf_var2)( \ + BLIS_NO_TRANSPOSE, \ + bli_extract_conj(blis_transb), \ + m0, k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a,\ + (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ + (ftype*) beta, \ + c, rs_c, \ + NULL \ + ); \ + } \ + else \ + { \ + PASTEMAC(ch,gemv_unf_var1)( \ + blis_transa, \ + bli_extract_conj(blis_transb), \ + k0, m0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ + (ftype*)beta, \ + c, rs_c, \ + NULL \ + ); \ + } \ + return; \ + } \ + else if( m0 == 1 ) \ + { \ + if(bli_is_notrans(blis_transb)) \ + { \ + PASTEMAC(ch,gemv_unf_var1)( \ + blis_transb, \ + bli_extract_conj(blis_transa), \ + n0, k0, \ + (ftype*)alpha, \ + (ftype*)b, cs_b, rs_b, \ + (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ + (ftype*)beta, \ + c, cs_c, \ + NULL \ + ); \ + } \ + else \ + { \ + PASTEMAC(ch,gemv_unf_var2)( \ + blis_transb, \ + bli_extract_conj(blis_transa), \ + k0, n0, \ + (ftype*)alpha, \ + (ftype*)b, cs_b, rs_b, \ + (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ + (ftype*)beta, \ + c, cs_c, \ + NULL \ + ); \ + } \ + return; \ + } \ \ const num_t dt = PASTEMAC(ch,type); \ \ @@ -192,9 +260,6 @@ void PASTEF77(ch,blasname) \ obj_t bo = BLIS_OBJECT_INITIALIZER; \ obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ obj_t co = BLIS_OBJECT_INITIALIZER; \ -\ - dim_t m0_a, n0_a; \ - dim_t m0_b, n0_b; \ \ bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c index da9df9188..f77038919 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -48,9 +48,16 @@ typedef union typedef union { __m256d v; + __m128d xmm[2]; double d[4] __attribute__((aligned(64))); } v4df_t; +typedef union +{ + __m128d v; + double d[2] __attribute__((aligned(64))); +} v2df_t; + void bli_saxpyf_zen_int_5 ( @@ -597,6 +604,737 @@ void bli_daxpyf_zen_int_5 } } +// ----------------------------------------------------------------------------- + +static void bli_daxpyf_zen_int_16x2 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t fuse_fac = 2; + + const dim_t n_elem_per_reg = 4; + const dim_t n_iter_unroll = 4; + + dim_t i; + + double* restrict a0; + double* restrict a1; + + double* restrict y0; + + v4df_t chi0v, chi1v; + + v4df_t a00v, a01v; + + v4df_t a10v, a11v; + + v4df_t a20v, a21v; + + v4df_t a30v, a31v; + + v4df_t y0v, y1v, y2v, y3v; + + double chi0, chi1; + + v2df_t a40v, a41v; + + v2df_t y4v; + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { +#ifdef BLIS_CONFIG_EPYC + for ( i = 0; i < b_n; ++i ) + { + double* a1 = a + (0 )*inca + (i )*lda; + double* chi1 = x + (i )*incx; + double* y1 = y + (0 )*incy; + double alpha_chi1; + + bli_dcopycjs( conjx, *chi1, alpha_chi1 ); + bli_dscals( *alpha, alpha_chi1 ); + + bli_daxpyv_zen_int10 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + +#else + daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + double* a1 = a + (0 )*inca + (i )*lda; + double* chi1 = x + (i )*incx; + double* y1 = y + (0 )*incy; + double alpha_chi1; + + bli_dcopycjs( conjx, *chi1, alpha_chi1 ); + bli_dscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + +#endif + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + + a0 = a + 0*lda; + a1 = a + 1*lda; + + y0 = y; + + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + + + // Scale each chi scalar by alpha. + bli_dscals( *alpha, chi0 ); + bli_dscals( *alpha, chi1 ); + + // Broadcast the (alpha*chi?) scalars to all elements of vector registers. + chi0v.v = _mm256_broadcast_sd( &chi0 ); + chi1v.v = _mm256_broadcast_sd( &chi1 ); + + // If there are vectorized iterations, perform them with vector + // instructions. + if ( inca == 1 && incy == 1 ) + { + for ( i = 0; (i + 15) < m; i += 16 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg ); + a30v.v = _mm256_loadu_pd( a0 + 3*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg ); + a31v.v = _mm256_loadu_pd( a1 + 3*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v ); + y3v.v = _mm256_fmadd_pd( a30v.v, chi0v.v, y3v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v ); + y3v.v = _mm256_fmadd_pd( a31v.v, chi1v.v, y3v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v ); + _mm256_storeu_pd( (double *)(y0 + 3*n_elem_per_reg), y3v.v ); + + y0 += n_iter_unroll * n_elem_per_reg; + a0 += n_iter_unroll * n_elem_per_reg; + a1 += n_iter_unroll * n_elem_per_reg; + } + + for ( ; (i + 11) < m; i += 12 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v ); + + y0 += 3 * n_elem_per_reg; + a0 += 3 * n_elem_per_reg; + a1 += 3 * n_elem_per_reg; + } + for ( ; (i + 7) < m; i += 8 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + + y0 += 2 * n_elem_per_reg; + a0 += 2 * n_elem_per_reg; + a1 += 2 * n_elem_per_reg; + } + + for ( ; (i + 3) < m; i += 4 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + + y0 += n_elem_per_reg; + a0 += n_elem_per_reg; + a1 += n_elem_per_reg; + } + + for ( ; (i + 1) < m; i += 2 ) + { + // Load the input values. + y4v.v = _mm_loadu_pd( y0 + 0*n_elem_per_reg ); + + a40v.v = _mm_loadu_pd( a0 + 0*n_elem_per_reg ); + + a41v.v = _mm_loadu_pd( a1 + 0*n_elem_per_reg ); + + // perform : y += alpha * x; + y4v.v = _mm_fmadd_pd( a40v.v, chi0v.xmm[0], y4v.v ); + + y4v.v = _mm_fmadd_pd( a41v.v, chi1v.xmm[0], y4v.v ); + + // Store the output. + _mm_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y4v.v ); + + y0 += 2; + a0 += 2; + a1 += 2; + } + + // If there are leftover iterations, perform them with scalar code. + for ( ; (i + 0) < m ; ++i ) + { + double y0c = *y0; + + const double a0c = *a0; + const double a1c = *a1; + + y0c += chi0 * a0c; + y0c += chi1 * a1c; + + *y0 = y0c; + + a0 += 1; + a1 += 1; + y0 += 1; + } + } + else + { + for ( i = 0; (i + 0) < m ; ++i ) + { + double y0c = *y0; + + const double a0c = *a0; + const double a1c = *a1; + + y0c += chi0 * a0c; + y0c += chi1 * a1c; + + *y0 = y0c; + + a0 += inca; + a1 += inca; + y0 += incy; + } + + } +} + +// ----------------------------------------------------------------------------- +void bli_daxpyf_zen_int_16x4 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t fuse_fac = 4; + + const dim_t n_elem_per_reg = 4; + const dim_t n_iter_unroll = 4; + + dim_t i; + + double* restrict a0; + double* restrict a1; + double* restrict a2; + double* restrict a3; + + double* restrict y0; + + v4df_t chi0v, chi1v, chi2v, chi3v; + + v4df_t a00v, a01v, a02v, a03v; + + v4df_t a10v, a11v, a12v, a13v; + + v4df_t a20v, a21v, a22v, a23v; + + v4df_t a30v, a31v, a32v, a33v; + + v4df_t y0v, y1v, y2v, y3v; + + double chi0, chi1, chi2, chi3; + + v2df_t y4v; + + v2df_t a40v, a41v, a42v, a43v; + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { +#ifdef BLIS_CONFIG_EPYC + if(b_n & 2) + { + bli_daxpyf_zen_int_16x2( conja, + conjx, + m, 2, + alpha, a, inca, lda, + x, incx, + y, incy, + cntx + ); + b_n -= 2; + a += 2*lda; + x += 2 * incx; + } + for ( i = 0; i < b_n; ++i ) + { + double* a1 = a + (0 )*inca + (i )*lda; + double* chi1 = x + (i )*incx; + double* y1 = y + (0 )*incy; + double alpha_chi1; + + bli_dcopycjs( conjx, *chi1, alpha_chi1 ); + bli_dscals( *alpha, alpha_chi1 ); + + bli_daxpyv_zen_int10 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + +#else + daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + double* a1 = a + (0 )*inca + (i )*lda; + double* chi1 = x + (i )*incx; + double* y1 = y + (0 )*incy; + double alpha_chi1; + + bli_dcopycjs( conjx, *chi1, alpha_chi1 ); + bli_dscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + +#endif + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + + a0 = a + 0*lda; + a1 = a + 1*lda; + a2 = a + 2*lda; + a3 = a + 3*lda; + + y0 = y; + + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + + // Scale each chi scalar by alpha. + bli_dscals( *alpha, chi0 ); + bli_dscals( *alpha, chi1 ); + bli_dscals( *alpha, chi2 ); + bli_dscals( *alpha, chi3 ); + + // Broadcast the (alpha*chi?) scalars to all elements of vector registers. + chi0v.v = _mm256_broadcast_sd( &chi0 ); + chi1v.v = _mm256_broadcast_sd( &chi1 ); + chi2v.v = _mm256_broadcast_sd( &chi2 ); + chi3v.v = _mm256_broadcast_sd( &chi3 ); + + // If there are vectorized iterations, perform them with vector + // instructions. + if ( inca == 1 && incy == 1 ) + { + for ( i = 0; (i + 15) < m; i += 16 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg ); + a30v.v = _mm256_loadu_pd( a0 + 3*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg ); + a31v.v = _mm256_loadu_pd( a1 + 3*n_elem_per_reg ); + + a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); + a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg ); + a22v.v = _mm256_loadu_pd( a2 + 2*n_elem_per_reg ); + a32v.v = _mm256_loadu_pd( a2 + 3*n_elem_per_reg ); + + a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); + a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg ); + a23v.v = _mm256_loadu_pd( a3 + 2*n_elem_per_reg ); + a33v.v = _mm256_loadu_pd( a3 + 3*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v ); + y3v.v = _mm256_fmadd_pd( a30v.v, chi0v.v, y3v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v ); + y3v.v = _mm256_fmadd_pd( a31v.v, chi1v.v, y3v.v ); + + y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a22v.v, chi2v.v, y2v.v ); + y3v.v = _mm256_fmadd_pd( a32v.v, chi2v.v, y3v.v ); + + y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a23v.v, chi3v.v, y2v.v ); + y3v.v = _mm256_fmadd_pd( a33v.v, chi3v.v, y3v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v ); + _mm256_storeu_pd( (double *)(y0 + 3*n_elem_per_reg), y3v.v ); + + y0 += n_iter_unroll * n_elem_per_reg; + a0 += n_iter_unroll * n_elem_per_reg; + a1 += n_iter_unroll * n_elem_per_reg; + a2 += n_iter_unroll * n_elem_per_reg; + a3 += n_iter_unroll * n_elem_per_reg; + } + + for ( ; (i + 11) < m; i += 12 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + a20v.v = _mm256_loadu_pd( a0 + 2*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + a21v.v = _mm256_loadu_pd( a1 + 2*n_elem_per_reg ); + + a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); + a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg ); + a22v.v = _mm256_loadu_pd( a2 + 2*n_elem_per_reg ); + + a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); + a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg ); + a23v.v = _mm256_loadu_pd( a3 + 2*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a20v.v, chi0v.v, y2v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a21v.v, chi1v.v, y2v.v ); + + y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a22v.v, chi2v.v, y2v.v ); + + y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); + y2v.v = _mm256_fmadd_pd( a23v.v, chi3v.v, y2v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd( (double *)(y0 + 2*n_elem_per_reg), y2v.v ); + + y0 += 3 * n_elem_per_reg; + a0 += 3 * n_elem_per_reg; + a1 += 3 * n_elem_per_reg; + a2 += 3 * n_elem_per_reg; + a3 += 3 * n_elem_per_reg; + } + + for ( ; (i + 7) < m; i += 8 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); + + a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); + a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg ); + + a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); + a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + + y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); + + y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); + y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); + + y0 += 2 * n_elem_per_reg; + a0 += 2 * n_elem_per_reg; + a1 += 2 * n_elem_per_reg; + a2 += 2 * n_elem_per_reg; + a3 += 2 * n_elem_per_reg; + } + + + for ( ; (i + 3) < m; i += 4) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); + + a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); + + a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); + + a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + + y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); + + y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + + y0 += n_elem_per_reg; + a0 += n_elem_per_reg; + a1 += n_elem_per_reg; + a2 += n_elem_per_reg; + a3 += n_elem_per_reg; + } +#if 1 + for ( ; (i + 1) < m; i += 2) + { + + // Load the input values. + y4v.v = _mm_loadu_pd( y0 + 0*n_elem_per_reg ); + + a40v.v = _mm_loadu_pd( a0 + 0*n_elem_per_reg ); + + a41v.v = _mm_loadu_pd( a1 + 0*n_elem_per_reg ); + + a42v.v = _mm_loadu_pd( a2 + 0*n_elem_per_reg ); + + a43v.v = _mm_loadu_pd( a3 + 0*n_elem_per_reg ); + + // perform : y += alpha * x; + y4v.v = _mm_fmadd_pd( a40v.v, chi0v.xmm[0], y4v.v ); + + y4v.v = _mm_fmadd_pd( a41v.v, chi1v.xmm[0], y4v.v ); + + y4v.v = _mm_fmadd_pd( a42v.v, chi2v.xmm[0], y4v.v ); + + y4v.v = _mm_fmadd_pd( a43v.v, chi3v.xmm[0], y4v.v ); + + // Store the output. + _mm_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y4v.v ); + + y0 += 2; + a0 += 2; + a1 += 2; + a2 += 2; + a3 += 2; + } +#endif + // If there are leftover iterations, perform them with scalar code. + for ( ; (i + 0) < m ; ++i ) + { + double y0c = *y0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + + y0c += chi0 * a0c; + y0c += chi1 * a1c; + y0c += chi2 * a2c; + y0c += chi3 * a3c; + + *y0 = y0c; + + a0 += 1; + a1 += 1; + a2 += 1; + a3 += 1; + + y0 += 1; + } + } + else + { + for ( i = 0; (i + 0) < m ; ++i ) + { + double y0c = *y0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + + y0c += chi0 * a0c; + y0c += chi1 * a1c; + y0c += chi2 * a2c; + y0c += chi3 * a3c; + + *y0 = y0c; + + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + + y0 += incy; + } + + } +} // ----------------------------------------------------------------------------- diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index fa3db9d64..a16914352 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -95,6 +95,7 @@ SETV_KER_PROT(double, d, setv_zen_int) // axpyf (intrinsics) AXPYF_KER_PROT( float, s, axpyf_zen_int_8 ) AXPYF_KER_PROT( double, d, axpyf_zen_int_8 ) +AXPYF_KER_PROT( double, d, axpyf_zen_int_16x4 ) AXPYF_KER_PROT( float, s, axpyf_zen_int_5 ) AXPYF_KER_PROT( double, d, axpyf_zen_int_5 )