Optimized "bli_dotv_zen_int10" kernels

Details:
- Fixed issues in "bli_dotv_zen_int10" kernels and optimized them.
- Changed cntx_init file to choose "bli_dotv_zen_int10" kernel for dotv
 API call.

Change-Id: Iee8d7519f3a22a2d41166390be6047e9cb37557f
Signed-off-by: Meghana Vankadari <Meghana.Vankadari@amd.com>
AMD-Internal: [CPUPL-824]
This commit is contained in:
Meghana
2020-04-10 12:20:10 +05:30
parent b5fe75e104
commit e56cf63a3f
2 changed files with 84 additions and 66 deletions

View File

@@ -89,8 +89,8 @@ void bli_cntx_init_zen2( cntx_t* cntx )
BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10,
// dotv
BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int,
BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int,
BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10,
BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10,
// dotxv
BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int,
BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int,

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc.
Copyright (C) 2016 - 2020, Advanced Micro Devices, Inc.
Copyright (C) 2018, The University of Texas at Austin
Redistribution and use in source and binary forms, with or without
@@ -73,11 +73,11 @@ void bli_sdotv_zen_int10
float* restrict x0;
float* restrict y0;
float rho0;
float rho0 = 0.0;
__m256 xv[10];
__m256 yv[10];
v8sf_t rhov[2];
v8sf_t rhov[10];
// If the vector dimension is zero, or if alpha is zero, return early.
if ( bli_zero_dim1( n ) )
@@ -96,8 +96,16 @@ void bli_sdotv_zen_int10
{
rhov[0].v = _mm256_setzero_ps();
rhov[1].v = _mm256_setzero_ps();
rhov[2].v = _mm256_setzero_ps();
rhov[3].v = _mm256_setzero_ps();
rhov[4].v = _mm256_setzero_ps();
rhov[5].v = _mm256_setzero_ps();
rhov[6].v = _mm256_setzero_ps();
rhov[7].v = _mm256_setzero_ps();
rhov[8].v = _mm256_setzero_ps();
rhov[9].v = _mm256_setzero_ps();
for ( i = 0; (i + 79) < n; i += 80 )
for (i=0 ; (i + 79) < n; i += 80 )
{
// 80 elements will be processed per loop; 10 FMAs will run per loop.
xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg );
@@ -124,19 +132,25 @@ void bli_sdotv_zen_int10
rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v );
rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v );
rhov[0].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[0].v );
rhov[1].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[1].v );
rhov[0].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[0].v );
rhov[1].v = _mm256_fmadd_ps( xv[5], yv[5], rhov[1].v );
rhov[0].v = _mm256_fmadd_ps( xv[6], yv[6], rhov[0].v );
rhov[1].v = _mm256_fmadd_ps( xv[7], yv[7], rhov[1].v );
rhov[0].v = _mm256_fmadd_ps( xv[8], yv[8], rhov[0].v );
rhov[1].v = _mm256_fmadd_ps( xv[9], yv[9], rhov[1].v );
rhov[2].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[2].v );
rhov[3].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[3].v );
rhov[4].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[4].v );
rhov[5].v = _mm256_fmadd_ps( xv[5], yv[5], rhov[5].v );
rhov[6].v = _mm256_fmadd_ps( xv[6], yv[6], rhov[6].v );
rhov[7].v = _mm256_fmadd_ps( xv[7], yv[7], rhov[7].v );
rhov[8].v = _mm256_fmadd_ps( xv[8], yv[8], rhov[8].v );
rhov[9].v = _mm256_fmadd_ps( xv[9], yv[9], rhov[9].v );
x0 += 10*n_elem_per_reg;
y0 += 10*n_elem_per_reg;
}
rhov[0].v += rhov[5].v;
rhov[1].v += rhov[6].v;
rhov[2].v += rhov[7].v;
rhov[3].v += rhov[8].v;
rhov[4].v += rhov[9].v;
for ( ; (i + 39) < n; i += 40 )
{
xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg );
@@ -153,35 +167,18 @@ void bli_sdotv_zen_int10
rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v );
rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v );
rhov[0].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[0].v );
rhov[1].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[1].v );
rhov[0].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[0].v );
rhov[2].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[2].v );
rhov[3].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[3].v );
rhov[4].v = _mm256_fmadd_ps( xv[4], yv[4], rhov[4].v );
x0 += 5*n_elem_per_reg;
y0 += 5*n_elem_per_reg;
}
for ( ; (i + 31) < n; i += 32 )
{
xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg );
xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg );
xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg );
xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg );
yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg );
yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg );
yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg );
yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg );
rhov[0].v = _mm256_fmadd_ps( xv[0], yv[0], rhov[0].v );
rhov[1].v = _mm256_fmadd_ps( xv[1], yv[1], rhov[1].v );
rhov[0].v = _mm256_fmadd_ps( xv[2], yv[2], rhov[0].v );
rhov[1].v = _mm256_fmadd_ps( xv[3], yv[3], rhov[1].v );
x0 += 4*n_elem_per_reg;
y0 += 4*n_elem_per_reg;
}
rhov[0].v += rhov[2].v;
rhov[1].v += rhov[3].v;
rhov[0].v += rhov[4].v;
for ( ; (i + 15) < n; i += 16 )
{
xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg );
@@ -197,6 +194,8 @@ void bli_sdotv_zen_int10
y0 += 2*n_elem_per_reg;
}
rhov[0].v += rhov[1].v;
for ( ; (i + 7) < n; i += 8 )
{
xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg );
@@ -211,19 +210,15 @@ void bli_sdotv_zen_int10
for ( ; (i + 0) < n; i += 1 )
{
rhov[0].f[0] += x0[i] * y0[i];
rho0 += (*x0) * (*y0);
x0 += 1;
y0 += 1;
}
v8sf_t onev;
onev.v = _mm256_set1_ps( 1.0f );
rhov[0].v = _mm256_dp_ps( rhov[0].v, onev.v, 0xf1 );
rhov[1].v = _mm256_dp_ps( rhov[1].v, onev.v, 0xf1 );
// Manually add the results from above to finish the sum.
rho0 += rhov[0].f[0] + rhov[0].f[4];
rho0 += rhov[1].f[0] + rhov[1].f[4];
rho0 += rhov[0].f[0] + rhov[0].f[1] +
rhov[0].f[2] + rhov[0].f[3] +
rhov[0].f[4] + rhov[0].f[5] +
rhov[0].f[6] + rhov[0].f[7];
// Issue vzeroupper instruction to clear upper lanes of ymm registers.
// This avoids a performance penalty caused by false dependencies when
@@ -269,11 +264,11 @@ void bli_ddotv_zen_int10
double* restrict x0;
double* restrict y0;
double rho0;
double rho0 = 0.0;
__m256d xv[10];
__m256d yv[10];
v4df_t rhov[2];
v4df_t rhov[10];
// If the vector dimension is zero, or if alpha is zero, return early.
if ( bli_zero_dim1( n ) )
@@ -292,6 +287,14 @@ void bli_ddotv_zen_int10
{
rhov[0].v = _mm256_setzero_pd();
rhov[1].v = _mm256_setzero_pd();
rhov[2].v = _mm256_setzero_pd();
rhov[3].v = _mm256_setzero_pd();
rhov[4].v = _mm256_setzero_pd();
rhov[5].v = _mm256_setzero_pd();
rhov[6].v = _mm256_setzero_pd();
rhov[7].v = _mm256_setzero_pd();
rhov[8].v = _mm256_setzero_pd();
rhov[9].v = _mm256_setzero_pd();
for ( i = 0; (i + 39) < n; i += 40 )
{
@@ -320,19 +323,25 @@ void bli_ddotv_zen_int10
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v );
rhov[0].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[1].v );
rhov[0].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[5], yv[5], rhov[1].v );
rhov[0].v = _mm256_fmadd_pd( xv[6], yv[6], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[7], yv[7], rhov[1].v );
rhov[0].v = _mm256_fmadd_pd( xv[8], yv[8], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[9], yv[9], rhov[1].v );
rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v );
rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v );
rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v );
rhov[5].v = _mm256_fmadd_pd( xv[5], yv[5], rhov[5].v );
rhov[6].v = _mm256_fmadd_pd( xv[6], yv[6], rhov[6].v );
rhov[7].v = _mm256_fmadd_pd( xv[7], yv[7], rhov[7].v );
rhov[8].v = _mm256_fmadd_pd( xv[8], yv[8], rhov[8].v );
rhov[9].v = _mm256_fmadd_pd( xv[9], yv[9], rhov[9].v );
x0 += 10*n_elem_per_reg;
y0 += 10*n_elem_per_reg;
}
rhov[0].v += rhov[5].v;
rhov[1].v += rhov[6].v;
rhov[2].v += rhov[7].v;
rhov[3].v += rhov[8].v;
rhov[4].v += rhov[9].v;
for ( ; (i + 19) < n; i += 20 )
{
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
@@ -349,14 +358,16 @@ void bli_ddotv_zen_int10
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v );
rhov[0].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[1].v );
rhov[0].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[0].v );
rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v );
rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v );
rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v );
x0 += 5*n_elem_per_reg;
y0 += 5*n_elem_per_reg;
}
rhov[0].v += rhov[4].v;
for ( ; (i + 15) < n; i += 16 )
{
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
@@ -371,13 +382,16 @@ void bli_ddotv_zen_int10
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v );
rhov[0].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[1].v );
rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v );
rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v );
x0 += 4*n_elem_per_reg;
y0 += 4*n_elem_per_reg;
}
rhov[0].v += rhov[2].v;
rhov[1].v += rhov[3].v;
for ( ; (i + 7) < n; i += 8 )
{
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
@@ -393,6 +407,8 @@ void bli_ddotv_zen_int10
y0 += 2*n_elem_per_reg;
}
rhov[0].v += rhov[1].v;
for ( ; (i + 3) < n; i += 4 )
{
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
@@ -407,12 +423,14 @@ void bli_ddotv_zen_int10
for ( ; (i + 0) < n; i += 1 )
{
rhov[0].d[0] += x0[i] * y0[i];
rho0 += (*x0) * (*y0);
x0 += 1;
y0 += 1;
}
// Manually add the results from above to finish the sum.
rho0 += rhov[0].d[0] + rhov[0].d[1] + rhov[0].d[2] + rhov[0].d[3];
rho0 += rhov[1].d[0] + rhov[1].d[1] + rhov[1].d[2] + rhov[1].d[3];
// Issue vzeroupper instruction to clear upper lanes of ymm registers.
// This avoids a performance penalty caused by false dependencies when