Improved DGEMV performance for smaller sizes

- Introduced two new ddotxf functions with lower fuse
  factor.
- Changed the DGEMV framework to use new kernels to
  improve problem decomposition.

Change-Id: I523e158fd33260d06224118fbf74f2314e03a617
This commit is contained in:
Harihara Sudhan S
2021-12-14 12:01:12 +05:30
parent 0f43db8347
commit 8201bcfdaf
3 changed files with 1067 additions and 253 deletions

View File

@@ -34,7 +34,6 @@
*/
#include "blis.h"
#define BLIS_DGEMV_VAR1_FUSE 8
#undef GENTFUNC
#define GENTFUNC( ctype, ch, varname ) \
@@ -121,30 +120,30 @@ void bli_dgemv_unf_var1
)
{
double* A1;
double* y1;
dim_t i;
dim_t f;
dim_t n_elem, n_iter;
inc_t rs_at, cs_at;
conj_t conja;
double *A1;
double *y1;
dim_t i;
dim_t f;
dim_t n_elem, n_iter;
inc_t rs_at, cs_at;
conj_t conja;
//memory pool declarations for packing vector X.
mem_t mem_bufX;
rntm_t rntm;
double *x_buf = x;
inc_t buf_incx = incx;
mem_t mem_bufX;
rntm_t rntm;
double *x_buf = x;
inc_t buf_incx = incx;
bli_init_once();
if( cntx == NULL ) cntx = bli_gks_query_cntx();
if (cntx == NULL)
cntx = bli_gks_query_cntx();
bli_set_dims_incs_with_trans( transa,
m, n, rs_a, cs_a,
&n_iter, &n_elem, &rs_at, &cs_at );
bli_set_dims_incs_with_trans(transa,
m, n, rs_a, cs_a,
&n_iter, &n_elem, &rs_at, &cs_at);
conja = bli_extract_conj( transa );
conja = bli_extract_conj(transa);
// When dynamic dispatch is enabled i.e. library is built for amdzen configuration.
// This function is invoked on all architectures including generic.
// Invoke architecture specific kernels only if we are sure that we are running on zen,
// zen2 or zen3 otherwise fall back to reference kernels (via framework and context).
@@ -193,88 +192,154 @@ void bli_dgemv_unf_var1
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
return;
}
if (incx > 1)
{
/*
/*
Initialize mem pool buffer to NULL and size to 0
"buf" and "size" fields are assigned once memory
is allocated from the pool in bli_membrk_acquire_m().
This will ensure bli_mem_is_alloc() will be passed on
an allocated memory if created or a NULL .
*/
mem_bufX.pblk.buf = NULL; mem_bufX.pblk.block_size = 0;
mem_bufX.buf_type = 0; mem_bufX.size = 0;
mem_bufX.pool = NULL;
*/
/* In order to get the buffer from pool via rntm access to memory broker
mem_bufX.pblk.buf = NULL;
mem_bufX.pblk.block_size = 0;
mem_bufX.buf_type = 0;
mem_bufX.size = 0;
mem_bufX.pool = NULL;
/* In order to get the buffer from pool via rntm access to memory broker
is needed.Following are initializations for rntm */
bli_rntm_init_from_global( &rntm );
bli_rntm_set_num_threads_only( 1, &rntm );
bli_membrk_rntm_set_membrk( &rntm );
bli_rntm_init_from_global(&rntm);
bli_rntm_set_num_threads_only(1, &rntm);
bli_membrk_rntm_set_membrk(&rntm);
//calculate the size required for n_elem double elements in vector X.
size_t buffer_size = n_elem * sizeof(double);
//calculate the size required for n_elem double elements in vector X.
size_t buffer_size = n_elem * sizeof(double);
#ifdef BLIS_ENABLE_MEM_TRACING
printf( "bli_dgemv_unf_var1(): get mem pool block\n" );
#endif
#ifdef BLIS_ENABLE_MEM_TRACING
printf("bli_dgemv_unf_var1(): get mem pool block\n");
#endif
/*acquire a Buffer(n_elem*size(double)) from the memory broker
and save the associated mem_t entry to mem_bufX.*/
bli_membrk_acquire_m(&rntm,
buffer_size,
BLIS_BUFFER_FOR_B_PANEL,
&mem_bufX);
/*acquire a Buffer(n_elem*size(double)) from the memory broker
and save the associated mem_t entry to mem_bufX.*/
bli_membrk_acquire_m(&rntm,
buffer_size,
BLIS_BUFFER_FOR_B_PANEL,
&mem_bufX);
/*Continue packing X if buffer memory is allocated*/
if ((bli_mem_is_alloc( &mem_bufX )))
{
x_buf = bli_mem_buffer(&mem_bufX);
//pack X vector with non-unit stride to a temp buffer x_buf with unit stride
for(dim_t x_index = 0 ; x_index < n_elem ; x_index++)
{
*(x_buf + x_index) = *(x + (x_index * incx)) ;
}
// stride of vector x_buf =1
buf_incx = 1;
}
}
for ( i = 0; i < n_iter; i += f )
/*Continue packing X if buffer memory is allocated*/
if ((bli_mem_is_alloc(&mem_bufX)))
{
f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR1_FUSE );
x_buf = bli_mem_buffer(&mem_bufX);
A1 = a + (i )*rs_at + (0 )*cs_at;
y1 = y + (i )*incy;
//pack X vector with non-unit stride to a temp buffer x_buf with unit stride
for (dim_t x_index = 0; x_index < n_elem; x_index++)
{
*(x_buf + x_index) = *(x + (x_index * incx));
}
// stride of vector x_buf =1
buf_incx = 1;
}
}
/* y1 = beta * y1 + alpha * A1 * x; */
bli_ddotxf_zen_int_8
(
dim_t fuse_factor = 8;
dim_t f_temp =0;
if (n < 4)
{
fuse_factor = 2;
} else if (n < 8)
{
fuse_factor = 4;
}
for (i = 0; i < n_iter; i += f)
{
f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor);
//A = a + i * row_increment + 0 * column_increment
A1 = a + (i)*rs_at;
y1 = y + (i)*incy;
/* y1 = beta * y1 + alpha * A1 * x; */
switch (f)
{
case 8:
bli_ddotxf_zen_int_8(
conja,
conjx,
n_elem,
f,
alpha,
A1, cs_at, rs_at,
x_buf, buf_incx,
A1, cs_at, rs_at,
x_buf, buf_incx,
beta,
y1, incy,
cntx
);
y1, incy,
cntx);
break;
default:
if (f < 4)
{
bli_ddotxf_zen_int_2(
conja,
conjx,
n_elem,
f,
alpha,
A1, cs_at, rs_at,
x_buf, buf_incx,
beta,
y1, incy,
cntx);
}
else
{
bli_ddotxf_zen_int_4(
conja,
conjx,
n_elem,
f,
alpha,
A1, cs_at, rs_at,
x_buf, buf_incx,
beta,
y1, incy,
cntx);
}
}
if ((incx > 1) && bli_mem_is_alloc( &mem_bufX ))
f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor);
if (f_temp < fuse_factor)
{
#ifdef BLIS_ENABLE_MEM_TRACING
printf( "bli_dgemv_unf_var1(): releasing mem pool block\n" );
#endif
// Return the buffer to pool
bli_membrk_release(&rntm , &mem_bufX);
switch (fuse_factor)
{
case 8:
fuse_factor = 4;
break;
case 4:
fuse_factor = 2;
break;
}
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
}
if ((incx > 1) && bli_mem_is_alloc(&mem_bufX))
{
#ifdef BLIS_ENABLE_MEM_TRACING
printf("bli_dgemv_unf_var1(): releasing mem pool block\n");
#endif
// Return the buffer to pool
bli_membrk_release(&rntm, &mem_bufX);
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
}
void bli_sgemv_unf_var1

File diff suppressed because it is too large Load Diff

View File

@@ -118,6 +118,8 @@ AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 )
// dotxf (intrinsics)
DOTXF_KER_PROT( float, s, dotxf_zen_int_8 )
DOTXF_KER_PROT( double, d, dotxf_zen_int_8 )
DOTXF_KER_PROT( double, d, dotxf_zen_int_4 )
DOTXF_KER_PROT( double, d, dotxf_zen_int_2 )
// dotxaxpyf (intrinsics)
DOTXAXPYF_KER_PROT( double, d, dotxaxpyf_zen_int_8 )