mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
DDOTV Optimization for ZEN3 Architecture
- Reduced the blocking size of 'bli_ddotv_zen_int10' kernel from 40 elements to 20 elements for better utilization of vector registers - Replaced redundant 'for' loops in 'bli_ddotv_zen_int10' kernel with 'if' conditions to handle reminder iterations. As only a single iteration is used when reminder is less than the primary unroll factor. - Added a conditional check to invoke the vectorized DDOTV kernels directly(fast-path), without incurring any additional framework overhead. - The fast-path is taken when the input size is ideal for single-threaded execution. Thus, we avoid the call to bli_nthreads_l1() function to set the ideal number of threads. - Updated getestsuite ukr tests for 'bli_ddotv_zen_int10' kernel. AMD-Internal: [CPUPL-4877] Change-Id: If43f0fcff1c5b1563ad233005717398b5b6fb8f2
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2018 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2018 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -323,6 +323,13 @@ double ddot_blis_impl
|
||||
|
||||
cntx_t *cntx = NULL;
|
||||
|
||||
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
|
||||
// Setting the threshold to invoke the fast-path
|
||||
// The fast-path is intended to directly call the kernel
|
||||
// in case the criteria for single threaded execution is met.
|
||||
dim_t fast_path_thresh = 0;
|
||||
#endif
|
||||
|
||||
// Query the architecture ID
|
||||
arch_t arch_id_local = bli_arch_query_id();
|
||||
|
||||
@@ -330,19 +337,36 @@ double ddot_blis_impl
|
||||
switch (arch_id_local)
|
||||
{
|
||||
case BLIS_ARCH_ZEN5:
|
||||
case BLIS_ARCH_ZEN4:
|
||||
#if defined(BLIS_KERNELS_ZEN4)
|
||||
|
||||
// AVX-512 Kernel
|
||||
dotv_ker_ptr = bli_ddotv_zen_int_avx512;
|
||||
break;
|
||||
#if defined(BLIS_KERNELS_ZEN5)
|
||||
// AVX-512 Kernel
|
||||
dotv_ker_ptr = bli_ddotv_zen_int_avx512;
|
||||
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
|
||||
fast_path_thresh = 6600;
|
||||
#endif
|
||||
break;
|
||||
#endif
|
||||
|
||||
case BLIS_ARCH_ZEN4:
|
||||
|
||||
#if defined(BLIS_KERNELS_ZEN4)
|
||||
// AVX-512 Kernel
|
||||
dotv_ker_ptr = bli_ddotv_zen_int_avx512;
|
||||
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
|
||||
fast_path_thresh = 5600;
|
||||
#endif
|
||||
break;
|
||||
#endif
|
||||
|
||||
case BLIS_ARCH_ZEN:
|
||||
case BLIS_ARCH_ZEN2:
|
||||
case BLIS_ARCH_ZEN3:
|
||||
|
||||
// AVX2 Kernel
|
||||
dotv_ker_ptr = bli_ddotv_zen_int10;
|
||||
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
|
||||
fast_path_thresh = 2500;
|
||||
#endif
|
||||
break;
|
||||
|
||||
default:
|
||||
@@ -355,6 +379,33 @@ double ddot_blis_impl
|
||||
}
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
#ifdef AOCL_DYNAMIC
|
||||
|
||||
/*
|
||||
If the input size is less than ST_THRESH, the OpenMP and
|
||||
'bli_nthreads_l1' overheads are avoided by invoking the
|
||||
function directly. This ensures that performance of ddotv
|
||||
does not drop for single thread when OpenMP is enabled.
|
||||
*/
|
||||
if (n_elem <= fast_path_thresh)
|
||||
{
|
||||
dotv_ker_ptr
|
||||
(
|
||||
BLIS_NO_CONJUGATE,
|
||||
BLIS_NO_CONJUGATE,
|
||||
n_elem,
|
||||
x0, incx0,
|
||||
y0, incy0,
|
||||
&rho,
|
||||
cntx
|
||||
);
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
|
||||
return rho;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
/*
|
||||
Initializing the number of thread to one
|
||||
to avoid compiler warnings
|
||||
@@ -395,9 +446,9 @@ double ddot_blis_impl
|
||||
);
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
|
||||
|
||||
return rho;
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -166,9 +166,8 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
// Tests for bli_ddotv_zen_int10 (AVX2) kernel.
|
||||
/**
|
||||
* Loops:
|
||||
* L40 - Main loop, handles 40 elements
|
||||
* L20 - handles 20 elements
|
||||
* L16 - handles 16 elements
|
||||
* L20 - Main loop, handles 20 elements
|
||||
* L12 - handles 12 elements
|
||||
* L8 - handles 8 elements
|
||||
* L4 - handles 4 elements
|
||||
* LScalar - leftover loop
|
||||
@@ -188,26 +187,23 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
// m: size of vector.
|
||||
::testing::Values(
|
||||
// testing each loop individually.
|
||||
gtint_t(80), // L40, executed twice
|
||||
gtint_t(40), // L40
|
||||
gtint_t(40), // L20, executed twice
|
||||
gtint_t(20), // L20
|
||||
gtint_t(16), // L16
|
||||
gtint_t(12), // L12
|
||||
gtint_t( 8), // L8
|
||||
gtint_t( 4), // L4
|
||||
gtint_t( 2), // LScalar
|
||||
gtint_t( 1), // LScalar
|
||||
|
||||
// testing entire set of loops starting from loop m to n.
|
||||
gtint_t(73), // L40 through LScalar, excludes L16
|
||||
gtint_t(33), // L20 through LScalar, excludes L16
|
||||
gtint_t(13), // L8 through LScalar
|
||||
gtint_t( 5), // L4 through LScalar
|
||||
gtint_t(25), // L20 + L4 + LScalar
|
||||
gtint_t( 9), // L8 + LScalar
|
||||
gtint_t( 5), // L4 + LScalar
|
||||
|
||||
// testing few combinations including L16.
|
||||
gtint_t(77), // L40 + L20 + L16 + LScalar
|
||||
gtint_t(76), // L40 + L20 + L16
|
||||
gtint_t(57), // L40 + L16 + LScalar
|
||||
gtint_t(37) // L20 + L16 + LScalar
|
||||
// testing few combinations including L12.
|
||||
gtint_t(37), // L20 + L12 + L4 + LScalar
|
||||
gtint_t(36), // L20 + L12 + L4
|
||||
gtint_t(17) // L12 + L4 + LScalar
|
||||
),
|
||||
// incx: stride of x vector.
|
||||
::testing::Values(
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2016 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2018, The University of Texas at Austin
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
@@ -277,9 +277,10 @@ void bli_ddotv_zen_int10
|
||||
|
||||
double rho0 = 0.0;
|
||||
|
||||
__m256d xv[10];
|
||||
__m256d yv[10];
|
||||
v4df_t rhov[10];
|
||||
__m256d xv[5];
|
||||
__m256d yv[5];
|
||||
__m256d rhov[5];
|
||||
v4df_t rh;
|
||||
|
||||
// If the vector dimension is zero, or if alpha is zero, return early.
|
||||
if ( bli_zero_dim1( n ) )
|
||||
@@ -296,64 +297,13 @@ void bli_ddotv_zen_int10
|
||||
|
||||
if ( incx == 1 && incy == 1 )
|
||||
{
|
||||
rhov[0].v = _mm256_setzero_pd();
|
||||
rhov[1].v = _mm256_setzero_pd();
|
||||
rhov[2].v = _mm256_setzero_pd();
|
||||
rhov[3].v = _mm256_setzero_pd();
|
||||
rhov[4].v = _mm256_setzero_pd();
|
||||
rhov[5].v = _mm256_setzero_pd();
|
||||
rhov[6].v = _mm256_setzero_pd();
|
||||
rhov[7].v = _mm256_setzero_pd();
|
||||
rhov[8].v = _mm256_setzero_pd();
|
||||
rhov[9].v = _mm256_setzero_pd();
|
||||
rhov[0] = _mm256_setzero_pd();
|
||||
rhov[1] = _mm256_setzero_pd();
|
||||
rhov[2] = _mm256_setzero_pd();
|
||||
rhov[3] = _mm256_setzero_pd();
|
||||
rhov[4] = _mm256_setzero_pd();
|
||||
|
||||
for ( i = 0; (i + 39) < n; i += 40 )
|
||||
{
|
||||
// 80 elements will be processed per loop; 10 FMAs will run per loop.
|
||||
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
|
||||
xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg );
|
||||
xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg );
|
||||
xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg );
|
||||
xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg );
|
||||
xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg );
|
||||
xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg );
|
||||
xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg );
|
||||
xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg );
|
||||
xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg );
|
||||
|
||||
yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
|
||||
yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
|
||||
yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg );
|
||||
yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg );
|
||||
yv[5] = _mm256_loadu_pd( y0 + 5*n_elem_per_reg );
|
||||
yv[6] = _mm256_loadu_pd( y0 + 6*n_elem_per_reg );
|
||||
yv[7] = _mm256_loadu_pd( y0 + 7*n_elem_per_reg );
|
||||
yv[8] = _mm256_loadu_pd( y0 + 8*n_elem_per_reg );
|
||||
yv[9] = _mm256_loadu_pd( y0 + 9*n_elem_per_reg );
|
||||
|
||||
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
|
||||
rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v );
|
||||
rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v );
|
||||
rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v );
|
||||
rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v );
|
||||
rhov[5].v = _mm256_fmadd_pd( xv[5], yv[5], rhov[5].v );
|
||||
rhov[6].v = _mm256_fmadd_pd( xv[6], yv[6], rhov[6].v );
|
||||
rhov[7].v = _mm256_fmadd_pd( xv[7], yv[7], rhov[7].v );
|
||||
rhov[8].v = _mm256_fmadd_pd( xv[8], yv[8], rhov[8].v );
|
||||
rhov[9].v = _mm256_fmadd_pd( xv[9], yv[9], rhov[9].v );
|
||||
|
||||
x0 += 10*n_elem_per_reg;
|
||||
y0 += 10*n_elem_per_reg;
|
||||
}
|
||||
|
||||
rhov[0].v += rhov[5].v;
|
||||
rhov[1].v += rhov[6].v;
|
||||
rhov[2].v += rhov[7].v;
|
||||
rhov[3].v += rhov[8].v;
|
||||
rhov[4].v += rhov[9].v;
|
||||
|
||||
for ( ; (i + 19) < n; i += 20 )
|
||||
for ( i = 0; (i + 19) < n; i += 20 )
|
||||
{
|
||||
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
|
||||
xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg );
|
||||
@@ -367,43 +317,41 @@ void bli_ddotv_zen_int10
|
||||
yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg );
|
||||
yv[4] = _mm256_loadu_pd( y0 + 4*n_elem_per_reg );
|
||||
|
||||
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
|
||||
rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v );
|
||||
rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v );
|
||||
rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v );
|
||||
rhov[4].v = _mm256_fmadd_pd( xv[4], yv[4], rhov[4].v );
|
||||
rhov[0] = _mm256_fmadd_pd( xv[0], yv[0], rhov[0] );
|
||||
rhov[1] = _mm256_fmadd_pd( xv[1], yv[1], rhov[1] );
|
||||
rhov[2] = _mm256_fmadd_pd( xv[2], yv[2], rhov[2] );
|
||||
rhov[3] = _mm256_fmadd_pd( xv[3], yv[3], rhov[3] );
|
||||
rhov[4] = _mm256_fmadd_pd( xv[4], yv[4], rhov[4] );
|
||||
|
||||
x0 += 5*n_elem_per_reg;
|
||||
y0 += 5*n_elem_per_reg;
|
||||
}
|
||||
|
||||
rhov[0].v += rhov[4].v;
|
||||
rhov[0] = _mm256_add_pd( rhov[3], rhov[0]) ;
|
||||
rhov[1] = _mm256_add_pd( rhov[4], rhov[1]) ;
|
||||
|
||||
for ( ; (i + 15) < n; i += 16 )
|
||||
if ( (i + 11) < n )
|
||||
{
|
||||
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
|
||||
xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg );
|
||||
xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg );
|
||||
xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg );
|
||||
|
||||
yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
|
||||
yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg );
|
||||
yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg );
|
||||
|
||||
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
|
||||
rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v );
|
||||
rhov[2].v = _mm256_fmadd_pd( xv[2], yv[2], rhov[2].v );
|
||||
rhov[3].v = _mm256_fmadd_pd( xv[3], yv[3], rhov[3].v );
|
||||
rhov[0] = _mm256_fmadd_pd( xv[0], yv[0], rhov[0] );
|
||||
rhov[1] = _mm256_fmadd_pd( xv[1], yv[1], rhov[1] );
|
||||
rhov[2] = _mm256_fmadd_pd( xv[2], yv[2], rhov[2] );
|
||||
|
||||
x0 += 4*n_elem_per_reg;
|
||||
y0 += 4*n_elem_per_reg;
|
||||
x0 += 3*n_elem_per_reg;
|
||||
y0 += 3*n_elem_per_reg;
|
||||
i += 3*n_elem_per_reg;
|
||||
}
|
||||
|
||||
rhov[0].v += rhov[2].v;
|
||||
rhov[1].v += rhov[3].v;
|
||||
rhov[0] = _mm256_add_pd( rhov[2], rhov[0]) ;
|
||||
|
||||
for ( ; (i + 7) < n; i += 8 )
|
||||
if ( (i + 7) < n )
|
||||
{
|
||||
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
|
||||
xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg );
|
||||
@@ -411,40 +359,45 @@ void bli_ddotv_zen_int10
|
||||
yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg );
|
||||
|
||||
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
|
||||
rhov[1].v = _mm256_fmadd_pd( xv[1], yv[1], rhov[1].v );
|
||||
rhov[0] = _mm256_fmadd_pd( xv[0], yv[0], rhov[0] );
|
||||
rhov[1] = _mm256_fmadd_pd( xv[1], yv[1], rhov[1] );
|
||||
|
||||
x0 += 2*n_elem_per_reg;
|
||||
y0 += 2*n_elem_per_reg;
|
||||
i += 2*n_elem_per_reg;
|
||||
}
|
||||
|
||||
rhov[0].v += rhov[1].v;
|
||||
rhov[0] = _mm256_add_pd( rhov[1], rhov[0]) ;
|
||||
|
||||
for ( ; (i + 3) < n; i += 4 )
|
||||
if ( (i + 3) < n )
|
||||
{
|
||||
xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
|
||||
|
||||
yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
|
||||
|
||||
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
|
||||
rhov[0] = _mm256_fmadd_pd( xv[0], yv[0], rhov[0] );
|
||||
|
||||
x0 += 1*n_elem_per_reg;
|
||||
y0 += 1*n_elem_per_reg;
|
||||
x0 += n_elem_per_reg;
|
||||
y0 += n_elem_per_reg;
|
||||
i += n_elem_per_reg;
|
||||
}
|
||||
|
||||
if(i < n)
|
||||
if( i < n )
|
||||
{
|
||||
__m256i maskVec = _mm256_loadu_si256( (__m256i *)mask_ptr[(n - i)]);
|
||||
|
||||
xv[0] = _mm256_maskload_pd( x0, maskVec );
|
||||
yv[0] = _mm256_maskload_pd( y0, maskVec );
|
||||
|
||||
rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v );
|
||||
rhov[0] = _mm256_fmadd_pd( xv[0], yv[0], rhov[0] );
|
||||
i = n;
|
||||
}
|
||||
|
||||
// Manually add the results from above to finish the sum.
|
||||
rho0 += rhov[0].d[0] + rhov[0].d[1] + rhov[0].d[2] + rhov[0].d[3];
|
||||
// Perform horizontal addition of the elements in the vector.
|
||||
rh.v = _mm256_hadd_pd( rhov[0], rhov[0] );
|
||||
|
||||
// Manually add the first and third element from above vector to finish the sum.
|
||||
rho0 += rh.d[0] + rh.d[2];
|
||||
|
||||
// Issue vzeroupper instruction to clear upper lanes of ymm registers.
|
||||
// This avoids a performance penalty caused by false dependencies when
|
||||
|
||||
Reference in New Issue
Block a user