mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Fix vectorized version of bli_amaxv (#382)
* Fix vectorized version of bli_amaxv To match Netlib, i?amax should return: - the lowest index among equal values - the first NaN if one is encountered * Fix typos. * And another one... * Update ref. amaxv kernel too. * Re-enabled optimized amaxv kernels. Details: - Re-enabled the optimized, intrinsics-based amaxv kernels in the 'zen' kernel set for use in haswell, zen, zen2, knl, and skx subconfigs. These two kernels (for s and d datatypes) were temporarily disabled ine186d71as part of issue #380. However, the key missing semantic properties that prompted the disabling of these kernels--returning the index of the *first* rather than of the last element with largest absolute value, and returning the index of the first NaN if one is encountered--were added as part of #382 thanks to Devin Matthews. Thus, now that the kernels are working as expected once more, this commit causes these kernels to once again be registered for the affected subconfigs, which effectively reverts all code changes included ine186d71. - Whitespace/formatting updates to new macros in bli_amaxv_zen_int.c. Co-authored-by: Field G. Van Zee <field@cs.utexas.edu>
This commit is contained in:
@@ -89,10 +89,8 @@ void bli_cntx_init_haswell( cntx_t* cntx )
|
||||
// Update the context with optimized level-1v kernels.
|
||||
bli_cntx_set_l1v_kers
|
||||
(
|
||||
8,
|
||||
#if 0
|
||||
// NOTE: Disabled vectorized amaxv kernels due to incorrect semantics.
|
||||
// See issue #380 for more details.
|
||||
10,
|
||||
#if 1
|
||||
// amaxv
|
||||
BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int,
|
||||
BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int,
|
||||
|
||||
@@ -78,10 +78,8 @@ void bli_cntx_init_knl( cntx_t* cntx )
|
||||
// Update the context with optimized level-1v kernels.
|
||||
bli_cntx_set_l1v_kers
|
||||
(
|
||||
8,
|
||||
#if 0
|
||||
// NOTE: Disabled vectorized amaxv kernels due to incorrect semantics.
|
||||
// See issue #380 for more details.
|
||||
10,
|
||||
#if 1
|
||||
// amaxv
|
||||
BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int,
|
||||
BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int,
|
||||
|
||||
@@ -149,10 +149,8 @@ void bli_cntx_init_haswell( cntx_t* cntx )
|
||||
// Update the context with optimized level-1v kernels.
|
||||
bli_cntx_set_l1v_kers
|
||||
(
|
||||
8,
|
||||
#if 0
|
||||
// NOTE: Disabled vectorized amaxv kernels due to incorrect semantics.
|
||||
// See issue #380 for more details.
|
||||
10,
|
||||
#if 1
|
||||
// amaxv
|
||||
BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int,
|
||||
BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int,
|
||||
|
||||
@@ -70,10 +70,8 @@ void bli_cntx_init_skx( cntx_t* cntx )
|
||||
// Update the context with optimized level-1v kernels.
|
||||
bli_cntx_set_l1v_kers
|
||||
(
|
||||
8,
|
||||
#if 0
|
||||
// NOTE: Disabled vectorized amaxv kernels due to incorrect semantics.
|
||||
// See issue #380 for more details.
|
||||
10,
|
||||
#if 1
|
||||
// amaxv
|
||||
BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int,
|
||||
BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int,
|
||||
|
||||
@@ -82,10 +82,8 @@ void bli_cntx_init_zen( cntx_t* cntx )
|
||||
// Update the context with optimized level-1v kernels.
|
||||
bli_cntx_set_l1v_kers
|
||||
(
|
||||
8,
|
||||
#if 0
|
||||
// NOTE: Disabled vectorized amaxv kernels due to incorrect semantics.
|
||||
// See issue #380 for more details.
|
||||
10,
|
||||
#if 1
|
||||
// amaxv
|
||||
BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int,
|
||||
BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int,
|
||||
|
||||
@@ -79,10 +79,8 @@ void bli_cntx_init_zen2( cntx_t* cntx )
|
||||
// Update the context with optimized level-1v kernels.
|
||||
bli_cntx_set_l1v_kers
|
||||
(
|
||||
8,
|
||||
#if 0
|
||||
// NOTE: Disabled vectorized amaxv kernels due to incorrect semantics.
|
||||
// See issue #380 for more details.
|
||||
10,
|
||||
#if 1
|
||||
// amaxv
|
||||
BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int,
|
||||
BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int,
|
||||
|
||||
@@ -65,6 +65,38 @@ typedef union
|
||||
double d[2];
|
||||
}v2dd_t;
|
||||
|
||||
// return a mask which indicates either:
|
||||
// - v1 > v2
|
||||
// - v1 is NaN and v2 is not
|
||||
// assumes that idx(v1) > idx(v2)
|
||||
// all "OQ" comparisons false if either operand NaN
|
||||
#define CMP256( dt, v1, v2 ) \
|
||||
_mm256_or_p##dt( _mm256_cmp_p##dt( v1, v2, _CMP_GT_OQ ), /* v1 > v2 || */ \
|
||||
_mm256_andnot_p##dt( _mm256_cmp_p##dt( v2, v2, _CMP_UNORD_Q ), /* ( !isnan(v2) && */ \
|
||||
_mm256_cmp_p##dt( v1, v1, _CMP_UNORD_Q ) /* isnan(v1) ) */ \
|
||||
) \
|
||||
);
|
||||
|
||||
// return a mask which indicates either:
|
||||
// - v1 > v2
|
||||
// - v1 is NaN and v2 is not
|
||||
// - v1 == v2 (maybe == NaN) and i1 < i2
|
||||
// all "OQ" comparisons false if either operand NaN
|
||||
#define CMP128( dt, v1, v2, i1, i2 ) \
|
||||
_mm_or_p##dt( _mm_or_p##dt( _mm_cmp_p##dt( v1, v2, _CMP_GT_OQ ), /* ( v1 > v2 || */ \
|
||||
_mm_andnot_p##dt( _mm_cmp_p##dt( v2, v2, _CMP_UNORD_Q ), /* ( !isnan(v2) && */ \
|
||||
_mm_cmp_p##dt( v1, v1, _CMP_UNORD_Q ) /* isnan(v1) ) ) || */ \
|
||||
) \
|
||||
), \
|
||||
_mm_and_p##dt( _mm_or_p##dt( _mm_cmp_p##dt( v1, v2, _CMP_EQ_OQ ), /* ( ( v1 == v2 || */ \
|
||||
_mm_and_p##dt( _mm_cmp_p##dt( v1, v1, _CMP_UNORD_Q ), /* ( isnan(v1) && */ \
|
||||
_mm_cmp_p##dt( v2, v2, _CMP_UNORD_Q ) /* isnan(v2) ) ) && */ \
|
||||
) \
|
||||
), \
|
||||
_mm_cmp_p##dt( i1, i2, _CMP_LT_OQ ) /* i1 < i2 ) */ \
|
||||
) \
|
||||
);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
void bli_samaxv_zen_int
|
||||
@@ -122,8 +154,8 @@ void bli_samaxv_zen_int
|
||||
the previous largest, save it and its index. If NaN is
|
||||
encountered, then treat it the same as if it were a valid
|
||||
value that was smaller than any previously seen. This
|
||||
behavior mimics that of LAPACK's ?lange(). */
|
||||
if ( abs_chi1_max < abs_chi1 || isnan( abs_chi1 ) )
|
||||
behavior mimics that of LAPACK's i?amax(). */
|
||||
if ( abs_chi1_max < abs_chi1 || ( isnan( abs_chi1 ) && !isnan( abs_chi1_max ) ) )
|
||||
{
|
||||
abs_chi1_max = abs_chi1;
|
||||
i_max_l = i;
|
||||
@@ -157,7 +189,7 @@ void bli_samaxv_zen_int
|
||||
// Get the absolute value of the vector element.
|
||||
x_vec.v = _mm256_andnot_ps( sign_mask.v, x_vec.v );
|
||||
|
||||
mask_vec.v = _mm256_cmp_ps( x_vec.v, max_vec.v, _CMP_GT_OS );
|
||||
mask_vec.v = CMP256( s, x_vec.v, max_vec.v );
|
||||
|
||||
max_vec.v = _mm256_blendv_ps( max_vec.v, x_vec.v, mask_vec.v );
|
||||
maxInx_vec.v = _mm256_blendv_ps( maxInx_vec.v, idx_vec.v, mask_vec.v );
|
||||
@@ -166,33 +198,34 @@ void bli_samaxv_zen_int
|
||||
x += num_vec_elements;
|
||||
}
|
||||
|
||||
max_vec_lo.v = _mm256_extractf128_ps( max_vec.v, 0 );
|
||||
max_vec_hi.v = _mm256_extractf128_ps( max_vec.v, 1 );
|
||||
mask_vec_lo.v = _mm_cmp_ps( max_vec_hi.v, max_vec_lo.v, _CMP_GT_OS );
|
||||
|
||||
max_vec_lo.v = _mm_blendv_ps( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
|
||||
|
||||
max_vec_lo.v = _mm256_extractf128_ps( max_vec.v, 0 );
|
||||
max_vec_hi.v = _mm256_extractf128_ps( max_vec.v, 1 );
|
||||
maxInx_vec_lo.v = _mm256_extractf128_ps( maxInx_vec.v, 0 );
|
||||
maxInx_vec_hi.v = _mm256_extractf128_ps( maxInx_vec.v, 1 );
|
||||
maxInx_vec_lo.v = _mm_blendv_ps( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v );
|
||||
|
||||
max_vec_hi.v = _mm_permute_ps( max_vec_lo.v, 14 );
|
||||
maxInx_vec_hi.v = _mm_permute_ps( maxInx_vec_lo.v, 14 );
|
||||
mask_vec_lo.v = _mm_cmp_ps( max_vec_hi.v, max_vec_lo.v, _CMP_GT_OS );
|
||||
|
||||
mask_vec_lo.v = CMP128( s, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v );
|
||||
|
||||
max_vec_lo.v = _mm_blendv_ps( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
|
||||
maxInx_vec_lo.v = _mm_blendv_ps( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v );
|
||||
|
||||
if ( max_vec_lo.f[0] > max_vec_lo.f[1] )
|
||||
{
|
||||
abs_chi1_max = max_vec_lo.f[0];
|
||||
i_max_l = maxInx_vec_lo.f[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
abs_chi1_max = max_vec_lo.f[1];
|
||||
i_max_l = maxInx_vec_lo.f[1];
|
||||
}
|
||||
max_vec_hi.v = _mm_permute_ps( max_vec_lo.v, 14 );
|
||||
maxInx_vec_hi.v = _mm_permute_ps( maxInx_vec_lo.v, 14 );
|
||||
|
||||
mask_vec_lo.v = CMP128( s, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v );
|
||||
|
||||
max_vec_lo.v = _mm_blendv_ps( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
|
||||
maxInx_vec_lo.v = _mm_blendv_ps( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v );
|
||||
|
||||
max_vec_hi.v = _mm_permute_ps( max_vec_lo.v, 1 );
|
||||
maxInx_vec_hi.v = _mm_permute_ps( maxInx_vec_lo.v, 1 );
|
||||
|
||||
mask_vec_lo.v = CMP128( s, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v );
|
||||
|
||||
max_vec_lo.v = _mm_blendv_ps( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
|
||||
maxInx_vec_lo.v = _mm_blendv_ps( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v );
|
||||
|
||||
abs_chi1_max = max_vec_lo.f[0];
|
||||
i_max_l = maxInx_vec_lo.f[0];
|
||||
|
||||
for ( i = n - n_left; i < n; i++ )
|
||||
{
|
||||
@@ -208,8 +241,8 @@ void bli_samaxv_zen_int
|
||||
the previous largest, save it and its index. If NaN is
|
||||
encountered, then treat it the same as if it were a valid
|
||||
value that was smaller than any previously seen. This
|
||||
behavior mimics that of LAPACK's ?lange(). */
|
||||
if ( abs_chi1_max < abs_chi1 || isnan( abs_chi1 ) )
|
||||
behavior mimics that of LAPACK's i?amax(). */
|
||||
if ( abs_chi1_max < abs_chi1 || ( isnan( abs_chi1 ) && !isnan( abs_chi1_max ) ) )
|
||||
{
|
||||
abs_chi1_max = abs_chi1;
|
||||
i_max_l = i;
|
||||
@@ -286,8 +319,8 @@ void bli_damaxv_zen_int
|
||||
the previous largest, save it and its index. If NaN is
|
||||
encountered, then treat it the same as if it were a valid
|
||||
value that was smaller than any previously seen. This
|
||||
behavior mimics that of LAPACK's ?lange(). */
|
||||
if ( abs_chi1_max < abs_chi1 || isnan( abs_chi1 ) )
|
||||
behavior mimics that of LAPACK's i?amax(). */
|
||||
if ( abs_chi1_max < abs_chi1 || ( isnan( abs_chi1 ) && !isnan( abs_chi1_max ) ) )
|
||||
{
|
||||
abs_chi1_max = abs_chi1;
|
||||
i_max_l = i;
|
||||
@@ -321,7 +354,7 @@ void bli_damaxv_zen_int
|
||||
// Get the absolute value of the vector element.
|
||||
x_vec.v = _mm256_andnot_pd( sign_mask.v, x_vec.v );
|
||||
|
||||
mask_vec.v = _mm256_cmp_pd( x_vec.v, max_vec.v, _CMP_GT_OS );
|
||||
mask_vec.v = CMP256( d, x_vec.v, max_vec.v );
|
||||
|
||||
max_vec.v = _mm256_blendv_pd( max_vec.v, x_vec.v, mask_vec.v );
|
||||
maxInx_vec.v = _mm256_blendv_pd( maxInx_vec.v, idx_vec.v, mask_vec.v );
|
||||
@@ -330,26 +363,26 @@ void bli_damaxv_zen_int
|
||||
x += num_vec_elements;
|
||||
}
|
||||
|
||||
max_vec_lo.v = _mm256_extractf128_pd( max_vec.v, 0 );
|
||||
max_vec_hi.v = _mm256_extractf128_pd( max_vec.v, 1 );
|
||||
mask_vec_lo.v = _mm_cmp_pd( max_vec_hi.v, max_vec_lo.v, _CMP_GT_OS );
|
||||
|
||||
max_vec_lo.v = _mm_blendv_pd( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
|
||||
|
||||
max_vec_lo.v = _mm256_extractf128_pd( max_vec.v, 0 );
|
||||
max_vec_hi.v = _mm256_extractf128_pd( max_vec.v, 1 );
|
||||
maxInx_vec_lo.v = _mm256_extractf128_pd( maxInx_vec.v, 0 );
|
||||
maxInx_vec_hi.v = _mm256_extractf128_pd( maxInx_vec.v, 1 );
|
||||
|
||||
mask_vec_lo.v = CMP128( d, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v );
|
||||
|
||||
max_vec_lo.v = _mm_blendv_pd( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
|
||||
maxInx_vec_lo.v = _mm_blendv_pd( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v );
|
||||
|
||||
max_vec_hi.v = _mm_permute_pd( max_vec_lo.v, 1 );
|
||||
maxInx_vec_hi.v = _mm_permute_pd( maxInx_vec_lo.v, 1 );
|
||||
|
||||
mask_vec_lo.v = CMP128( d, max_vec_hi.v, max_vec_lo.v, maxInx_vec_hi.v, maxInx_vec_lo.v );
|
||||
|
||||
max_vec_lo.v = _mm_blendv_pd( max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v );
|
||||
maxInx_vec_lo.v = _mm_blendv_pd( maxInx_vec_lo.v, maxInx_vec_hi.v, mask_vec_lo.v );
|
||||
|
||||
if ( max_vec_lo.d[0] > max_vec_lo.d[1] )
|
||||
{
|
||||
abs_chi1_max = max_vec_lo.d[0];
|
||||
i_max_l = maxInx_vec_lo.d[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
abs_chi1_max = max_vec_lo.d[1];
|
||||
i_max_l = maxInx_vec_lo.d[1];
|
||||
}
|
||||
abs_chi1_max = max_vec_lo.d[0];
|
||||
i_max_l = maxInx_vec_lo.d[0];
|
||||
|
||||
for ( i = n - n_left; i < n; i++ )
|
||||
{
|
||||
@@ -363,10 +396,9 @@ void bli_damaxv_zen_int
|
||||
|
||||
/* If the absolute value of the current element exceeds that of
|
||||
the previous largest, save it and its index. If NaN is
|
||||
encountered, then treat it the same as if it were a valid
|
||||
value that was smaller than any previously seen. This
|
||||
behavior mimics that of LAPACK's ?lange(). */
|
||||
if ( abs_chi1_max < abs_chi1 || isnan( abs_chi1 ) )
|
||||
encountered, return the index of the first NaN. This
|
||||
behavior mimics that of LAPACK's i?amax(). */
|
||||
if ( abs_chi1_max < abs_chi1 || ( isnan( abs_chi1 ) && !isnan( abs_chi1_max ) ) )
|
||||
{
|
||||
abs_chi1_max = abs_chi1;
|
||||
i_max_l = i;
|
||||
|
||||
@@ -97,7 +97,7 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
encountered, then treat it the same as if it were a valid
|
||||
value that was smaller than any previously seen. This
|
||||
behavior mimics that of LAPACK's ?lange(). */ \
|
||||
if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \
|
||||
if ( abs_chi1_max < abs_chi1 || ( bli_isnan( abs_chi1 ) && !bli_isnan( abs_chi1_max ) ) ) \
|
||||
{ \
|
||||
abs_chi1_max = abs_chi1; \
|
||||
i_max_l = i; \
|
||||
@@ -129,7 +129,7 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
encountered, then treat it the same as if it were a valid
|
||||
value that was smaller than any previously seen. This
|
||||
behavior mimics that of LAPACK's ?lange(). */ \
|
||||
if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \
|
||||
if ( abs_chi1_max < abs_chi1 || ( bli_isnan( abs_chi1 ) && !bli_isnan( abs_chi1_max ) ) ) \
|
||||
{ \
|
||||
abs_chi1_max = abs_chi1; \
|
||||
i_max_l = i; \
|
||||
|
||||
Reference in New Issue
Block a user