From d116780616bc13cc59d7ac2dad9fddb923ebb02b Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Fri, 17 Dec 2021 02:34:52 -0600 Subject: [PATCH] Optimized dher2 implementation - Impplemented her2 framework calls for transposed and non transposed kernel variants. - dher2 kernel operate over 4 columns at a time. It computes 4x4 triangular part of matrix first and remainder part is computed in chunk of 4x4 tile upto m rows. - remainder cases(m < 4) are handled serially. AMD-Internal: [CPUPL-1968] Change-Id: I12ae97b2ad673a7fd9b733c607f27b1089142313 --- frame/2/her2/bli_her2_unf_var1.c | 214 ++++++++++++++++++++++++++++++- frame/2/her2/bli_her2_unf_var4.c | 189 ++++++++++++++++++++++++++- 2 files changed, 401 insertions(+), 2 deletions(-) diff --git a/frame/2/her2/bli_her2_unf_var1.c b/frame/2/her2/bli_her2_unf_var1.c index a0aec48f7..299e3d161 100644 --- a/frame/2/her2/bli_her2_unf_var1.c +++ b/frame/2/her2/bli_her2_unf_var1.c @@ -158,5 +158,217 @@ void PASTEMAC(ch,varname) \ } \ } -INSERT_GENTFUNC_BASIC0( her2_unf_var1 ) + +#ifdef BLIS_CONFIG_EPYC + +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_trans_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var1 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* x0; + double* chi1; + double* y0; + double* psi1; + double* c10t; + double* gamma11; + double alpha0; + double alpha1; + double alpha0_chi1; + double alpha1_psi1; + double alpha0_chi1_psi1; + double conjx0_chi1; + double conjy1_psi1; + double conjy0_psi1; + dim_t i; + dim_t n_behind; + inc_t rs_ct, cs_ct; + conj_t conj0, conj1; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + PASTEMAC(d,copycjs)( conjh, *alpha, alpha1 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + conjx = bli_apply_conj( conjh, conjx ); + conjy = bli_apply_conj( conjh, conjy ); + + PASTEMAC(d,copycjs)( conjh, *alpha, alpha0 ); + PASTEMAC(d,copys)( *alpha, alpha1 ); + } + + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + conj0 = bli_apply_conj( conjh, conjy ); + conj1 = bli_apply_conj( conjh, conjx ); + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if( (incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + if((n_behind >= 3)) + { + bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); + i+=4; + } + else + { + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + } + } +} + +GENTFUNC(float, s, her2_unf_var1) +GENTFUNC(scomplex, c, her2_unf_var1) +GENTFUNC(dcomplex, z,her2_unf_var1) +#else +INSERT_GENTFUNC_BASIC0( her2_unf_var1 ) +#endif diff --git a/frame/2/her2/bli_her2_unf_var4.c b/frame/2/her2/bli_her2_unf_var4.c index 3dea31d53..e39c7224c 100644 --- a/frame/2/her2/bli_her2_unf_var4.c +++ b/frame/2/her2/bli_her2_unf_var4.c @@ -166,5 +166,192 @@ void PASTEMAC(ch,varname) \ } \ } -INSERT_GENTFUNC_BASIC0( her2_unf_var4 ) +#ifdef BLIS_CONFIG_EPYC +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var4 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + + double* chi1; + double* x2; + double* psi1; + double* y2; + double* gamma11; + double* c21; + double alpha0; + double alpha0_psi1; + double alpha1_chi1; + double alpha0_chi1_psi1; + dim_t i; + dim_t n_ahead; + inc_t rs_ct, cs_ct; + + const num_t dt = PASTEMAC(d,type); + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if((incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + if((n_ahead >= 3)) + { + bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); + i+= 4; + } + else + { + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + } + } +} + +GENTFUNC(float, s, her2_unf_var4) +GENTFUNC(scomplex, c, her2_unf_var4) +GENTFUNC(dcomplex, z,her2_unf_var4) +#else +INSERT_GENTFUNC_BASIC0( her2_unf_var4 ) +#endif