DDOTV Optimization for ZEN3 Architecture

- Reduced the blocking size of 'bli_ddotv_zen_int10'
  kernel from 40 elements to 20 elements for better
  utilization of vector registers

- Replaced redundant 'for' loops in 'bli_ddotv_zen_int10'
  kernel with 'if' conditions to handle reminder
  iterations. As only a single iteration is used when
  reminder is less than the primary unroll factor.

- Added a conditional check to invoke the vectorized
  DDOTV kernels directly(fast-path), without incurring
  any additional framework overhead.

- The fast-path is taken when the input size is ideal
  for single-threaded execution. Thus, we avoid the
  call to bli_nthreads_l1() function to set the ideal
  number of threads.

- Updated getestsuite ukr tests for 'bli_ddotv_zen_int10'
  kernel.

AMD-Internal: [CPUPL-4877]
Change-Id: If43f0fcff1c5b1563ad233005717398b5b6fb8f2
This commit is contained in:
Hari Govind S
2025-02-03 13:04:22 +05:30
parent bec9406996
commit 3d2653f1ab
3 changed files with 114 additions and 114 deletions

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2018 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -323,6 +323,13 @@ double ddot_blis_impl
cntx_t *cntx = NULL;
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
// Setting the threshold to invoke the fast-path
// The fast-path is intended to directly call the kernel
// in case the criteria for single threaded execution is met.
dim_t fast_path_thresh = 0;
#endif
// Query the architecture ID
arch_t arch_id_local = bli_arch_query_id();
@@ -330,19 +337,36 @@ double ddot_blis_impl
switch (arch_id_local)
{
case BLIS_ARCH_ZEN5:
case BLIS_ARCH_ZEN4:
#if defined(BLIS_KERNELS_ZEN4)
// AVX-512 Kernel
dotv_ker_ptr = bli_ddotv_zen_int_avx512;
break;
#if defined(BLIS_KERNELS_ZEN5)
// AVX-512 Kernel
dotv_ker_ptr = bli_ddotv_zen_int_avx512;
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
fast_path_thresh = 6600;
#endif
break;
#endif
case BLIS_ARCH_ZEN4:
#if defined(BLIS_KERNELS_ZEN4)
// AVX-512 Kernel
dotv_ker_ptr = bli_ddotv_zen_int_avx512;
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
fast_path_thresh = 5600;
#endif
break;
#endif
case BLIS_ARCH_ZEN:
case BLIS_ARCH_ZEN2:
case BLIS_ARCH_ZEN3:
// AVX2 Kernel
dotv_ker_ptr = bli_ddotv_zen_int10;
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
fast_path_thresh = 2500;
#endif
break;
default:
@@ -355,6 +379,33 @@ double ddot_blis_impl
}
#ifdef BLIS_ENABLE_OPENMP
#ifdef AOCL_DYNAMIC
/*
If the input size is less than ST_THRESH, the OpenMP and
'bli_nthreads_l1' overheads are avoided by invoking the
function directly. This ensures that performance of ddotv
does not drop for single thread when OpenMP is enabled.
*/
if (n_elem <= fast_path_thresh)
{
dotv_ker_ptr
(
BLIS_NO_CONJUGATE,
BLIS_NO_CONJUGATE,
n_elem,
x0, incx0,
y0, incy0,
&rho,
cntx
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
return rho;
}
#endif
/*
Initializing the number of thread to one
to avoid compiler warnings
@@ -395,9 +446,9 @@ double ddot_blis_impl
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
return rho;
#ifdef BLIS_ENABLE_OPENMP
#ifdef BLIS_ENABLE_OPENMP
}
/*

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -166,9 +166,8 @@ INSTANTIATE_TEST_SUITE_P(
// Tests for bli_ddotv_zen_int10 (AVX2) kernel.
/**
* Loops:
* L40 - Main loop, handles 40 elements
* L20 - handles 20 elements
* L16 - handles 16 elements
* L20 - Main loop, handles 20 elements
* L12 - handles 12 elements
* L8 - handles 8 elements
* L4 - handles 4 elements
* LScalar - leftover loop
@@ -188,26 +187,23 @@ INSTANTIATE_TEST_SUITE_P(
// m: size of vector.
::testing::Values(
// testing each loop individually.
gtint_t(80), // L40, executed twice
gtint_t(40), // L40
gtint_t(40), // L20, executed twice
gtint_t(20), // L20
gtint_t(16), // L16
gtint_t(12), // L12
gtint_t( 8), // L8
gtint_t( 4), // L4
gtint_t( 2), // LScalar
gtint_t( 1), // LScalar
// testing entire set of loops starting from loop m to n.
gtint_t(73), // L40 through LScalar, excludes L16
gtint_t(33), // L20 through LScalar, excludes L16
gtint_t(13), // L8 through LScalar
gtint_t( 5), // L4 through LScalar
gtint_t(25), // L20 + L4 + LScalar
gtint_t( 9), // L8 + LScalar
gtint_t( 5), // L4 + LScalar
// testing few combinations including L16.
gtint_t(77), // L40 + L20 + L16 + LScalar
gtint_t(76), // L40 + L20 + L16
gtint_t(57), // L40 + L16 + LScalar
gtint_t(37) // L20 + L16 + LScalar
// testing few combinations including L12.
gtint_t(37), // L20 + L12 + L4 + LScalar
gtint_t(36), // L20 + L12 + L4
gtint_t(17) // L12 + L4 + LScalar
),
// incx: stride of x vector.
::testing::Values(

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2016 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2018, The University of Texas at Austin
Redistribution and use in source and binary forms, with or without
@@ -277,9 +277,10 @@ void bli_ddotv_zen_int10
double rho0 = 0.0;
__m256d xv[10];
__m256d yv[10];
v4df_t rhov[10];
__m256d xv[5];
__m256d yv[5];
__m256d rhov[5];
v4df_t rh;
// If the vector dimension is zero, or if alpha is zero, return early.
if ( bli_zero_dim1( n ) )
@@ -296,64 +297,13 @@ void bli_ddotv_zen_int10
if ( incx == 1 && incy == 1 )
{
rhov[0].v = _mm256_setzero_pd();
rhov[1].v = _mm256_setzero_pd();
rhov[2].v = _mm256_setzero_pd();
rhov[3].v = _mm256_setzero_pd();
rhov[4].v = _mm256_setzero_pd();
rhov[5].v = _mm256_setzero_pd();
rhov[6].v = _mm256_setzero_pd();
rhov[7].v = _mm256_setzero_pd();
rhov[8].v = _mm256_setzero_pd();
rhov[9].v = _mm256_setzero_pd();
rhov[0] = _mm256_setzero_pd();
rhov[1] = _mm256_setzero_pd();
rhov[2] = _mm256_setzero_pd();
rhov[3] = _mm256_setzero_pd();
rhov[4] = _mm256_setzero_pd();
for ( i = 0; (i + 39) < n; i += 40 )
{
// 80 elements will be processed per loop; 10 FMAs will run per loop.
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg );
xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg );
xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg );
xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg );
xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg );
xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg );
xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg );
xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg );
xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg );
yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg );
yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg );
yv[5] = _mm256_loadu_pd( y0 + 5*n_elem_per_reg );
yv[6] = _mm256_loadu_pd( y0 + 6*n_elem_per_reg );
yv[7] = _mm256_loadu_pd( y0 + 7*n_elem_per_reg );
yv[8] = _mm256_loadu_pd( y0 + 8*n_elem_per_reg );
yv[9] = _mm256_loadu_pd( y0 + 9*n_elem_per_reg );
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v );
rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v );
rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v );
rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v );
rhov[5].v = _mm256_fmadd_pd( xv[5], yv[5], rhov[5].v );
rhov[6].v = _mm256_fmadd_pd( xv[6], yv[6], rhov[6].v );
rhov[7].v = _mm256_fmadd_pd( xv[7], yv[7], rhov[7].v );
rhov[8].v = _mm256_fmadd_pd( xv[8], yv[8], rhov[8].v );
rhov[9].v = _mm256_fmadd_pd( xv[9], yv[9], rhov[9].v );
x0 += 10*n_elem_per_reg;
y0 += 10*n_elem_per_reg;
}
rhov[0].v += rhov[5].v;
rhov[1].v += rhov[6].v;
rhov[2].v += rhov[7].v;
rhov[3].v += rhov[8].v;
rhov[4].v += rhov[9].v;
for ( ; (i + 19) < n; i += 20 )
for ( i = 0; (i + 19) < n; i += 20 )
{
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg );
@@ -367,43 +317,41 @@ void bli_ddotv_zen_int10
yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg );
yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg );
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v );
rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v );
rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v );
rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v );
rhov[0] = _mm256_fmadd_pd( xv[0], yv[0], rhov[0] );
rhov[1] = _mm256_fmadd_pd( xv[1], yv[1], rhov[1] );
rhov[2] = _mm256_fmadd_pd( xv[2], yv[2], rhov[2] );
rhov[3] = _mm256_fmadd_pd( xv[3], yv[3], rhov[3] );
rhov[4] = _mm256_fmadd_pd( xv[4], yv[4], rhov[4] );
x0 += 5*n_elem_per_reg;
y0 += 5*n_elem_per_reg;
}
rhov[0].v += rhov[4].v;
rhov[0] = _mm256_add_pd( rhov[3], rhov[0]) ;
rhov[1] = _mm256_add_pd( rhov[4], rhov[1]) ;
for ( ; (i + 15) < n; i += 16 )
if ( (i + 11) < n )
{
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg );
xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg );
xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg );
yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg );
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v );
rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v );
rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v );
rhov[0] = _mm256_fmadd_pd( xv[0], yv[0], rhov[0] );
rhov[1] = _mm256_fmadd_pd( xv[1], yv[1], rhov[1] );
rhov[2] = _mm256_fmadd_pd( xv[2], yv[2], rhov[2] );
x0 += 4*n_elem_per_reg;
y0 += 4*n_elem_per_reg;
x0 += 3*n_elem_per_reg;
y0 += 3*n_elem_per_reg;
i += 3*n_elem_per_reg;
}
rhov[0].v += rhov[2].v;
rhov[1].v += rhov[3].v;
rhov[0] = _mm256_add_pd( rhov[2], rhov[0]) ;
for ( ; (i + 7) < n; i += 8 )
if ( (i + 7) < n )
{
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg );
@@ -411,40 +359,45 @@ void bli_ddotv_zen_int10
yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v );
rhov[0] = _mm256_fmadd_pd( xv[0], yv[0], rhov[0] );
rhov[1] = _mm256_fmadd_pd( xv[1], yv[1], rhov[1] );
x0 += 2*n_elem_per_reg;
y0 += 2*n_elem_per_reg;
i += 2*n_elem_per_reg;
}
rhov[0].v += rhov[1].v;
rhov[0] = _mm256_add_pd( rhov[1], rhov[0]) ;
for ( ; (i + 3) < n; i += 4 )
if ( (i + 3) < n )
{
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
rhov[0] = _mm256_fmadd_pd( xv[0], yv[0], rhov[0] );
x0 += 1*n_elem_per_reg;
y0 += 1*n_elem_per_reg;
x0 += n_elem_per_reg;
y0 += n_elem_per_reg;
i += n_elem_per_reg;
}
if(i < n)
if( i < n )
{
__m256i maskVec = _mm256_loadu_si256( (__m256i *)mask_ptr[(n - i)]);
xv[0] = _mm256_maskload_pd( x0, maskVec );
yv[0] = _mm256_maskload_pd( y0, maskVec );
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
rhov[0] = _mm256_fmadd_pd( xv[0], yv[0], rhov[0] );
i = n;
}
// Manually add the results from above to finish the sum.
rho0 += rhov[0].d[0] + rhov[0].d[1] + rhov[0].d[2] + rhov[0].d[3];
// Perform horizontal addition of the elements in the vector.
rh.v = _mm256_hadd_pd( rhov[0], rhov[0] );
// Manually add the first and third element from above vector to finish the sum.
rho0 += rh.d[0] + rh.d[2];
// Issue vzeroupper instruction to clear upper lanes of ymm registers.
// This avoids a performance penalty caused by false dependencies when