diff --git a/kernels/zen4/1/bli_amaxv_zen_int_avx512.c b/kernels/zen4/1/bli_amaxv_zen_int_avx512.c index 0e0186c40..9e32f955a 100644 --- a/kernels/zen4/1/bli_amaxv_zen_int_avx512.c +++ b/kernels/zen4/1/bli_amaxv_zen_int_avx512.c @@ -80,33 +80,24 @@ typedef union the times the function is called to ensure that bigger numbers are assigned for nan which showed up first.*/ -__m512 remove_NAN_512_s(__m512 vec) -{ - // Sign extraction mask - __m512 sign_mask; - // Temporary place to store vector's sign extracted 16xdouble word - __m512 vec_mask; - // k register to store the mask to do blend operation to remove NAN - __mmask16 vec_mask16; - // Static to preserve accross the function calls - static int iter = -1; - iter -= 1; - - // Extracting sign from the vec into int_mask_vec - // Sign is -0.f in IEEE754 is just signbit set, all others 0 - sign_mask = _mm512_set1_ps(-0.f); - // And with -0.f will keep just signbits, all others will be 0 - vec_mask = _mm512_mul_ps(vec, sign_mask); - // Typecast mask into int type no clock cycle is taken just to - // convince compiler. - __m512i int_mask_vec = _mm512_castps_si512(vec_mask); - // Extract the signbits and put it in a 16bit mask register - vec_mask16 = _mm512_movepi32_mask(int_mask_vec); - - // Swap NAN with -ve number - vec = _mm512_mask_blend_ps(vec_mask16, _mm512_set1_ps(iter), vec); - return vec; -} +#define REMOVE_NAN_512S(reg_512) \ + { \ + /*Sign is -0.f in IEEE754 is just signbit set, all others 0*/ \ + __m512 sign_mask = _mm512_set1_ps( -0.0f ); \ + \ + /* Numbers other than NAN will become 0. */ \ + __m512 vec_mask = _mm512_mul_ps( reg_512, sign_mask ); \ + \ + /* Typecast mask into int type no clock cycle is taken just to + * convince compiler. */ \ + __m512i int_mask_vec = _mm512_castps_si512( vec_mask ); \ + /* Extract the signbits and put it in a 16bit mask register. */ \ + __mmask16 vec_mask16 = _mm512_movepi32_mask( int_mask_vec ); \ + \ + /* Swap NAN with -ve number. */ \ + reg_512 = _mm512_mask_blend_ps( vec_mask16, _mm512_set1_ps( nan_repl ), reg_512 ); \ + nan_repl = nan_repl - 1; \ + } // return a mask which indicates either: // - v1 > v2 @@ -151,10 +142,14 @@ void bli_samaxv_zen_int_avx512( // *minus_one = -1 float *minus_one = PASTEMAC(s, m1); // bli_sm1() // *zero_i = 0 - dim_t *zero_i = PASTEMAC(i, 0); // bli_i0() + dim_t *zero_i = PASTEMAC(i, 0); // bli_i0() + + // Used to replace NAN in registers. This value is decremented each time + // remove NAN is applied so as to keep the NAN value replacements unique. + float nan_repl = -1.0; float fndMaxVal; // Max value will be stored in this - dim_t fndInd; // Max value's index will be stored in this + dim_t fndInd; // Max value's index will be stored in this // Iterator for loops to keep continuity throughout the loops dim_t i; @@ -246,7 +241,7 @@ void bli_samaxv_zen_int_avx512( // max_vector = abs(max_vector) max_vec_1.v = _mm512_andnot_ps(abs_mask.v, max_vec_1.v); // Remove nan and replace with -ve values - max_vec_1.v = remove_NAN_512_s(max_vec_1.v); + REMOVE_NAN_512S(max_vec_1.v); // Increment x vector as we have loaded 16 values x += num_vector_elements; @@ -254,7 +249,7 @@ void bli_samaxv_zen_int_avx512( maxInd_vec_1.v = _mm512_setr_ps(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - int i = 1; + dim_t i = 1; for (; (i + 4) < num_iter; i += 5) { /* @@ -276,7 +271,7 @@ void bli_samaxv_zen_int_avx512( // Increment x vector as we have loaded 16 values x += num_vector_elements; // Remove nan and replace with -ve values - x_vec_1.v = remove_NAN_512_s(x_vec_1.v); + REMOVE_NAN_512S(x_vec_1.v); // Mask Generation of 1st(can be previous max) and 2nd element // mask = max_vector - x_vec_1 @@ -295,7 +290,7 @@ void bli_samaxv_zen_int_avx512( // max_vec_2 = abs(max_vec_2) max_vec_2.v = _mm512_andnot_ps(abs_mask.v, max_vec_2.v); // Remove nan and replace with -ve values - max_vec_2.v = remove_NAN_512_s(max_vec_2.v); + REMOVE_NAN_512S(max_vec_2.v); // Increment x vector as we have loaded 16 values x += num_vector_elements; // Increment the index vector to point to next indexes. @@ -306,7 +301,7 @@ void bli_samaxv_zen_int_avx512( // x_vec_2 = abs(x_vec_2) x_vec_2.v = _mm512_andnot_ps(abs_mask.v, x_vec_2.v); // Remove nan and replace with -ve values - x_vec_2.v = remove_NAN_512_s(x_vec_2.v); + REMOVE_NAN_512S(x_vec_2.v); // Increment x vector as we have loaded 16 values x += num_vector_elements; // Increment the index vector to point to next indexes. @@ -329,7 +324,7 @@ void bli_samaxv_zen_int_avx512( // max_vec_3 = abs(max_vec_3) max_vec_3.v = _mm512_andnot_ps(abs_mask.v, max_vec_3.v); // Remove nan and replace with -ve values - max_vec_3.v = remove_NAN_512_s(max_vec_3.v); + REMOVE_NAN_512S(max_vec_3.v); // Increment x vector as we have loaded 16 values x += num_vector_elements; // Increment the index vector to point to next indexes. @@ -339,7 +334,7 @@ void bli_samaxv_zen_int_avx512( // x_vec_3 = abs(x_vec_3) x_vec_3.v = _mm512_andnot_ps(abs_mask.v, x_vec_3.v); // Remove nan and replace with -ve values - x_vec_3.v = remove_NAN_512_s(x_vec_3.v); + REMOVE_NAN_512S(x_vec_3.v); // Increment x vector as we have loaded 16 values x += num_vector_elements; // Increment the index vector to point to next indexes. @@ -468,7 +463,7 @@ void bli_samaxv_zen_int_avx512( // x_vec_1 = abs(x_vec_1) x_vec_1.v = _mm512_andnot_ps(abs_mask.v, x_vec_1.v); // Remove nan and replace with -ve values - x_vec_1.v = remove_NAN_512_s(x_vec_1.v); + REMOVE_NAN_512S(x_vec_1.v); // Mask Generation // mask = max_vec_1 - x_vec_1 @@ -618,7 +613,7 @@ void bli_samaxv_zen_int_avx512( fndMaxVal = NAN; } // Finish off the remaining values using normal instructions - for (int i = n - num_remain; i < n; i++) + for (dim_t i = n - num_remain; i < n; i++) { float absval = fabsf(*(x)); if (fndMaxVal < absval || (isnan(absval) && !isnan(fndMaxVal))) @@ -643,32 +638,21 @@ void bli_samaxv_zen_int_avx512( // ----------------------------------------------------------------------------- /* Converts all the NAN to a negative number less than previously encountered NANs*/ -__m512d remove_NAN_512d(__m512d vec) -{ - - static int iter; - static __m512d sign_mask; - - __m512d vec_mask; - __m512i int_mask_vec; - __mmask8 vec_mask8; - - iter = iter - 1; - - sign_mask = _mm512_set1_pd(-0.f); - - //numbers other than NAN will become 0 - vec_mask = _mm512_mul_pd(vec, sign_mask); - - //producing an 8-bit mask - int_mask_vec = _mm512_castpd_si512(vec_mask); - vec_mask8 = _mm512_movepi64_mask(int_mask_vec); - - //replacing all the NAN with negative numbers - vec = _mm512_mask_blend_pd(vec_mask8, _mm512_set1_pd(-1 + iter), vec); - - return vec; -} +#define REMOVE_NAN_512D(reg_512) \ + { \ + __m512d sign_mask = _mm512_set1_pd( -0.0f ); \ + \ + /* Numbers other than NAN will become 0. */ \ + __m512d vec_mask = _mm512_mul_pd( reg_512, sign_mask ); \ + \ + /* Producing an 8-bit mask. */ \ + __m512i int_mask_vec = _mm512_castpd_si512( vec_mask ); \ + __mmask8 vec_mask8 = _mm512_movepi64_mask( int_mask_vec ); \ + \ + /* Replacing all the NAN with negative numbers. */ \ + reg_512 = _mm512_mask_blend_pd( vec_mask8, _mm512_set1_pd( nan_repl ), reg_512 ); \ + nan_repl = nan_repl - 1; \ + } //---------------------------------------------------------------------------------------------------- void bli_damaxv_zen_int_avx512( @@ -679,6 +663,11 @@ void bli_damaxv_zen_int_avx512( { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) double *minus_one = PASTEMAC(d, m1); + + // Used to replace NAN in registers. This value is decremented each time + // remove NAN is applied so as to keep the NAN value replacements unique. + double nan_repl = -1.0; + dim_t *zero_i = PASTEMAC(i, 0); double chi1_r; @@ -776,7 +765,7 @@ void bli_damaxv_zen_int_avx512( // Taking the absolute value and removing the NAN max_array.v = _mm512_andnot_pd(sign_mask, max_array.v); - max_array.v = remove_NAN_512d(max_array.v); + REMOVE_NAN_512D(max_array.v); // Initializing the maximumum index max_ind.v = _mm512_set_pd(7, 6, 5, 4, 3, 2, 1, 0); @@ -786,7 +775,7 @@ void bli_damaxv_zen_int_avx512( //to point to the next 8 elements zmm4_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); - /* Loop unrolled by a factor of 4 + /* Loop unrolled by a factor of 4 At the end of the loop max_array holds the largest element in each corresponding vector index */ for (unrollCount = 8; (unrollCount + 31) < n; unrollCount += 32) @@ -797,25 +786,25 @@ void bli_damaxv_zen_int_avx512( // with negative numbers zmm0.v = _mm512_loadu_pd(x); zmm0.v = _mm512_andnot_pd(sign_mask, zmm0.v); - zmm0.v = remove_NAN_512d(zmm0.v); + REMOVE_NAN_512D(zmm0.v); x += vector_length; zmm1.v = _mm512_loadu_pd(x); zmm5_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); zmm1.v = _mm512_andnot_pd(sign_mask, zmm1.v); - zmm1.v = remove_NAN_512d(zmm1.v); + REMOVE_NAN_512D(zmm1.v); x += vector_length; zmm2.v = _mm512_loadu_pd(x); zmm6_Ind.v = _mm512_add_pd(zmm5_Ind.v, inc_vec.v); zmm2.v = _mm512_andnot_pd(sign_mask, zmm2.v); - zmm2.v = remove_NAN_512d(zmm2.v); + REMOVE_NAN_512D(zmm2.v); x += vector_length; zmm3.v = _mm512_loadu_pd(x); zmm7_Ind.v = _mm512_add_pd(zmm6_Ind.v, inc_vec.v); zmm3.v = _mm512_andnot_pd(sign_mask, zmm3.v); - zmm3.v = remove_NAN_512d(zmm3.v); + REMOVE_NAN_512D(zmm3.v); x += vector_length; /*Using sub function to generating the mask @@ -872,7 +861,7 @@ void bli_damaxv_zen_int_avx512( /* At the end of the loop max_array holds the largest element in each corresponding vector index */ - for (int i = 1; i < iterations; ++i) + for (dim_t i = 0; i < iterations; ++i) { // Taking 32 elements // Taking only the absolute values of the registers @@ -880,7 +869,7 @@ void bli_damaxv_zen_int_avx512( // with negative numbers zmm0.v = _mm512_loadu_pd(x); zmm0.v = _mm512_abs_pd(zmm0.v); - zmm0.v = remove_NAN_512d(zmm0.v); + REMOVE_NAN_512D(zmm0.v); //Generating mask for the intermediate max vector mask_01 = _mm512_sub_pd(max_array.v, zmm0.v); @@ -968,6 +957,12 @@ void bli_damaxv_zen_int_avx512( } } + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + // Return value *i_max = i_max_l;