Fix extreme values handling in GEMV

- When alpha == 0, we are expected to only scale y vector with beta and not read A or X at all.
- This scenario is not handled properly in all code paths which causes NAN and INF from A and X being wrongly propagated. For example, for non-zen cpus (default block in switch case) no such check is present, similarly some of the avx512 kernels are also missing these checks.
- When beta == 0, we are not expected to read Y at all, this also is not handled correctly in one of the avx512 kernel.
- To fix these, early return condition for alpha == 0 is added to bla layer itself so that each kernel does not have to implement the logic.
- DGEMV AVX512 transpose kernel has been fixed to load vector Y only when beta != 0.

AMD-Internal: [CPUPL-7585]
This commit is contained in:
S, Hari Govind
2025-11-07 19:55:26 +05:30
committed by GitHub
parent b729473839
commit 7b4e665273
5 changed files with 124 additions and 49 deletions

View File

@@ -285,7 +285,6 @@ void bli_dgemv_unf_var1
// Function pointer declaration for the functions that will be used.
dgemv_ker_ft_conja gemv_kr_ptr = NULL; // DGEMV
dscalv_ker_ft scalv_kr_ptr = NULL; // DSCALV
dcopyv_ker_ft copyv_kr_ptr = NULL; // DCOPYV
/*
@@ -307,7 +306,6 @@ void bli_dgemv_unf_var1
case BLIS_ARCH_ZEN5:
#if defined(BLIS_KERNELS_ZEN5)
gemv_kr_ptr = bli_dgemv_t_zen4_int; // DGEMV
scalv_kr_ptr = bli_dscalv_zen4_int; // DSCALV
copyv_kr_ptr = bli_dcopyv_zen5_asm; // DCOPYV
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
fast_path_thresh = 12000;
@@ -318,7 +316,6 @@ void bli_dgemv_unf_var1
#if defined(BLIS_KERNELS_ZEN4)
gemv_kr_ptr = bli_dgemv_t_zen4_int; // DGEMV
scalv_kr_ptr = bli_dscalv_zen4_int; // DSCALV
copyv_kr_ptr = bli_dcopyv_zen4_asm; // DCOPYV
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
fast_path_thresh = 11000;
@@ -331,7 +328,6 @@ void bli_dgemv_unf_var1
case BLIS_ARCH_ZEN3:
gemv_kr_ptr = bli_dgemv_t_zen_int; // DGEMV
scalv_kr_ptr = bli_dscalv_zen_int; // DSCALV
copyv_kr_ptr = bli_dcopyv_zen_int; // DCOPYV
#if defined(BLIS_ENABLE_OPENMP) && defined(AOCL_DYNAMIC)
@@ -388,24 +384,6 @@ void bli_dgemv_unf_var1
return;
}
// If alpha is zero, the GEMV operation is reduced to y := beta * y, thus,
// y is only scaled by beta and returned.
if( bli_deq0( *alpha ) )
{
// Invoke the SCALV function using the function pointer
scalv_kr_ptr
(
BLIS_NO_CONJUGATE,
n0,
beta,
y_buf, incy,
cntx
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3)
return;
}
// If x has non-unit increments , x is packed and copied to a temp buffer (x_buf)
// and passed to the kernels. At the end, the memory is freed.
if ( incx != 1 )

View File

@@ -475,12 +475,6 @@ void bli_dgemv_unf_var2 (
);
}
if( bli_deq0( *alpha ) )
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3)
return;
}
for (i = 0; i < n_iter; i += f)
{
f = bli_determine_blocksize_dim_f(i, n_iter, b_fuse);

View File

@@ -287,6 +287,29 @@ void dgemv_blis_impl
incy0 = ( inc_t )(*incy);
}
// If alpha is zero, the GEMV operation is reduced to y := beta * y, thus,
// y is only scaled by beta and returned.
if( bli_deq0( *alpha ) == TRUE )
{
cntx_t* cntx = bli_gks_query_cntx();
// Query the context for the SCALV function pointer.
dscalv_ker_ft scalv_kr_ptr = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_SCALV_KER, cntx );
// Invoke the SCALV function using the function pointer
scalv_kr_ptr
(
BLIS_NO_CONJUGATE,
m_y,
(double*)beta,
y0, incy0,
cntx
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
return;
}
/* Set the row and column strides of A. */
rs_a = 1;
cs_a = *lda;
@@ -526,6 +549,29 @@ void sgemv_blis_impl
incy0 = ( inc_t )(*incy);
}
// If alpha is zero, the GEMV operation is reduced to y := beta * y, thus,
// y is only scaled by beta and returned.
if( bli_seq0( *alpha ) == TRUE )
{
cntx_t* cntx = bli_gks_query_cntx();
// Query the context for the SCALV function pointer.
sscalv_ker_ft scalv_kr_ptr = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_SCALV_KER, cntx );
// Invoke the SCALV function using the function pointer
scalv_kr_ptr
(
BLIS_NO_CONJUGATE,
m_y,
(float*)beta,
y0, incy0,
cntx
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
return;
}
/* Set the row and column strides of A. */
rs_a = 1;
cs_a = *lda;
@@ -729,6 +775,29 @@ void cgemv_blis_impl
incy0 = ( inc_t )(*incy);
}
// If alpha is zero, the GEMV operation is reduced to y := beta * y, thus,
// y is only scaled by beta and returned.
if( bli_ceq0( *alpha ) == TRUE )
{
cntx_t* cntx = bli_gks_query_cntx();
// Query the context for the SCALV function pointer.
cscalv_ker_ft scalv_kr_ptr = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_SCALV_KER, cntx );
// Invoke the SCALV function using the function pointer
scalv_kr_ptr
(
BLIS_NO_CONJUGATE,
m_y,
(scomplex*)beta,
y0, incy0,
cntx
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
return;
}
/* Set the row and column strides of A. */
rs_a = 1;
cs_a = *lda;
@@ -985,6 +1054,29 @@ void zgemv_blis_impl
incy0 = ( inc_t )(*incy);
}
// If alpha is zero, the GEMV operation is reduced to y := beta * y, thus,
// y is only scaled by beta and returned.
if( bli_zeq0( *alpha ) == TRUE )
{
cntx_t* cntx = bli_gks_query_cntx();
// Query the context for the SCALV function pointer.
zscalv_ker_ft scalv_kr_ptr = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_SCALV_KER, cntx );
// Invoke the SCALV function using the function pointer
scalv_kr_ptr
(
BLIS_NO_CONJUGATE,
m_y,
(dcomplex*)beta,
y0, incy0,
cntx
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
return;
}
/* Set the row and column strides of A. */
rs_a = 1;
cs_a = *lda;

View File

@@ -193,6 +193,10 @@ void bli_dgemv_t_zen_int
rhov[6].v = _mm256_setzero_pd();
rhov[7].v = _mm256_setzero_pd();
// Setting y register to zero before loading
yv0.v = _mm256_setzero_pd();
yv1.v = _mm256_setzero_pd();
// Loading data from vector y
yv0.v = _mm256_maskload_pd( y_buf, beta_mask ); // yv0 = y_buf[0:3]
yv1.v = _mm256_maskload_pd( y_buf + 4, beta_mask ); // yv1 = y_buf[4:7]
@@ -539,17 +543,6 @@ void bli_dgemv_t_zen_int
rhov[6].v = _mm256_setzero_pd();
rhov[7].v = _mm256_setzero_pd();
// Loading data from vector y
yv0.d[0] = *( y_buf + 0 * incy); // yv0[0] = y_buf[0]
yv0.d[1] = *( y_buf + 1 * incy); // yv0[1] = y_buf[1]
yv0.d[2] = *( y_buf + 2 * incy); // yv0[2] = y_buf[2]
yv0.d[3] = *( y_buf + 3 * incy); // yv0[3] = y_buf[3]
yv1.d[0] = *( y_buf + 4 * incy); // yv1[0] = y_buf[4]
yv1.d[1] = *( y_buf + 5 * incy); // yv1[1] = y_buf[5]
yv1.d[2] = *( y_buf + 6 * incy); // yv1[2] = y_buf[6]
yv1.d[3] = *( y_buf + 7 * incy); // yv1[3] = y_buf[7]
// Calculating beta * y
if (bli_deq0( *beta ))
{
@@ -558,6 +551,17 @@ void bli_dgemv_t_zen_int
}
else
{
// Loading data from vector y
yv0.d[0] = *( y_buf + 0 * incy); // yv0[0] = y_buf[0]
yv0.d[1] = *( y_buf + 1 * incy); // yv0[1] = y_buf[1]
yv0.d[2] = *( y_buf + 2 * incy); // yv0[2] = y_buf[2]
yv0.d[3] = *( y_buf + 3 * incy); // yv0[3] = y_buf[3]
yv1.d[0] = *( y_buf + 4 * incy); // yv1[0] = y_buf[4]
yv1.d[1] = *( y_buf + 5 * incy); // yv1[1] = y_buf[5]
yv1.d[2] = *( y_buf + 6 * incy); // yv1[2] = y_buf[6]
yv1.d[3] = *( y_buf + 7 * incy); // yv1[3] = y_buf[7]
yv0.v = _mm256_mul_pd ( betav.v, yv0.v ); // yv0 = beta * y_buf[0:3]
yv1.v = _mm256_mul_pd ( betav.v, yv1.v ); // yv1 = beta * y_buf[4:7]
}

View File

@@ -256,7 +256,7 @@ static dgemv_ker_ft_conja n_ker_fp[8] =
rhov[7].v = _mm512_setzero_pd();
// Loading the value of y into yv0
yv0.v = _mm512_loadu_pd( y_buf ); // yv0 = y_buf[0:7]
yv0.v = _mm512_maskz_loadu_pd( beta_mask, y_buf ); // yv0 = y_buf[0:7]
yv0.v = _mm512_maskz_mul_pd( beta_mask, betav.v, yv0.v ); // yv0 = beta * y_buf[0:7]
// Handles (a_buf[0:31, 0:7] * x_buf[0:31])
@@ -675,16 +675,23 @@ static dgemv_ker_ft_conja n_ker_fp[8] =
// In case of non-unit stride y,
// The inputs on vector y are manually moved to register yv0
yv0.d[0] = *( y_buf + (0 * incy) ); // yv0[0] = y_buf[0]
yv0.d[1] = *( y_buf + (1 * incy) ); // yv0[1] = y_buf[1]
yv0.d[2] = *( y_buf + (2 * incy) ); // yv0[2] = y_buf[2]
yv0.d[3] = *( y_buf + (3 * incy) ); // yv0[3] = y_buf[3]
yv0.d[4] = *( y_buf + (4 * incy) ); // yv0[4] = y_buf[4]
yv0.d[5] = *( y_buf + (5 * incy) ); // yv0[5] = y_buf[5]
yv0.d[6] = *( y_buf + (6 * incy) ); // yv0[6] = y_buf[6]
yv0.d[7] = *( y_buf + (7 * incy) ); // yv0[7] = y_buf[7]
if ( bli_deq0( *beta ) )
{
yv0.v = _mm512_setzero_pd();
}
else
{
yv0.d[0] = *( y_buf + (0 * incy) ); // yv0[0] = y_buf[0]
yv0.d[1] = *( y_buf + (1 * incy) ); // yv0[1] = y_buf[1]
yv0.d[2] = *( y_buf + (2 * incy) ); // yv0[2] = y_buf[2]
yv0.d[3] = *( y_buf + (3 * incy) ); // yv0[3] = y_buf[3]
yv0.d[4] = *( y_buf + (4 * incy) ); // yv0[4] = y_buf[4]
yv0.d[5] = *( y_buf + (5 * incy) ); // yv0[5] = y_buf[5]
yv0.d[6] = *( y_buf + (6 * incy) ); // yv0[6] = y_buf[6]
yv0.d[7] = *( y_buf + (7 * incy) ); // yv0[7] = y_buf[7]
yv0.v = _mm512_maskz_mul_pd( beta_mask, betav.v, yv0.v ); // yv0 = beta * y_buf[0:7]
yv0.v = _mm512_maskz_mul_pd( beta_mask, betav.v, yv0.v ); // yv0 = beta * y_buf[0:7]
}
// Handles (a_buf[0:31, 0:7] * x_buf[0:31])
for ( j = 0; (j + 31) < m; j += 32 )