Fix vectorized version of bli_amaxv (#382)

* Fix vectorized version of bli_amaxv

To match Netlib, i?amax should return:
- the lowest index among equal values
- the first NaN if one is encountered

* Fix typos.

* And another one...

* Update ref. amaxv kernel too.

* Re-enabled optimized amaxv kernels.

Details:
- Re-enabled the optimized, intrinsics-based amaxv kernels in the 'zen'
  kernel set for use in haswell, zen, zen2, knl, and skx subconfigs.
  These two kernels (for s and d datatypes) were temporarily disabled in
  e186d71 as part of issue #380. However, the key missing semantic
  properties that prompted the disabling of these kernels--returning the
  index of the *first* rather than of the last element with largest
  absolute value, and returning the index of the first NaN if one is
  encountered--were added as part of #382 thanks to Devin Matthews.
  Thus, now that the kernels are working as expected once more, this
  commit causes these kernels to once again be registered for the
  affected subconfigs, which effectively reverts all code changes
  included in e186d71.
- Whitespace/formatting updates to new macros in bli_amaxv_zen_int.c.

Co-authored-by: Field G. Van Zee <field@cs.utexas.edu>
This commit is contained in:
Devin Matthews
2020-03-24 17:28:47 -05:00
committed by GitHub
parent e186d7141a
commit 492a736fab
8 changed files with 95 additions and 75 deletions

View File

@@ -89,10 +89,8 @@ void bli_cntx_init_haswell( cntx_t* cntx )
// Update the context with optimized level-1v kernels.
bli_cntx_set_l1v_kers
(
8,
#if 0
// NOTE: Disabled vectorized amaxv kernels due to incorrect semantics.
// See issue #380 for more details.
10,
#if 1
// amaxv
BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int,
BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int,

View File

@@ -78,10 +78,8 @@ void bli_cntx_init_knl( cntx_t* cntx )
// Update the context with optimized level-1v kernels.
bli_cntx_set_l1v_kers
(
8,
#if 0
// NOTE: Disabled vectorized amaxv kernels due to incorrect semantics.
// See issue #380 for more details.
10,
#if 1
// amaxv
BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int,
BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int,

View File

@@ -149,10 +149,8 @@ void bli_cntx_init_haswell( cntx_t* cntx )
// Update the context with optimized level-1v kernels.
bli_cntx_set_l1v_kers
(
8,
#if 0
// NOTE: Disabled vectorized amaxv kernels due to incorrect semantics.
// See issue #380 for more details.
10,
#if 1
// amaxv
BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int,
BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int,

View File

@@ -70,10 +70,8 @@ void bli_cntx_init_skx( cntx_t* cntx )
// Update the context with optimized level-1v kernels.
bli_cntx_set_l1v_kers
(
8,
#if 0
// NOTE: Disabled vectorized amaxv kernels due to incorrect semantics.
// See issue #380 for more details.
10,
#if 1
// amaxv
BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int,
BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int,

View File

@@ -82,10 +82,8 @@ void bli_cntx_init_zen( cntx_t* cntx )
// Update the context with optimized level-1v kernels.
bli_cntx_set_l1v_kers
(
8,
#if 0
// NOTE: Disabled vectorized amaxv kernels due to incorrect semantics.
// See issue #380 for more details.
10,
#if 1
// amaxv
BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int,
BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int,

View File

@@ -79,10 +79,8 @@ void bli_cntx_init_zen2( cntx_t* cntx )
// Update the context with optimized level-1v kernels.
bli_cntx_set_l1v_kers
(
8,
#if 0
// NOTE: Disabled vectorized amaxv kernels due to incorrect semantics.
// See issue #380 for more details.
10,
#if 1
// amaxv
BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int,
BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int,

View File

@@ -65,6 +65,38 @@ typedef union
double d[2];
}v2dd_t;
// return a mask which indicates either:
// - v1 > v2
// - v1 is NaN and v2 is not
// assumes that idx(v1) > idx(v2)
// all "OQ" comparisons false if either operand NaN
#define CMP256( dt, v1, v2 ) \
_mm256_or_p##dt( _mm256_cmp_p##dt( v1, v2, _CMP_GT_OQ ), /* v1 > v2 || */ \
_mm256_andnot_p##dt( _mm256_cmp_p##dt( v2, v2, _CMP_UNORD_Q ), /* ( !isnan(v2) && */ \
_mm256_cmp_p##dt( v1, v1, _CMP_UNORD_Q ) /* isnan(v1) ) */ \
) \
);
// return a mask which indicates either:
// - v1 > v2
// - v1 is NaN and v2 is not
// - v1 == v2 (maybe == NaN) and i1 < i2
// all "OQ" comparisons false if either operand NaN
#define CMP128( dt, v1, v2, i1, i2 ) \
_mm_or_p##dt( _mm_or_p##dt( _mm_cmp_p##dt( v1, v2, _CMP_GT_OQ ), /* ( v1 > v2 || */ \
_mm_andnot_p##dt( _mm_cmp_p##dt( v2, v2, _CMP_UNORD_Q ), /* ( !isnan(v2) && */ \
_mm_cmp_p##dt( v1, v1, _CMP_UNORD_Q ) /* isnan(v1) ) ) || */ \
) \
), \
_mm_and_p##dt( _mm_or_p##dt( _mm_cmp_p##dt( v1, v2, _CMP_EQ_OQ ), /* ( ( v1 == v2 || */ \
_mm_and_p##dt( _mm_cmp_p##dt( v1, v1, _CMP_UNORD_Q ), /* ( isnan(v1) && */ \
_mm_cmp_p##dt( v2, v2, _CMP_UNORD_Q ) /* isnan(v2) ) ) && */ \
) \
), \
_mm_cmp_p##dt( i1, i2, _CMP_LT_OQ ) /* i1 < i2 ) */ \
) \
);
// -----------------------------------------------------------------------------
void bli_samaxv_zen_int
@@ -122,8 +154,8 @@ void bli_samaxv_zen_int
the previous largest, save it and its index. If NaN is
encountered, then treat it the same as if it were a valid
value that was smaller than any previously seen. This
behavior mimics that of LAPACK's ?lange(). */
if ( abs_chi1_max < abs_chi1 || isnan( abs_chi1 ) )
behavior mimics that of LAPACK's i?amax(). */
if ( abs_chi1_max < abs_chi1 || ( isnan( abs_chi1 ) && !isnan( abs_chi1_max ) ) )
{
abs_chi1_max = abs_chi1;
i_max_l = i;
@@ -157,7 +189,7 @@ void bli_samaxv_zen_int
// Get the absolute value of the vector element.
x_vec.v = _mm256_andnot_ps( sign_mask.v, x_vec.v );
mask_vec.v = _mm256_cmp_ps( x_vec.v, max_vec.v, _CMP_GT_OS );
mask_vec.v = CMP256( s, x_vec.v, max_vec.v );
max_vec.v = _mm256_blendv_ps( max_vec.v, x_vec.v, mask_vec.v );
maxInx_vec.v = _mm256_blendv_ps( maxInx_vec.v, idx_vec.v, mask_vec.v );
@@ -166,33 +198,34 @@ void bli_samaxv_zen_int
x += num_vec_elements;
}
max_vec_lo.v = _mm256_extractf128_ps( max_vec.v, 0 );
max_vec_hi.v = _mm256_extractf128_ps( max_vec.v, 1 );
mask_vec_lo.v = _mm_cmp_ps( max_vec_hi.v, max_vec_lo.v, _CMP_GT_OS );
max_vec_lo.v = _mm_blendv_ps( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
max_vec_lo.v = _mm256_extractf128_ps( max_vec.v, 0 );
max_vec_hi.v = _mm256_extractf128_ps( max_vec.v, 1 );
maxInx_vec_lo.v = _mm256_extractf128_ps( maxInx_vec.v, 0 );
maxInx_vec_hi.v = _mm256_extractf128_ps( maxInx_vec.v, 1 );
maxInx_vec_lo.v = _mm_blendv_ps( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v );
max_vec_hi.v = _mm_permute_ps( max_vec_lo.v, 14 );
maxInx_vec_hi.v = _mm_permute_ps( maxInx_vec_lo.v, 14 );
mask_vec_lo.v = _mm_cmp_ps( max_vec_hi.v, max_vec_lo.v, _CMP_GT_OS );
mask_vec_lo.v = CMP128( s, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v );
max_vec_lo.v = _mm_blendv_ps( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
maxInx_vec_lo.v = _mm_blendv_ps( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v );
if ( max_vec_lo.f[0] > max_vec_lo.f[1] )
{
abs_chi1_max = max_vec_lo.f[0];
i_max_l = maxInx_vec_lo.f[0];
}
else
{
abs_chi1_max = max_vec_lo.f[1];
i_max_l = maxInx_vec_lo.f[1];
}
max_vec_hi.v = _mm_permute_ps( max_vec_lo.v, 14 );
maxInx_vec_hi.v = _mm_permute_ps( maxInx_vec_lo.v, 14 );
mask_vec_lo.v = CMP128( s, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v );
max_vec_lo.v = _mm_blendv_ps( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
maxInx_vec_lo.v = _mm_blendv_ps( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v );
max_vec_hi.v = _mm_permute_ps( max_vec_lo.v, 1 );
maxInx_vec_hi.v = _mm_permute_ps( maxInx_vec_lo.v, 1 );
mask_vec_lo.v = CMP128( s, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v );
max_vec_lo.v = _mm_blendv_ps( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
maxInx_vec_lo.v = _mm_blendv_ps( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v );
abs_chi1_max = max_vec_lo.f[0];
i_max_l = maxInx_vec_lo.f[0];
for ( i = n - n_left; i < n; i++ )
{
@@ -208,8 +241,8 @@ void bli_samaxv_zen_int
the previous largest, save it and its index. If NaN is
encountered, then treat it the same as if it were a valid
value that was smaller than any previously seen. This
behavior mimics that of LAPACK's ?lange(). */
if ( abs_chi1_max < abs_chi1 || isnan( abs_chi1 ) )
behavior mimics that of LAPACK's i?amax(). */
if ( abs_chi1_max < abs_chi1 || ( isnan( abs_chi1 ) && !isnan( abs_chi1_max ) ) )
{
abs_chi1_max = abs_chi1;
i_max_l = i;
@@ -286,8 +319,8 @@ void bli_damaxv_zen_int
the previous largest, save it and its index. If NaN is
encountered, then treat it the same as if it were a valid
value that was smaller than any previously seen. This
behavior mimics that of LAPACK's ?lange(). */
if ( abs_chi1_max < abs_chi1 || isnan( abs_chi1 ) )
behavior mimics that of LAPACK's i?amax(). */
if ( abs_chi1_max < abs_chi1 || ( isnan( abs_chi1 ) && !isnan( abs_chi1_max ) ) )
{
abs_chi1_max = abs_chi1;
i_max_l = i;
@@ -321,7 +354,7 @@ void bli_damaxv_zen_int
// Get the absolute value of the vector element.
x_vec.v = _mm256_andnot_pd( sign_mask.v, x_vec.v );
mask_vec.v = _mm256_cmp_pd( x_vec.v, max_vec.v, _CMP_GT_OS );
mask_vec.v = CMP256( d, x_vec.v, max_vec.v );
max_vec.v = _mm256_blendv_pd( max_vec.v, x_vec.v, mask_vec.v );
maxInx_vec.v = _mm256_blendv_pd( maxInx_vec.v, idx_vec.v, mask_vec.v );
@@ -330,26 +363,26 @@ void bli_damaxv_zen_int
x += num_vec_elements;
}
max_vec_lo.v = _mm256_extractf128_pd( max_vec.v, 0 );
max_vec_hi.v = _mm256_extractf128_pd( max_vec.v, 1 );
mask_vec_lo.v = _mm_cmp_pd( max_vec_hi.v, max_vec_lo.v, _CMP_GT_OS );
max_vec_lo.v = _mm_blendv_pd( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
max_vec_lo.v = _mm256_extractf128_pd( max_vec.v, 0 );
max_vec_hi.v = _mm256_extractf128_pd( max_vec.v, 1 );
maxInx_vec_lo.v = _mm256_extractf128_pd( maxInx_vec.v, 0 );
maxInx_vec_hi.v = _mm256_extractf128_pd( maxInx_vec.v, 1 );
mask_vec_lo.v = CMP128( d, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v );
max_vec_lo.v = _mm_blendv_pd( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
maxInx_vec_lo.v = _mm_blendv_pd( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v );
max_vec_hi.v = _mm_permute_pd( max_vec_lo.v, 1 );
maxInx_vec_hi.v = _mm_permute_pd( maxInx_vec_lo.v, 1 );
mask_vec_lo.v = CMP128( d, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v );
max_vec_lo.v = _mm_blendv_pd( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
maxInx_vec_lo.v = _mm_blendv_pd( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v );
if ( max_vec_lo.d[0] > max_vec_lo.d[1] )
{
abs_chi1_max = max_vec_lo.d[0];
i_max_l = maxInx_vec_lo.d[0];
}
else
{
abs_chi1_max = max_vec_lo.d[1];
i_max_l = maxInx_vec_lo.d[1];
}
abs_chi1_max = max_vec_lo.d[0];
i_max_l = maxInx_vec_lo.d[0];
for ( i = n - n_left; i < n; i++ )
{
@@ -363,10 +396,9 @@ void bli_damaxv_zen_int
/* If the absolute value of the current element exceeds that of
the previous largest, save it and its index. If NaN is
encountered, then treat it the same as if it were a valid
value that was smaller than any previously seen. This
behavior mimics that of LAPACK's ?lange(). */
if ( abs_chi1_max < abs_chi1 || isnan( abs_chi1 ) )
encountered, return the index of the first NaN. This
behavior mimics that of LAPACK's i?amax(). */
if ( abs_chi1_max < abs_chi1 || ( isnan( abs_chi1 ) && !isnan( abs_chi1_max ) ) )
{
abs_chi1_max = abs_chi1;
i_max_l = i;

View File

@@ -97,7 +97,7 @@ void PASTEMAC3(ch,opname,arch,suf) \
encountered, then treat it the same as if it were a valid
value that was smaller than any previously seen. This
behavior mimics that of LAPACK's ?lange(). */ \
if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \
if ( abs_chi1_max < abs_chi1 || ( bli_isnan( abs_chi1 ) && !bli_isnan( abs_chi1_max ) ) ) \
{ \
abs_chi1_max = abs_chi1; \
i_max_l = i; \
@@ -129,7 +129,7 @@ void PASTEMAC3(ch,opname,arch,suf) \
encountered, then treat it the same as if it were a valid
value that was smaller than any previously seen. This
behavior mimics that of LAPACK's ?lange(). */ \
if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \
if ( abs_chi1_max < abs_chi1 || ( bli_isnan( abs_chi1 ) && !bli_isnan( abs_chi1_max ) ) ) \
{ \
abs_chi1_max = abs_chi1; \
i_max_l = i; \