mirror of
https://github.com/amd/blis.git
synced 2026-05-11 09:39:59 +00:00
DAMAXV AXX512 micro kernel bug fix.
-DAMAXV AVX512 is giving wrong results when max element is present at index in [n-u, n), where u < 32. This is a fallout of using wrong start offset for the non-loop unrolled code. -Functions for replacing NaN with negative numbers is replaced with MACRO to avoid function call overhead and to remove static variables used for stateful replacement numbers for NaN. AMD-Internal: [CPUPL-2190] Change-Id: Ie1435c38b264a271f869782793d0b52bbe6e1b2a
This commit is contained in:
@@ -80,33 +80,24 @@ typedef union
|
||||
the times the function is called to ensure that
|
||||
bigger numbers are assigned for nan which showed
|
||||
up first.*/
|
||||
__m512 remove_NAN_512_s(__m512 vec)
|
||||
{
|
||||
// Sign extraction mask
|
||||
__m512 sign_mask;
|
||||
// Temporary place to store vector's sign extracted 16xdouble word
|
||||
__m512 vec_mask;
|
||||
// k register to store the mask to do blend operation to remove NAN
|
||||
__mmask16 vec_mask16;
|
||||
// Static to preserve accross the function calls
|
||||
static int iter = -1;
|
||||
iter -= 1;
|
||||
|
||||
// Extracting sign from the vec into int_mask_vec
|
||||
// Sign is -0.f in IEEE754 is just signbit set, all others 0
|
||||
sign_mask = _mm512_set1_ps(-0.f);
|
||||
// And with -0.f will keep just signbits, all others will be 0
|
||||
vec_mask = _mm512_mul_ps(vec, sign_mask);
|
||||
// Typecast mask into int type no clock cycle is taken just to
|
||||
// convince compiler.
|
||||
__m512i int_mask_vec = _mm512_castps_si512(vec_mask);
|
||||
// Extract the signbits and put it in a 16bit mask register
|
||||
vec_mask16 = _mm512_movepi32_mask(int_mask_vec);
|
||||
|
||||
// Swap NAN with -ve number
|
||||
vec = _mm512_mask_blend_ps(vec_mask16, _mm512_set1_ps(iter), vec);
|
||||
return vec;
|
||||
}
|
||||
#define REMOVE_NAN_512S(reg_512) \
|
||||
{ \
|
||||
/*Sign is -0.f in IEEE754 is just signbit set, all others 0*/ \
|
||||
__m512 sign_mask = _mm512_set1_ps( -0.0f ); \
|
||||
\
|
||||
/* Numbers other than NAN will become 0. */ \
|
||||
__m512 vec_mask = _mm512_mul_ps( reg_512, sign_mask ); \
|
||||
\
|
||||
/* Typecast mask into int type no clock cycle is taken just to
|
||||
* convince compiler. */ \
|
||||
__m512i int_mask_vec = _mm512_castps_si512( vec_mask ); \
|
||||
/* Extract the signbits and put it in a 16bit mask register. */ \
|
||||
__mmask16 vec_mask16 = _mm512_movepi32_mask( int_mask_vec ); \
|
||||
\
|
||||
/* Swap NAN with -ve number. */ \
|
||||
reg_512 = _mm512_mask_blend_ps( vec_mask16, _mm512_set1_ps( nan_repl ), reg_512 ); \
|
||||
nan_repl = nan_repl - 1; \
|
||||
}
|
||||
|
||||
// return a mask which indicates either:
|
||||
// - v1 > v2
|
||||
@@ -151,10 +142,14 @@ void bli_samaxv_zen_int_avx512(
|
||||
// *minus_one = -1
|
||||
float *minus_one = PASTEMAC(s, m1); // bli_sm1()
|
||||
// *zero_i = 0
|
||||
dim_t *zero_i = PASTEMAC(i, 0); // bli_i0()
|
||||
dim_t *zero_i = PASTEMAC(i, 0); // bli_i0()
|
||||
|
||||
// Used to replace NAN in registers. This value is decremented each time
|
||||
// remove NAN is applied so as to keep the NAN value replacements unique.
|
||||
float nan_repl = -1.0;
|
||||
|
||||
float fndMaxVal; // Max value will be stored in this
|
||||
dim_t fndInd; // Max value's index will be stored in this
|
||||
dim_t fndInd; // Max value's index will be stored in this
|
||||
// Iterator for loops to keep continuity throughout the loops
|
||||
dim_t i;
|
||||
|
||||
@@ -246,7 +241,7 @@ void bli_samaxv_zen_int_avx512(
|
||||
// max_vector = abs(max_vector)
|
||||
max_vec_1.v = _mm512_andnot_ps(abs_mask.v, max_vec_1.v);
|
||||
// Remove nan and replace with -ve values
|
||||
max_vec_1.v = remove_NAN_512_s(max_vec_1.v);
|
||||
REMOVE_NAN_512S(max_vec_1.v);
|
||||
|
||||
// Increment x vector as we have loaded 16 values
|
||||
x += num_vector_elements;
|
||||
@@ -254,7 +249,7 @@ void bli_samaxv_zen_int_avx512(
|
||||
maxInd_vec_1.v = _mm512_setr_ps(0, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
9, 10, 11, 12, 13, 14, 15);
|
||||
|
||||
int i = 1;
|
||||
dim_t i = 1;
|
||||
for (; (i + 4) < num_iter; i += 5)
|
||||
{
|
||||
/*
|
||||
@@ -276,7 +271,7 @@ void bli_samaxv_zen_int_avx512(
|
||||
// Increment x vector as we have loaded 16 values
|
||||
x += num_vector_elements;
|
||||
// Remove nan and replace with -ve values
|
||||
x_vec_1.v = remove_NAN_512_s(x_vec_1.v);
|
||||
REMOVE_NAN_512S(x_vec_1.v);
|
||||
|
||||
// Mask Generation of 1st(can be previous max) and 2nd element
|
||||
// mask = max_vector - x_vec_1
|
||||
@@ -295,7 +290,7 @@ void bli_samaxv_zen_int_avx512(
|
||||
// max_vec_2 = abs(max_vec_2)
|
||||
max_vec_2.v = _mm512_andnot_ps(abs_mask.v, max_vec_2.v);
|
||||
// Remove nan and replace with -ve values
|
||||
max_vec_2.v = remove_NAN_512_s(max_vec_2.v);
|
||||
REMOVE_NAN_512S(max_vec_2.v);
|
||||
// Increment x vector as we have loaded 16 values
|
||||
x += num_vector_elements;
|
||||
// Increment the index vector to point to next indexes.
|
||||
@@ -306,7 +301,7 @@ void bli_samaxv_zen_int_avx512(
|
||||
// x_vec_2 = abs(x_vec_2)
|
||||
x_vec_2.v = _mm512_andnot_ps(abs_mask.v, x_vec_2.v);
|
||||
// Remove nan and replace with -ve values
|
||||
x_vec_2.v = remove_NAN_512_s(x_vec_2.v);
|
||||
REMOVE_NAN_512S(x_vec_2.v);
|
||||
// Increment x vector as we have loaded 16 values
|
||||
x += num_vector_elements;
|
||||
// Increment the index vector to point to next indexes.
|
||||
@@ -329,7 +324,7 @@ void bli_samaxv_zen_int_avx512(
|
||||
// max_vec_3 = abs(max_vec_3)
|
||||
max_vec_3.v = _mm512_andnot_ps(abs_mask.v, max_vec_3.v);
|
||||
// Remove nan and replace with -ve values
|
||||
max_vec_3.v = remove_NAN_512_s(max_vec_3.v);
|
||||
REMOVE_NAN_512S(max_vec_3.v);
|
||||
// Increment x vector as we have loaded 16 values
|
||||
x += num_vector_elements;
|
||||
// Increment the index vector to point to next indexes.
|
||||
@@ -339,7 +334,7 @@ void bli_samaxv_zen_int_avx512(
|
||||
// x_vec_3 = abs(x_vec_3)
|
||||
x_vec_3.v = _mm512_andnot_ps(abs_mask.v, x_vec_3.v);
|
||||
// Remove nan and replace with -ve values
|
||||
x_vec_3.v = remove_NAN_512_s(x_vec_3.v);
|
||||
REMOVE_NAN_512S(x_vec_3.v);
|
||||
// Increment x vector as we have loaded 16 values
|
||||
x += num_vector_elements;
|
||||
// Increment the index vector to point to next indexes.
|
||||
@@ -468,7 +463,7 @@ void bli_samaxv_zen_int_avx512(
|
||||
// x_vec_1 = abs(x_vec_1)
|
||||
x_vec_1.v = _mm512_andnot_ps(abs_mask.v, x_vec_1.v);
|
||||
// Remove nan and replace with -ve values
|
||||
x_vec_1.v = remove_NAN_512_s(x_vec_1.v);
|
||||
REMOVE_NAN_512S(x_vec_1.v);
|
||||
|
||||
// Mask Generation
|
||||
// mask = max_vec_1 - x_vec_1
|
||||
@@ -618,7 +613,7 @@ void bli_samaxv_zen_int_avx512(
|
||||
fndMaxVal = NAN;
|
||||
}
|
||||
// Finish off the remaining values using normal instructions
|
||||
for (int i = n - num_remain; i < n; i++)
|
||||
for (dim_t i = n - num_remain; i < n; i++)
|
||||
{
|
||||
float absval = fabsf(*(x));
|
||||
if (fndMaxVal < absval || (isnan(absval) && !isnan(fndMaxVal)))
|
||||
@@ -643,32 +638,21 @@ void bli_samaxv_zen_int_avx512(
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
/* Converts all the NAN to a negative number less than previously encountered NANs*/
|
||||
__m512d remove_NAN_512d(__m512d vec)
|
||||
{
|
||||
|
||||
static int iter;
|
||||
static __m512d sign_mask;
|
||||
|
||||
__m512d vec_mask;
|
||||
__m512i int_mask_vec;
|
||||
__mmask8 vec_mask8;
|
||||
|
||||
iter = iter - 1;
|
||||
|
||||
sign_mask = _mm512_set1_pd(-0.f);
|
||||
|
||||
//numbers other than NAN will become 0
|
||||
vec_mask = _mm512_mul_pd(vec, sign_mask);
|
||||
|
||||
//producing an 8-bit mask
|
||||
int_mask_vec = _mm512_castpd_si512(vec_mask);
|
||||
vec_mask8 = _mm512_movepi64_mask(int_mask_vec);
|
||||
|
||||
//replacing all the NAN with negative numbers
|
||||
vec = _mm512_mask_blend_pd(vec_mask8, _mm512_set1_pd(-1 + iter), vec);
|
||||
|
||||
return vec;
|
||||
}
|
||||
#define REMOVE_NAN_512D(reg_512) \
|
||||
{ \
|
||||
__m512d sign_mask = _mm512_set1_pd( -0.0f ); \
|
||||
\
|
||||
/* Numbers other than NAN will become 0. */ \
|
||||
__m512d vec_mask = _mm512_mul_pd( reg_512, sign_mask ); \
|
||||
\
|
||||
/* Producing an 8-bit mask. */ \
|
||||
__m512i int_mask_vec = _mm512_castpd_si512( vec_mask ); \
|
||||
__mmask8 vec_mask8 = _mm512_movepi64_mask( int_mask_vec ); \
|
||||
\
|
||||
/* Replacing all the NAN with negative numbers. */ \
|
||||
reg_512 = _mm512_mask_blend_pd( vec_mask8, _mm512_set1_pd( nan_repl ), reg_512 ); \
|
||||
nan_repl = nan_repl - 1; \
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------------------------------
|
||||
void bli_damaxv_zen_int_avx512(
|
||||
@@ -679,6 +663,11 @@ void bli_damaxv_zen_int_avx512(
|
||||
{
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3)
|
||||
double *minus_one = PASTEMAC(d, m1);
|
||||
|
||||
// Used to replace NAN in registers. This value is decremented each time
|
||||
// remove NAN is applied so as to keep the NAN value replacements unique.
|
||||
double nan_repl = -1.0;
|
||||
|
||||
dim_t *zero_i = PASTEMAC(i, 0);
|
||||
|
||||
double chi1_r;
|
||||
@@ -776,7 +765,7 @@ void bli_damaxv_zen_int_avx512(
|
||||
|
||||
// Taking the absolute value and removing the NAN
|
||||
max_array.v = _mm512_andnot_pd(sign_mask, max_array.v);
|
||||
max_array.v = remove_NAN_512d(max_array.v);
|
||||
REMOVE_NAN_512D(max_array.v);
|
||||
|
||||
// Initializing the maximumum index
|
||||
max_ind.v = _mm512_set_pd(7, 6, 5, 4, 3, 2, 1, 0);
|
||||
@@ -786,7 +775,7 @@ void bli_damaxv_zen_int_avx512(
|
||||
//to point to the next 8 elements
|
||||
zmm4_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v);
|
||||
|
||||
/* Loop unrolled by a factor of 4
|
||||
/* Loop unrolled by a factor of 4
|
||||
At the end of the loop max_array holds the largest element
|
||||
in each corresponding vector index */
|
||||
for (unrollCount = 8; (unrollCount + 31) < n; unrollCount += 32)
|
||||
@@ -797,25 +786,25 @@ void bli_damaxv_zen_int_avx512(
|
||||
// with negative numbers
|
||||
zmm0.v = _mm512_loadu_pd(x);
|
||||
zmm0.v = _mm512_andnot_pd(sign_mask, zmm0.v);
|
||||
zmm0.v = remove_NAN_512d(zmm0.v);
|
||||
REMOVE_NAN_512D(zmm0.v);
|
||||
x += vector_length;
|
||||
|
||||
zmm1.v = _mm512_loadu_pd(x);
|
||||
zmm5_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v);
|
||||
zmm1.v = _mm512_andnot_pd(sign_mask, zmm1.v);
|
||||
zmm1.v = remove_NAN_512d(zmm1.v);
|
||||
REMOVE_NAN_512D(zmm1.v);
|
||||
x += vector_length;
|
||||
|
||||
zmm2.v = _mm512_loadu_pd(x);
|
||||
zmm6_Ind.v = _mm512_add_pd(zmm5_Ind.v, inc_vec.v);
|
||||
zmm2.v = _mm512_andnot_pd(sign_mask, zmm2.v);
|
||||
zmm2.v = remove_NAN_512d(zmm2.v);
|
||||
REMOVE_NAN_512D(zmm2.v);
|
||||
x += vector_length;
|
||||
|
||||
zmm3.v = _mm512_loadu_pd(x);
|
||||
zmm7_Ind.v = _mm512_add_pd(zmm6_Ind.v, inc_vec.v);
|
||||
zmm3.v = _mm512_andnot_pd(sign_mask, zmm3.v);
|
||||
zmm3.v = remove_NAN_512d(zmm3.v);
|
||||
REMOVE_NAN_512D(zmm3.v);
|
||||
x += vector_length;
|
||||
|
||||
/*Using sub function to generating the mask
|
||||
@@ -872,7 +861,7 @@ void bli_damaxv_zen_int_avx512(
|
||||
|
||||
/* At the end of the loop max_array holds the largest element
|
||||
in each corresponding vector index */
|
||||
for (int i = 1; i < iterations; ++i)
|
||||
for (dim_t i = 0; i < iterations; ++i)
|
||||
{
|
||||
// Taking 32 elements
|
||||
// Taking only the absolute values of the registers
|
||||
@@ -880,7 +869,7 @@ void bli_damaxv_zen_int_avx512(
|
||||
// with negative numbers
|
||||
zmm0.v = _mm512_loadu_pd(x);
|
||||
zmm0.v = _mm512_abs_pd(zmm0.v);
|
||||
zmm0.v = remove_NAN_512d(zmm0.v);
|
||||
REMOVE_NAN_512D(zmm0.v);
|
||||
|
||||
//Generating mask for the intermediate max vector
|
||||
mask_01 = _mm512_sub_pd(max_array.v, zmm0.v);
|
||||
@@ -968,6 +957,12 @@ void bli_damaxv_zen_int_avx512(
|
||||
}
|
||||
}
|
||||
|
||||
// Issue vzeroupper instruction to clear upper lanes of ymm registers.
|
||||
// This avoids a performance penalty caused by false dependencies when
|
||||
// transitioning from from AVX to SSE instructions (which may occur
|
||||
// later, especially if BLIS is compiled with -mfpmath=sse).
|
||||
_mm256_zeroupper();
|
||||
|
||||
// Return value
|
||||
*i_max = i_max_l;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user