mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Fix extreme values handling in GEMV
- When alpha == 0, we are expected to only scale y vector with beta and not read A or X at all. - This scenario is not handled properly in all code paths which causes NAN and INF from A and X being wrongly propagated. For example, for non-zen cpus (default block in switch case) no such check is present, similarly some of the avx512 kernels are also missing these checks. - When beta == 0, we are not expected to read Y at all, this also is not handled correctly in one of the avx512 kernel. - To fix these, early return condition for alpha == 0 is added to bla layer itself so that each kernel does not have to implement the logic. - DGEMV AVX512 transpose kernel has been fixed to load vector Y only when beta != 0. AMD-Internal: [CPUPL-7585]
This commit is contained in:
@@ -285,7 +285,6 @@ void bli_dgemv_unf_var1
|
||||
|
||||
// Function pointer declaration for the functions that will be used.
|
||||
dgemv_ker_ft_conja gemv_kr_ptr = NULL; // DGEMV
|
||||
dscalv_ker_ft scalv_kr_ptr = NULL; // DSCALV
|
||||
dcopyv_ker_ft copyv_kr_ptr = NULL; // DCOPYV
|
||||
|
||||
/*
|
||||
@@ -307,7 +306,6 @@ void bli_dgemv_unf_var1
|
||||
case BLIS_ARCH_ZEN5:
|
||||
#if defined(BLIS_KERNELS_ZEN5)
|
||||
gemv_kr_ptr = bli_dgemv_t_zen4_int; // DGEMV
|
||||
scalv_kr_ptr = bli_dscalv_zen4_int; // DSCALV
|
||||
copyv_kr_ptr = bli_dcopyv_zen5_asm; // DCOPYV
|
||||
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
|
||||
fast_path_thresh = 12000;
|
||||
@@ -318,7 +316,6 @@ void bli_dgemv_unf_var1
|
||||
|
||||
#if defined(BLIS_KERNELS_ZEN4)
|
||||
gemv_kr_ptr = bli_dgemv_t_zen4_int; // DGEMV
|
||||
scalv_kr_ptr = bli_dscalv_zen4_int; // DSCALV
|
||||
copyv_kr_ptr = bli_dcopyv_zen4_asm; // DCOPYV
|
||||
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
|
||||
fast_path_thresh = 11000;
|
||||
@@ -331,7 +328,6 @@ void bli_dgemv_unf_var1
|
||||
case BLIS_ARCH_ZEN3:
|
||||
|
||||
gemv_kr_ptr = bli_dgemv_t_zen_int; // DGEMV
|
||||
scalv_kr_ptr = bli_dscalv_zen_int; // DSCALV
|
||||
copyv_kr_ptr = bli_dcopyv_zen_int; // DCOPYV
|
||||
|
||||
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
|
||||
@@ -388,24 +384,6 @@ void bli_dgemv_unf_var1
|
||||
return;
|
||||
}
|
||||
|
||||
// If alpha is zero, the GEMV operation is reduced to y := beta * y, thus,
|
||||
// y is only scaled by beta and returned.
|
||||
if( bli_deq0( *alpha ) )
|
||||
{
|
||||
// Invoke the SCALV function using the function pointer
|
||||
scalv_kr_ptr
|
||||
(
|
||||
BLIS_NO_CONJUGATE,
|
||||
n0,
|
||||
beta,
|
||||
y_buf, incy,
|
||||
cntx
|
||||
);
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3)
|
||||
return;
|
||||
}
|
||||
|
||||
// If x has non-unit increments , x is packed and copied to a temp buffer (x_buf)
|
||||
// and passed to the kernels. At the end, the memory is freed.
|
||||
if ( incx != 1 )
|
||||
|
||||
@@ -475,12 +475,6 @@ void bli_dgemv_unf_var2 (
|
||||
);
|
||||
}
|
||||
|
||||
if( bli_deq0( *alpha ) )
|
||||
{
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3)
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; i < n_iter; i += f)
|
||||
{
|
||||
f = bli_determine_blocksize_dim_f(i, n_iter, b_fuse);
|
||||
|
||||
@@ -287,6 +287,29 @@ void dgemv_blis_impl
|
||||
incy0 = ( inc_t )(*incy);
|
||||
}
|
||||
|
||||
// If alpha is zero, the GEMV operation is reduced to y := beta * y, thus,
|
||||
// y is only scaled by beta and returned.
|
||||
if( bli_deq0( *alpha ) == TRUE )
|
||||
{
|
||||
cntx_t* cntx = bli_gks_query_cntx();
|
||||
|
||||
// Query the context for the SCALV function pointer.
|
||||
dscalv_ker_ft scalv_kr_ptr = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_SCALV_KER, cntx );
|
||||
|
||||
// Invoke the SCALV function using the function pointer
|
||||
scalv_kr_ptr
|
||||
(
|
||||
BLIS_NO_CONJUGATE,
|
||||
m_y,
|
||||
(double*)beta,
|
||||
y0, incy0,
|
||||
cntx
|
||||
);
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
|
||||
return;
|
||||
}
|
||||
|
||||
/* Set the row and column strides of A. */
|
||||
rs_a = 1;
|
||||
cs_a = *lda;
|
||||
@@ -526,6 +549,29 @@ void sgemv_blis_impl
|
||||
incy0 = ( inc_t )(*incy);
|
||||
}
|
||||
|
||||
// If alpha is zero, the GEMV operation is reduced to y := beta * y, thus,
|
||||
// y is only scaled by beta and returned.
|
||||
if( bli_seq0( *alpha ) == TRUE )
|
||||
{
|
||||
cntx_t* cntx = bli_gks_query_cntx();
|
||||
|
||||
// Query the context for the SCALV function pointer.
|
||||
sscalv_ker_ft scalv_kr_ptr = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_SCALV_KER, cntx );
|
||||
|
||||
// Invoke the SCALV function using the function pointer
|
||||
scalv_kr_ptr
|
||||
(
|
||||
BLIS_NO_CONJUGATE,
|
||||
m_y,
|
||||
(float*)beta,
|
||||
y0, incy0,
|
||||
cntx
|
||||
);
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
|
||||
return;
|
||||
}
|
||||
|
||||
/* Set the row and column strides of A. */
|
||||
rs_a = 1;
|
||||
cs_a = *lda;
|
||||
@@ -729,6 +775,29 @@ void cgemv_blis_impl
|
||||
incy0 = ( inc_t )(*incy);
|
||||
}
|
||||
|
||||
// If alpha is zero, the GEMV operation is reduced to y := beta * y, thus,
|
||||
// y is only scaled by beta and returned.
|
||||
if( bli_ceq0( *alpha ) == TRUE )
|
||||
{
|
||||
cntx_t* cntx = bli_gks_query_cntx();
|
||||
|
||||
// Query the context for the SCALV function pointer.
|
||||
cscalv_ker_ft scalv_kr_ptr = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_SCALV_KER, cntx );
|
||||
|
||||
// Invoke the SCALV function using the function pointer
|
||||
scalv_kr_ptr
|
||||
(
|
||||
BLIS_NO_CONJUGATE,
|
||||
m_y,
|
||||
(scomplex*)beta,
|
||||
y0, incy0,
|
||||
cntx
|
||||
);
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
|
||||
return;
|
||||
}
|
||||
|
||||
/* Set the row and column strides of A. */
|
||||
rs_a = 1;
|
||||
cs_a = *lda;
|
||||
@@ -985,6 +1054,29 @@ void zgemv_blis_impl
|
||||
incy0 = ( inc_t )(*incy);
|
||||
}
|
||||
|
||||
// If alpha is zero, the GEMV operation is reduced to y := beta * y, thus,
|
||||
// y is only scaled by beta and returned.
|
||||
if( bli_zeq0( *alpha ) == TRUE )
|
||||
{
|
||||
cntx_t* cntx = bli_gks_query_cntx();
|
||||
|
||||
// Query the context for the SCALV function pointer.
|
||||
zscalv_ker_ft scalv_kr_ptr = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_SCALV_KER, cntx );
|
||||
|
||||
// Invoke the SCALV function using the function pointer
|
||||
scalv_kr_ptr
|
||||
(
|
||||
BLIS_NO_CONJUGATE,
|
||||
m_y,
|
||||
(dcomplex*)beta,
|
||||
y0, incy0,
|
||||
cntx
|
||||
);
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
|
||||
return;
|
||||
}
|
||||
|
||||
/* Set the row and column strides of A. */
|
||||
rs_a = 1;
|
||||
cs_a = *lda;
|
||||
|
||||
@@ -193,6 +193,10 @@ void bli_dgemv_t_zen_int
|
||||
rhov[6].v = _mm256_setzero_pd();
|
||||
rhov[7].v = _mm256_setzero_pd();
|
||||
|
||||
// Setting y register to zero before loading
|
||||
yv0.v = _mm256_setzero_pd();
|
||||
yv1.v = _mm256_setzero_pd();
|
||||
|
||||
// Loading data from vector y
|
||||
yv0.v = _mm256_maskload_pd( y_buf, beta_mask ); // yv0 = y_buf[0:3]
|
||||
yv1.v = _mm256_maskload_pd( y_buf + 4, beta_mask ); // yv1 = y_buf[4:7]
|
||||
@@ -539,17 +543,6 @@ void bli_dgemv_t_zen_int
|
||||
rhov[6].v = _mm256_setzero_pd();
|
||||
rhov[7].v = _mm256_setzero_pd();
|
||||
|
||||
// Loading data from vector y
|
||||
yv0.d[0] = *( y_buf + 0 * incy); // yv0[0] = y_buf[0]
|
||||
yv0.d[1] = *( y_buf + 1 * incy); // yv0[1] = y_buf[1]
|
||||
yv0.d[2] = *( y_buf + 2 * incy); // yv0[2] = y_buf[2]
|
||||
yv0.d[3] = *( y_buf + 3 * incy); // yv0[3] = y_buf[3]
|
||||
|
||||
yv1.d[0] = *( y_buf + 4 * incy); // yv1[0] = y_buf[4]
|
||||
yv1.d[1] = *( y_buf + 5 * incy); // yv1[1] = y_buf[5]
|
||||
yv1.d[2] = *( y_buf + 6 * incy); // yv1[2] = y_buf[6]
|
||||
yv1.d[3] = *( y_buf + 7 * incy); // yv1[3] = y_buf[7]
|
||||
|
||||
// Calculating beta * y
|
||||
if (bli_deq0( *beta ))
|
||||
{
|
||||
@@ -558,6 +551,17 @@ void bli_dgemv_t_zen_int
|
||||
}
|
||||
else
|
||||
{
|
||||
// Loading data from vector y
|
||||
yv0.d[0] = *( y_buf + 0 * incy); // yv0[0] = y_buf[0]
|
||||
yv0.d[1] = *( y_buf + 1 * incy); // yv0[1] = y_buf[1]
|
||||
yv0.d[2] = *( y_buf + 2 * incy); // yv0[2] = y_buf[2]
|
||||
yv0.d[3] = *( y_buf + 3 * incy); // yv0[3] = y_buf[3]
|
||||
|
||||
yv1.d[0] = *( y_buf + 4 * incy); // yv1[0] = y_buf[4]
|
||||
yv1.d[1] = *( y_buf + 5 * incy); // yv1[1] = y_buf[5]
|
||||
yv1.d[2] = *( y_buf + 6 * incy); // yv1[2] = y_buf[6]
|
||||
yv1.d[3] = *( y_buf + 7 * incy); // yv1[3] = y_buf[7]
|
||||
|
||||
yv0.v = _mm256_mul_pd ( betav.v, yv0.v ); // yv0 = beta * y_buf[0:3]
|
||||
yv1.v = _mm256_mul_pd ( betav.v, yv1.v ); // yv1 = beta * y_buf[4:7]
|
||||
}
|
||||
|
||||
@@ -256,7 +256,7 @@ static dgemv_ker_ft_conja n_ker_fp[8] =
|
||||
rhov[7].v = _mm512_setzero_pd();
|
||||
|
||||
// Loading the value of y into yv0
|
||||
yv0.v = _mm512_loadu_pd( y_buf ); // yv0 = y_buf[0:7]
|
||||
yv0.v = _mm512_maskz_loadu_pd( beta_mask, y_buf ); // yv0 = y_buf[0:7]
|
||||
yv0.v = _mm512_maskz_mul_pd( beta_mask, betav.v, yv0.v ); // yv0 = beta * y_buf[0:7]
|
||||
|
||||
// Handles (a_buf[0:31, 0:7] * x_buf[0:31])
|
||||
@@ -675,16 +675,23 @@ static dgemv_ker_ft_conja n_ker_fp[8] =
|
||||
|
||||
// In case of non-unit stride y,
|
||||
// The inputs on vector y are manually moved to register yv0
|
||||
yv0.d[0] = *( y_buf + (0 * incy) ); // yv0[0] = y_buf[0]
|
||||
yv0.d[1] = *( y_buf + (1 * incy) ); // yv0[1] = y_buf[1]
|
||||
yv0.d[2] = *( y_buf + (2 * incy) ); // yv0[2] = y_buf[2]
|
||||
yv0.d[3] = *( y_buf + (3 * incy) ); // yv0[3] = y_buf[3]
|
||||
yv0.d[4] = *( y_buf + (4 * incy) ); // yv0[4] = y_buf[4]
|
||||
yv0.d[5] = *( y_buf + (5 * incy) ); // yv0[5] = y_buf[5]
|
||||
yv0.d[6] = *( y_buf + (6 * incy) ); // yv0[6] = y_buf[6]
|
||||
yv0.d[7] = *( y_buf + (7 * incy) ); // yv0[7] = y_buf[7]
|
||||
if ( bli_deq0( *beta ) )
|
||||
{
|
||||
yv0.v = _mm512_setzero_pd();
|
||||
}
|
||||
else
|
||||
{
|
||||
yv0.d[0] = *( y_buf + (0 * incy) ); // yv0[0] = y_buf[0]
|
||||
yv0.d[1] = *( y_buf + (1 * incy) ); // yv0[1] = y_buf[1]
|
||||
yv0.d[2] = *( y_buf + (2 * incy) ); // yv0[2] = y_buf[2]
|
||||
yv0.d[3] = *( y_buf + (3 * incy) ); // yv0[3] = y_buf[3]
|
||||
yv0.d[4] = *( y_buf + (4 * incy) ); // yv0[4] = y_buf[4]
|
||||
yv0.d[5] = *( y_buf + (5 * incy) ); // yv0[5] = y_buf[5]
|
||||
yv0.d[6] = *( y_buf + (6 * incy) ); // yv0[6] = y_buf[6]
|
||||
yv0.d[7] = *( y_buf + (7 * incy) ); // yv0[7] = y_buf[7]
|
||||
|
||||
yv0.v = _mm512_maskz_mul_pd( beta_mask, betav.v, yv0.v ); // yv0 = beta * y_buf[0:7]
|
||||
yv0.v = _mm512_maskz_mul_pd( beta_mask, betav.v, yv0.v ); // yv0 = beta * y_buf[0:7]
|
||||
}
|
||||
|
||||
// Handles (a_buf[0:31, 0:7] * x_buf[0:31])
|
||||
for ( j = 0; (j + 31) < m; j += 32 )
|
||||
|
||||
Reference in New Issue
Block a user