From e56cf63a3f72259bc68003c33eaee9a00dd024b3 Mon Sep 17 00:00:00 2001 From: Meghana Date: Fri, 10 Apr 2020 12:20:10 +0530 Subject: [PATCH] Optimized "bli_dotv_zen_int10" kernels Details: - Fixed issues in "bli_dotv_zen_int10" kernels and optimized them. - Changed cntx_init file to choose "bli_dotv_zen_int10" kernel for dotv API call. Change-Id: Iee8d7519f3a22a2d41166390be6047e9cb37557f Signed-off-by: Meghana Vankadari AMD-Internal: [CPUPL-824] --- config/zen2/bli_cntx_init_zen2.c | 4 +- kernels/zen/1/bli_dotv_zen_int10.c | 146 ++++++++++++++++------------- 2 files changed, 84 insertions(+), 66 deletions(-) diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index f7c5a8346..c85628eb9 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -89,8 +89,8 @@ void bli_cntx_init_zen2( cntx_t* cntx ) BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, // dotv - BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int, - BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int, + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, diff --git a/kernels/zen/1/bli_dotv_zen_int10.c b/kernels/zen/1/bli_dotv_zen_int10.c index e4b980362..7a67c1247 100644 --- a/kernels/zen/1/bli_dotv_zen_int10.c +++ b/kernels/zen/1/bli_dotv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2020, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -73,11 +73,11 @@ void bli_sdotv_zen_int10 float* restrict x0; float* restrict y0; - float rho0; + float rho0 = 0.0; __m256 xv[10]; __m256 yv[10]; - v8sf_t rhov[2]; + v8sf_t rhov[10]; // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) ) @@ -96,8 +96,16 @@ void bli_sdotv_zen_int10 { rhov[0].v = _mm256_setzero_ps(); rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); + rhov[4].v = _mm256_setzero_ps(); + rhov[5].v = _mm256_setzero_ps(); + rhov[6].v = _mm256_setzero_ps(); + rhov[7].v = _mm256_setzero_ps(); + rhov[8].v = _mm256_setzero_ps(); + rhov[9].v = _mm256_setzero_ps(); - for ( i = 0; (i + 79) < n; i += 80 ) + for (i=0 ; (i + 79) < n; i += 80 ) { // 80 elements will be processed per loop; 10 FMAs will run per loop. xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); @@ -124,19 +132,25 @@ void bli_sdotv_zen_int10 rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[5], yv[5], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[6], yv[6], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[7], yv[7], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[8], yv[8], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[9], yv[9], rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[4].v ); + rhov[5].v = _mm256_fmadd_ps( xv[5], yv[5], rhov[5].v ); + rhov[6].v = _mm256_fmadd_ps( xv[6], yv[6], rhov[6].v ); + rhov[7].v = _mm256_fmadd_ps( xv[7], yv[7], rhov[7].v ); + rhov[8].v = _mm256_fmadd_ps( xv[8], yv[8], rhov[8].v ); + rhov[9].v = _mm256_fmadd_ps( xv[9], yv[9], rhov[9].v ); x0 += 10*n_elem_per_reg; y0 += 10*n_elem_per_reg; } + rhov[0].v += rhov[5].v; + rhov[1].v += rhov[6].v; + rhov[2].v += rhov[7].v; + rhov[3].v += rhov[8].v; + rhov[4].v += rhov[9].v; + for ( ; (i + 39) < n; i += 40 ) { xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); @@ -153,35 +167,18 @@ void bli_sdotv_zen_int10 rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[0].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[4].v ); x0 += 5*n_elem_per_reg; y0 += 5*n_elem_per_reg; } - for ( ; (i + 31) < n; i += 32 ) - { - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - - yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); - - rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v ); - rhov[0].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[0].v ); - rhov[1].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[1].v ); - - x0 += 4*n_elem_per_reg; - y0 += 4*n_elem_per_reg; - } - + rhov[0].v += rhov[2].v; + rhov[1].v += rhov[3].v; + rhov[0].v += rhov[4].v; + for ( ; (i + 15) < n; i += 16 ) { xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); @@ -197,6 +194,8 @@ void bli_sdotv_zen_int10 y0 += 2*n_elem_per_reg; } + rhov[0].v += rhov[1].v; + for ( ; (i + 7) < n; i += 8 ) { xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); @@ -211,19 +210,15 @@ void bli_sdotv_zen_int10 for ( ; (i + 0) < n; i += 1 ) { - rhov[0].f[0] += x0[i] * y0[i]; + rho0 += (*x0) * (*y0); + x0 += 1; + y0 += 1; } - v8sf_t onev; - - onev.v = _mm256_set1_ps( 1.0f ); - - rhov[0].v = _mm256_dp_ps( rhov[0].v, onev.v, 0xf1 ); - rhov[1].v = _mm256_dp_ps( rhov[1].v, onev.v, 0xf1 ); - - // Manually add the results from above to finish the sum. - rho0 += rhov[0].f[0] + rhov[0].f[4]; - rho0 += rhov[1].f[0] + rhov[1].f[4]; + rho0 += rhov[0].f[0] + rhov[0].f[1] + + rhov[0].f[2] + rhov[0].f[3] + + rhov[0].f[4] + rhov[0].f[5] + + rhov[0].f[6] + rhov[0].f[7]; // Issue vzeroupper instruction to clear upper lanes of ymm registers. // This avoids a performance penalty caused by false dependencies when @@ -269,11 +264,11 @@ void bli_ddotv_zen_int10 double* restrict x0; double* restrict y0; - double rho0; + double rho0 = 0.0; __m256d xv[10]; __m256d yv[10]; - v4df_t rhov[2]; + v4df_t rhov[10]; // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) ) @@ -292,6 +287,14 @@ void bli_ddotv_zen_int10 { rhov[0].v = _mm256_setzero_pd(); rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + rhov[8].v = _mm256_setzero_pd(); + rhov[9].v = _mm256_setzero_pd(); for ( i = 0; (i + 39) < n; i += 40 ) { @@ -320,19 +323,25 @@ void bli_ddotv_zen_int10 rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[5], yv[5], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[6], yv[6], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[7], yv[7], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[8], yv[8], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[9], yv[9], rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v ); + rhov[5].v = _mm256_fmadd_pd( xv[5], yv[5], rhov[5].v ); + rhov[6].v = _mm256_fmadd_pd( xv[6], yv[6], rhov[6].v ); + rhov[7].v = _mm256_fmadd_pd( xv[7], yv[7], rhov[7].v ); + rhov[8].v = _mm256_fmadd_pd( xv[8], yv[8], rhov[8].v ); + rhov[9].v = _mm256_fmadd_pd( xv[9], yv[9], rhov[9].v ); x0 += 10*n_elem_per_reg; y0 += 10*n_elem_per_reg; } + rhov[0].v += rhov[5].v; + rhov[1].v += rhov[6].v; + rhov[2].v += rhov[7].v; + rhov[3].v += rhov[8].v; + rhov[4].v += rhov[9].v; + for ( ; (i + 19) < n; i += 20 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); @@ -349,14 +358,16 @@ void bli_ddotv_zen_int10 rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[0].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); + rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v ); x0 += 5*n_elem_per_reg; y0 += 5*n_elem_per_reg; } + rhov[0].v += rhov[4].v; + for ( ; (i + 15) < n; i += 16 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); @@ -371,13 +382,16 @@ void bli_ddotv_zen_int10 rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v ); - rhov[0].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[0].v ); - rhov[1].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v ); x0 += 4*n_elem_per_reg; y0 += 4*n_elem_per_reg; } + rhov[0].v += rhov[2].v; + rhov[1].v += rhov[3].v; + for ( ; (i + 7) < n; i += 8 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); @@ -393,6 +407,8 @@ void bli_ddotv_zen_int10 y0 += 2*n_elem_per_reg; } + rhov[0].v += rhov[1].v; + for ( ; (i + 3) < n; i += 4 ) { xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); @@ -407,12 +423,14 @@ void bli_ddotv_zen_int10 for ( ; (i + 0) < n; i += 1 ) { - rhov[0].d[0] += x0[i] * y0[i]; + rho0 += (*x0) * (*y0); + + x0 += 1; + y0 += 1; } // Manually add the results from above to finish the sum. rho0 += rhov[0].d[0] + rhov[0].d[1] + rhov[0].d[2] + rhov[0].d[3]; - rho0 += rhov[1].d[0] + rhov[1].d[1] + rhov[1].d[2] + rhov[1].d[3]; // Issue vzeroupper instruction to clear upper lanes of ymm registers. // This avoids a performance penalty caused by false dependencies when