Added vector packing logic to ZGEMV variant 2

- In cases when incy != 1, a buffer is created for y vector. The
  contents of vector y is scaled by beta and stored in this buffer.
- After performing the compute using ZAXPYF kernel, the results in
  y buffer memory is copied back to the orginal buffer using ZCOPYV.
- In cases when alpha is zero, we only scale the y vector by beta
  without using the buffer and return.
- The kernels are picked based on the architecture ID. For any zen
  based architecture, AVX2 kernels are invoked. For other, the
  kernels are invoked based on the context.
- In ZSCAL2V, query for the context if NULL pointer is passed.

AMD-Internal: [CPUPL-2773]
Change-Id: If409ca5c438fc2eebe73480c011577088d52c65f
This commit is contained in:
Harihara Sudhan S
2023-03-01 11:41:09 +05:30
committed by HariharaSudhan S
parent a67205d8bd
commit 4b36529a8b
2 changed files with 202 additions and 150 deletions

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2020-23, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -576,175 +576,225 @@ void bli_zgemv_unf_var2
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3);
dcomplex* A1;
dcomplex* x1;
dcomplex* y1;
dim_t i;
dim_t b_fuse, f;
dim_t n_elem, n_iter;
inc_t rs_at, cs_at;
conj_t conja;
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3);
// For AMD these APIS are invoked skipping intermediate framework layers
// Hence we need to ensure that cntx is set here.
bli_init_once();
if(cntx == NULL) cntx = bli_gks_query_cntx();
dcomplex *A1;
dcomplex *x1;
dcomplex *y1;
bli_set_dims_incs_with_trans( transa,
m, n, rs_a, cs_a,
&n_elem, &n_iter, &rs_at, &cs_at );
dim_t i, b_fuse, f;
dim_t n_elem, n_iter;
inc_t rs_at, cs_at;
conj_t conja;
conja = bli_extract_conj( transa );
// Memory pool declarations for packing vector Y.
mem_t mem_bufY;
rntm_t rntm;
dcomplex *y_buf = y;
inc_t buf_incy = incy;
/* If beta is zero, use setv. Otherwise, scale by beta. */
/* y = beta * y; */
bli_set_dims_incs_with_trans(transa,
m, n, rs_a, cs_a,
&n_elem, &n_iter, &rs_at, &cs_at);
/* beta=0 case is hadled by scalv internally */
/* bli_zscalv_zen_int10
(
BLIS_NO_CONJUGATE,
n_elem,
beta,
y,
incy,
cntx
);*/
conja = bli_extract_conj(transa);
// This function is invoked on all architectures including generic.
// Non-AVX platforms will use the kernels derived from the context.
if (bli_cpuid_is_avx_supported() == FALSE)
// Query the architecture ID
arch_t id = bli_arch_query_id();
/*
Function pointer declaration for the functions
that will be used by this API
*/
zaxpyf_ker_ft axpyf_kr_ptr; // ZAXPYF
zscal2v_ker_ft scal2v_kr_ptr; // ZSCAL2V
zscalv_ker_ft scalv_kr_ptr; // ZSCALV
zcopyv_ker_ft copyv_kr_ptr; // ZCOPYV
/*
Boolean to check if the y has been packed
and memory needs to be freed in the end
*/
bool is_y_temp_buf_created = FALSE;
switch (id)
{
case BLIS_ARCH_ZEN4:
case BLIS_ARCH_ZEN:
case BLIS_ARCH_ZEN2:
case BLIS_ARCH_ZEN3:
/*
Assign the AVX2 based kernel function pointers for
ZAXPYF, ZSCAL2V, ZSCALV, ZCOPYV and corresponding fusing
factor of ZAXPYF kernel
*/
axpyf_kr_ptr = bli_zaxpyf_zen_int_4;
b_fuse = 4;
scal2v_kr_ptr = bli_zscal2v_zen_int;
scalv_kr_ptr = bli_zscalv_zen_int;
copyv_kr_ptr = bli_zcopyv_zen_int;
break;
default:
// For non-Zen architectures, query the context if it is NULL
if(cntx == NULL) cntx = bli_gks_query_cntx();
/*
Query the context for the kernel function pointers for
ZAXPYF, ZSCAL2V, ZSCALV, ZCOPYV and corresponding fusing
factor of ZAXPYF kernel
*/
axpyf_kr_ptr = bli_cntx_get_l1f_ker_dt(BLIS_DCOMPLEX, BLIS_AXPYF_KER, cntx);
b_fuse = bli_cntx_get_blksz_def_dt(BLIS_DCOMPLEX, BLIS_AF, cntx);
scal2v_kr_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DCOMPLEX, BLIS_SCAL2V_KER, cntx);
scalv_kr_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DCOMPLEX, BLIS_SCALV_KER, cntx);
copyv_kr_ptr = bli_cntx_get_l1v_ker_dt(BLIS_DCOMPLEX, BLIS_COPYV_KER, cntx);
}
/*
If alpha is equal to zero, y = beta * y + alpha * A * x
becomes y = beat * y in that case packing will be costly.
y is only scaled with SCALV and returned.
*/
if (incy > 1 && (!bli_zeq0(*alpha)))
{
/*
Initialize mem pool buffer to NULL and size to 0
"buf" and "size" fields are assigned once memory
is allocated from the pool in bli_membrk_acquire_m().
This will ensure bli_mem_is_alloc() will be passed on
an allocated memory if created or a NULL .
*/
mem_bufY.pblk.buf = NULL;
mem_bufY.pblk.block_size = 0;
mem_bufY.buf_type = 0;
mem_bufY.size = 0;
mem_bufY.pool = NULL;
/*
In order to get the buffer from pool via rntm access to memory broker
is needed.Following are initializations for rntm
*/
bli_rntm_init_from_global(&rntm);
bli_rntm_set_num_threads_only(1, &rntm);
bli_membrk_rntm_set_membrk(&rntm);
// Calculate the size required for n_elem double elements in vector Y.
size_t buffer_size = n_elem * sizeof(dcomplex);
#ifdef BLIS_ENABLE_MEM_TRACING
printf("bli_zgemv_unf_var2(): get mem pool block\n");
#endif
/*
Acquire a Buffer(n_elem*size(double)) from the memory broker
and save the associated mem_t entry to mem_bufY.
*/
bli_membrk_acquire_m(&rntm,
buffer_size,
BLIS_BUFFER_FOR_B_PANEL,
&mem_bufY);
/* Continue packing Y if buffer memory is allocated */
if ((bli_mem_is_alloc(&mem_bufY)))
{
const num_t dt = PASTEMAC(z,type);
/* If beta is zero, use setv. Otherwise, scale by beta. */
if ( PASTEMAC(z,eq0)( *beta ) )
{
dcomplex* zero = PASTEMAC(z,0);
/* y = 0; */
PASTEMAC2(z,setv,BLIS_TAPI_EX_SUF)
(
BLIS_NO_CONJUGATE,
n_elem,
zero,
y, incy,
cntx,
NULL
);
}
else
{
/* y = beta * y; */
PASTEMAC2(z,scalv,BLIS_TAPI_EX_SUF)
(
BLIS_NO_CONJUGATE,
n_elem,
beta,
y, incy,
cntx,
NULL
);
}
y_buf = bli_mem_buffer(&mem_bufY);
buf_incy = 1;
PASTECH(z,axpyf_ker_ft) kfp_af;
// Invoke the ZSCAL2V function using the function pointer
scal2v_kr_ptr
(
BLIS_NO_CONJUGATE,
n_elem,
beta,
y, incy,
y_buf, buf_incy,
NULL
);
/* Query the context for the kernel function pointer and fusing factor. */
kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx );
b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx );
for ( i = 0; i < n_iter; i += f )
{
f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse );
A1 = a + (0 )*rs_at + (i )*cs_at;
x1 = x + (i )*incx;
y1 = y + (0 )*incy;
/* y = y + alpha * A1 * x1; */
kfp_af
(
conja,
conjx,
n_elem,
f,
alpha,
A1, rs_at, cs_at,
x1, incx,
y1, incy,
cntx
);
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
return;
/*
Set y is packed as the memory allocation was
successful and contents have been copied
*/
is_y_temp_buf_created = TRUE;
}
bli_zscalv_ex
}
else
{
// Invoke the ZSCALV function using the function pointer
scalv_kr_ptr
(
BLIS_NO_CONJUGATE,
n_elem,
beta,
y_buf, buf_incy,
NULL
);
}
// If alpha is zero(0), we only need to scalv y and return
if (bli_zeq0(*alpha))
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
// Return early for alpha is zero(0)
return;
}
for (i = 0; i < n_iter; i += f)
{
f = bli_determine_blocksize_dim_f(i, n_iter, b_fuse);
A1 = a + (0) * rs_at + (i)*cs_at;
x1 = x + (i)*incx;
y1 = y_buf + (0) * buf_incy;
// Invoke the ZAXPYF function using the function pointer
axpyf_kr_ptr
(
conja,
conjx,
n_elem,
f,
alpha,
A1, rs_at, cs_at,
x1, incx,
y1, buf_incy,
cntx
);
}
// Check if temp y buffer was used for compute
if (is_y_temp_buf_created)
{
// Store the result from unit strided y_buf to non-unit strided Y
// Invoke the ZCOPYV function using the function pointer
copyv_kr_ptr
(
BLIS_NO_CONJUGATE,
n_elem,
y_buf, buf_incy,
y, incy,
cntx,
NULL
);
if( bli_zeq0( *alpha ) )
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
return;
}
#ifdef BLIS_ENABLE_MEM_TRACING
printf("bli_zgemv_unf_var2(): releasing mem pool block\n");
#endif
// for non-unit incx, incy and rs_at and conjugate will be added in the next patch
if( (incx == 1 && incy == 1 && rs_at == 1 ) &&
!bli_is_conj(conja) && !bli_is_conj(conjx) && !bli_is_trans(transa))
{
// This gemv code deals with the followint conditions only
// 1. incx, incy, and row stride equal to one
// 2. Non conjugate A matrix and X vector
// 3. No Transpose for A Martix
// Rest is taken care by the else part (axpyf implementation)
bli_zgemv_zen_int_4x4
(
conja,
conjx,
m,
n,
alpha,
a, rs_at, cs_at,
x, incx,
beta,
y, incy,
cntx
);
}
else
{
/* fusing factor */
b_fuse = 4;
// Return the buffer to pool
bli_membrk_release(&rntm, &mem_bufY);
}
for ( i = 0; i < n_iter; i += f )
{
f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse );
A1 = a + (0 )*rs_at + (i )*cs_at;
x1 = x + (i )*incx;
y1 = y + (0 )*incy;
/* y = y + alpha * A1 * x1; */
bli_zaxpyf_zen_int_4
(
conja,
conjx,
n_elem,
f,
alpha,
A1, rs_at, cs_at,
x1, incx,
y1, incy,
cntx
);
}
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
}
void bli_cgemv_unf_var2

View File

@@ -137,6 +137,8 @@ void bli_zscal2v_zen_int
/* If alpha is zero, use setv. */
dcomplex *zero = PASTEMAC(z, 0);
if(cntx == NULL) cntx = bli_gks_query_cntx();
/* Query the context for the kernel function pointer. */
const num_t dt = PASTEMAC(z, type);