DAMAXV AXX512 micro kernel bug fix.

-DAMAXV AVX512 is giving wrong results when max element is present at
index in [n-u, n), where u < 32. This is a fallout of using wrong
start offset for the non-loop unrolled code.
-Functions for replacing NaN with negative numbers is replaced with
MACRO to avoid function call overhead and to remove static variables
used for stateful replacement numbers for NaN.

AMD-Internal: [CPUPL-2190]
Change-Id: Ie1435c38b264a271f869782793d0b52bbe6e1b2a
This commit is contained in:
mkadavil
2022-06-10 18:19:08 +05:30
parent 6c112632a7
commit e073e8b669

View File

@@ -80,33 +80,24 @@ typedef union
the times the function is called to ensure that
bigger numbers are assigned for nan which showed
up first.*/
__m512 remove_NAN_512_s(__m512 vec)
{
// Sign extraction mask
__m512 sign_mask;
// Temporary place to store vector's sign extracted 16xdouble word
__m512 vec_mask;
// k register to store the mask to do blend operation to remove NAN
__mmask16 vec_mask16;
// Static to preserve accross the function calls
static int iter = -1;
iter -= 1;
// Extracting sign from the vec into int_mask_vec
// Sign is -0.f in IEEE754 is just signbit set, all others 0
sign_mask = _mm512_set1_ps(-0.f);
// And with -0.f will keep just signbits, all others will be 0
vec_mask = _mm512_mul_ps(vec, sign_mask);
// Typecast mask into int type no clock cycle is taken just to
// convince compiler.
__m512i int_mask_vec = _mm512_castps_si512(vec_mask);
// Extract the signbits and put it in a 16bit mask register
vec_mask16 = _mm512_movepi32_mask(int_mask_vec);
// Swap NAN with -ve number
vec = _mm512_mask_blend_ps(vec_mask16, _mm512_set1_ps(iter), vec);
return vec;
}
#define REMOVE_NAN_512S(reg_512) \
{ \
/*Sign is -0.f in IEEE754 is just signbit set, all others 0*/ \
__m512 sign_mask = _mm512_set1_ps( -0.0f ); \
\
/* Numbers other than NAN will become 0. */ \
__m512 vec_mask = _mm512_mul_ps( reg_512, sign_mask ); \
\
/* Typecast mask into int type no clock cycle is taken just to
* convince compiler. */ \
__m512i int_mask_vec = _mm512_castps_si512( vec_mask ); \
/* Extract the signbits and put it in a 16bit mask register. */ \
__mmask16 vec_mask16 = _mm512_movepi32_mask( int_mask_vec ); \
\
/* Swap NAN with -ve number. */ \
reg_512 = _mm512_mask_blend_ps( vec_mask16, _mm512_set1_ps( nan_repl ), reg_512 ); \
nan_repl = nan_repl - 1; \
}
// return a mask which indicates either:
// - v1 > v2
@@ -151,10 +142,14 @@ void bli_samaxv_zen_int_avx512(
// *minus_one = -1
float *minus_one = PASTEMAC(s, m1); // bli_sm1()
// *zero_i = 0
dim_t *zero_i = PASTEMAC(i, 0); // bli_i0()
dim_t *zero_i = PASTEMAC(i, 0); // bli_i0()
// Used to replace NAN in registers. This value is decremented each time
// remove NAN is applied so as to keep the NAN value replacements unique.
float nan_repl = -1.0;
float fndMaxVal; // Max value will be stored in this
dim_t fndInd; // Max value's index will be stored in this
dim_t fndInd; // Max value's index will be stored in this
// Iterator for loops to keep continuity throughout the loops
dim_t i;
@@ -246,7 +241,7 @@ void bli_samaxv_zen_int_avx512(
// max_vector = abs(max_vector)
max_vec_1.v = _mm512_andnot_ps(abs_mask.v, max_vec_1.v);
// Remove nan and replace with -ve values
max_vec_1.v = remove_NAN_512_s(max_vec_1.v);
REMOVE_NAN_512S(max_vec_1.v);
// Increment x vector as we have loaded 16 values
x += num_vector_elements;
@@ -254,7 +249,7 @@ void bli_samaxv_zen_int_avx512(
maxInd_vec_1.v = _mm512_setr_ps(0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15);
int i = 1;
dim_t i = 1;
for (; (i + 4) < num_iter; i += 5)
{
/*
@@ -276,7 +271,7 @@ void bli_samaxv_zen_int_avx512(
// Increment x vector as we have loaded 16 values
x += num_vector_elements;
// Remove nan and replace with -ve values
x_vec_1.v = remove_NAN_512_s(x_vec_1.v);
REMOVE_NAN_512S(x_vec_1.v);
// Mask Generation of 1st(can be previous max) and 2nd element
// mask = max_vector - x_vec_1
@@ -295,7 +290,7 @@ void bli_samaxv_zen_int_avx512(
// max_vec_2 = abs(max_vec_2)
max_vec_2.v = _mm512_andnot_ps(abs_mask.v, max_vec_2.v);
// Remove nan and replace with -ve values
max_vec_2.v = remove_NAN_512_s(max_vec_2.v);
REMOVE_NAN_512S(max_vec_2.v);
// Increment x vector as we have loaded 16 values
x += num_vector_elements;
// Increment the index vector to point to next indexes.
@@ -306,7 +301,7 @@ void bli_samaxv_zen_int_avx512(
// x_vec_2 = abs(x_vec_2)
x_vec_2.v = _mm512_andnot_ps(abs_mask.v, x_vec_2.v);
// Remove nan and replace with -ve values
x_vec_2.v = remove_NAN_512_s(x_vec_2.v);
REMOVE_NAN_512S(x_vec_2.v);
// Increment x vector as we have loaded 16 values
x += num_vector_elements;
// Increment the index vector to point to next indexes.
@@ -329,7 +324,7 @@ void bli_samaxv_zen_int_avx512(
// max_vec_3 = abs(max_vec_3)
max_vec_3.v = _mm512_andnot_ps(abs_mask.v, max_vec_3.v);
// Remove nan and replace with -ve values
max_vec_3.v = remove_NAN_512_s(max_vec_3.v);
REMOVE_NAN_512S(max_vec_3.v);
// Increment x vector as we have loaded 16 values
x += num_vector_elements;
// Increment the index vector to point to next indexes.
@@ -339,7 +334,7 @@ void bli_samaxv_zen_int_avx512(
// x_vec_3 = abs(x_vec_3)
x_vec_3.v = _mm512_andnot_ps(abs_mask.v, x_vec_3.v);
// Remove nan and replace with -ve values
x_vec_3.v = remove_NAN_512_s(x_vec_3.v);
REMOVE_NAN_512S(x_vec_3.v);
// Increment x vector as we have loaded 16 values
x += num_vector_elements;
// Increment the index vector to point to next indexes.
@@ -468,7 +463,7 @@ void bli_samaxv_zen_int_avx512(
// x_vec_1 = abs(x_vec_1)
x_vec_1.v = _mm512_andnot_ps(abs_mask.v, x_vec_1.v);
// Remove nan and replace with -ve values
x_vec_1.v = remove_NAN_512_s(x_vec_1.v);
REMOVE_NAN_512S(x_vec_1.v);
// Mask Generation
// mask = max_vec_1 - x_vec_1
@@ -618,7 +613,7 @@ void bli_samaxv_zen_int_avx512(
fndMaxVal = NAN;
}
// Finish off the remaining values using normal instructions
for (int i = n - num_remain; i < n; i++)
for (dim_t i = n - num_remain; i < n; i++)
{
float absval = fabsf(*(x));
if (fndMaxVal < absval || (isnan(absval) && !isnan(fndMaxVal)))
@@ -643,32 +638,21 @@ void bli_samaxv_zen_int_avx512(
// -----------------------------------------------------------------------------
/* Converts all the NAN to a negative number less than previously encountered NANs*/
__m512d remove_NAN_512d(__m512d vec)
{
static int iter;
static __m512d sign_mask;
__m512d vec_mask;
__m512i int_mask_vec;
__mmask8 vec_mask8;
iter = iter - 1;
sign_mask = _mm512_set1_pd(-0.f);
//numbers other than NAN will become 0
vec_mask = _mm512_mul_pd(vec, sign_mask);
//producing an 8-bit mask
int_mask_vec = _mm512_castpd_si512(vec_mask);
vec_mask8 = _mm512_movepi64_mask(int_mask_vec);
//replacing all the NAN with negative numbers
vec = _mm512_mask_blend_pd(vec_mask8, _mm512_set1_pd(-1 + iter), vec);
return vec;
}
#define REMOVE_NAN_512D(reg_512) \
{ \
__m512d sign_mask = _mm512_set1_pd( -0.0f ); \
\
/* Numbers other than NAN will become 0. */ \
__m512d vec_mask = _mm512_mul_pd( reg_512, sign_mask ); \
\
/* Producing an 8-bit mask. */ \
__m512i int_mask_vec = _mm512_castpd_si512( vec_mask ); \
__mmask8 vec_mask8 = _mm512_movepi64_mask( int_mask_vec ); \
\
/* Replacing all the NAN with negative numbers. */ \
reg_512 = _mm512_mask_blend_pd( vec_mask8, _mm512_set1_pd( nan_repl ), reg_512 ); \
nan_repl = nan_repl - 1; \
}
//----------------------------------------------------------------------------------------------------
void bli_damaxv_zen_int_avx512(
@@ -679,6 +663,11 @@ void bli_damaxv_zen_int_avx512(
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3)
double *minus_one = PASTEMAC(d, m1);
// Used to replace NAN in registers. This value is decremented each time
// remove NAN is applied so as to keep the NAN value replacements unique.
double nan_repl = -1.0;
dim_t *zero_i = PASTEMAC(i, 0);
double chi1_r;
@@ -776,7 +765,7 @@ void bli_damaxv_zen_int_avx512(
// Taking the absolute value and removing the NAN
max_array.v = _mm512_andnot_pd(sign_mask, max_array.v);
max_array.v = remove_NAN_512d(max_array.v);
REMOVE_NAN_512D(max_array.v);
// Initializing the maximumum index
max_ind.v = _mm512_set_pd(7, 6, 5, 4, 3, 2, 1, 0);
@@ -786,7 +775,7 @@ void bli_damaxv_zen_int_avx512(
//to point to the next 8 elements
zmm4_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v);
/* Loop unrolled by a factor of 4
/* Loop unrolled by a factor of 4
At the end of the loop max_array holds the largest element
in each corresponding vector index */
for (unrollCount = 8; (unrollCount + 31) < n; unrollCount += 32)
@@ -797,25 +786,25 @@ void bli_damaxv_zen_int_avx512(
// with negative numbers
zmm0.v = _mm512_loadu_pd(x);
zmm0.v = _mm512_andnot_pd(sign_mask, zmm0.v);
zmm0.v = remove_NAN_512d(zmm0.v);
REMOVE_NAN_512D(zmm0.v);
x += vector_length;
zmm1.v = _mm512_loadu_pd(x);
zmm5_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v);
zmm1.v = _mm512_andnot_pd(sign_mask, zmm1.v);
zmm1.v = remove_NAN_512d(zmm1.v);
REMOVE_NAN_512D(zmm1.v);
x += vector_length;
zmm2.v = _mm512_loadu_pd(x);
zmm6_Ind.v = _mm512_add_pd(zmm5_Ind.v, inc_vec.v);
zmm2.v = _mm512_andnot_pd(sign_mask, zmm2.v);
zmm2.v = remove_NAN_512d(zmm2.v);
REMOVE_NAN_512D(zmm2.v);
x += vector_length;
zmm3.v = _mm512_loadu_pd(x);
zmm7_Ind.v = _mm512_add_pd(zmm6_Ind.v, inc_vec.v);
zmm3.v = _mm512_andnot_pd(sign_mask, zmm3.v);
zmm3.v = remove_NAN_512d(zmm3.v);
REMOVE_NAN_512D(zmm3.v);
x += vector_length;
/*Using sub function to generating the mask
@@ -872,7 +861,7 @@ void bli_damaxv_zen_int_avx512(
/* At the end of the loop max_array holds the largest element
in each corresponding vector index */
for (int i = 1; i < iterations; ++i)
for (dim_t i = 0; i < iterations; ++i)
{
// Taking 32 elements
// Taking only the absolute values of the registers
@@ -880,7 +869,7 @@ void bli_damaxv_zen_int_avx512(
// with negative numbers
zmm0.v = _mm512_loadu_pd(x);
zmm0.v = _mm512_abs_pd(zmm0.v);
zmm0.v = remove_NAN_512d(zmm0.v);
REMOVE_NAN_512D(zmm0.v);
//Generating mask for the intermediate max vector
mask_01 = _mm512_sub_pd(max_array.v, zmm0.v);
@@ -968,6 +957,12 @@ void bli_damaxv_zen_int_avx512(
}
}
// Issue vzeroupper instruction to clear upper lanes of ymm registers.
// This avoids a performance penalty caused by false dependencies when
// transitioning from from AVX to SSE instructions (which may occur
// later, especially if BLIS is compiled with -mfpmath=sse).
_mm256_zeroupper();
// Return value
*i_max = i_max_l;