Updating reduction step of AVX512 DNRM2 API

- Updated the final reduction of partial sums to use scalar accumulation
  entirely, instead of using the _mm512_reduce_add_pd( ... ) intrinsic.
  This will in turn change the associativity and the rounding-off
  pattern in the reduction step.

- Defined a union data-type to do the same, by having a 512-bit
  register and a double-precision array as its members.

- Updated the declaration and usage of the register variable according
  to the union definition, for uniformity.

AMD-Internal: [CPUPL-5472]
Change-Id: I997464a6ec47e4054dca48a000fbd4ac0cfcc679
This commit is contained in:
Vignesh Balasubramanian
2024-08-01 16:59:22 +05:30
parent 75f21182bd
commit 4ec2bad744

View File

@@ -34,6 +34,14 @@
#include "immintrin.h"
#include "blis.h"
// Union data structure to access AVX registers
// One 512-bit AVX register holds 8 DP elements.
typedef union
{
__m512d v;
double d[8] __attribute__( ( aligned( 64 ) ) );
} v8df_t;
/*
Optimized kernel that computes the Frobenius norm using AVX512 intrinsics.
The kernel takes in the following input parameters :
@@ -88,9 +96,9 @@ void bli_dnorm2fv_unb_var1_avx512
{
// AVX-512 code-section
// Declaring registers for loading, accumulation, thresholds and scale factors
__m512d x_vec[4], sum_sml_vec[4], sum_med_vec[4], sum_big_vec[4], temp[4];
__m512d thresh_sml_vec, thresh_big_vec, scale_sml_vec, scale_big_vec;
__m512d zero_reg;
v8df_t x_vec[4], sum_sml_vec[4], sum_med_vec[4], sum_big_vec[4], temp[4];
v8df_t thresh_sml_vec, thresh_big_vec, scale_sml_vec, scale_big_vec;
v8df_t zero_reg;
// Masks to be used in computation
__mmask8 k_mask[8];
@@ -101,55 +109,55 @@ void bli_dnorm2fv_unb_var1_avx512
unsigned char truth_val[4];
// Setting the thresholds and scaling factors
thresh_sml_vec = _mm512_set1_pd( thresh_sml );
thresh_big_vec = _mm512_set1_pd( thresh_big );
scale_sml_vec = _mm512_set1_pd( scale_sml );
scale_big_vec = _mm512_set1_pd( scale_big );
thresh_sml_vec.v = _mm512_set1_pd( thresh_sml );
thresh_big_vec.v = _mm512_set1_pd( thresh_big );
scale_sml_vec.v = _mm512_set1_pd( scale_sml );
scale_big_vec.v = _mm512_set1_pd( scale_big );
// Resetting the accumulators
sum_sml_vec[0] = _mm512_setzero_pd();
sum_sml_vec[1] = _mm512_setzero_pd();
sum_sml_vec[2] = _mm512_setzero_pd();
sum_sml_vec[3] = _mm512_setzero_pd();
sum_sml_vec[0].v = _mm512_setzero_pd();
sum_sml_vec[1].v = _mm512_setzero_pd();
sum_sml_vec[2].v = _mm512_setzero_pd();
sum_sml_vec[3].v = _mm512_setzero_pd();
sum_med_vec[0] = _mm512_setzero_pd();
sum_med_vec[1] = _mm512_setzero_pd();
sum_med_vec[2] = _mm512_setzero_pd();
sum_med_vec[3] = _mm512_setzero_pd();
sum_med_vec[0].v = _mm512_setzero_pd();
sum_med_vec[1].v = _mm512_setzero_pd();
sum_med_vec[2].v = _mm512_setzero_pd();
sum_med_vec[3].v = _mm512_setzero_pd();
sum_big_vec[0] = _mm512_setzero_pd();
sum_big_vec[1] = _mm512_setzero_pd();
sum_big_vec[2] = _mm512_setzero_pd();
sum_big_vec[3] = _mm512_setzero_pd();
sum_big_vec[0].v = _mm512_setzero_pd();
sum_big_vec[1].v = _mm512_setzero_pd();
sum_big_vec[2].v = _mm512_setzero_pd();
sum_big_vec[3].v = _mm512_setzero_pd();
zero_reg = _mm512_setzero_pd();
zero_reg.v = _mm512_setzero_pd();
// Computing in blocks of 32
for ( ; ( i + 32 ) <= n; i = i + 32 )
{
// Set temp[0..3] to zero
temp[0] = _mm512_setzero_pd();
temp[1] = _mm512_setzero_pd();
temp[2] = _mm512_setzero_pd();
temp[3] = _mm512_setzero_pd();
temp[0].v = _mm512_setzero_pd();
temp[1].v = _mm512_setzero_pd();
temp[2].v = _mm512_setzero_pd();
temp[3].v = _mm512_setzero_pd();
// Loading the vectors
x_vec[0] = _mm512_loadu_pd( xt );
x_vec[1] = _mm512_loadu_pd( xt + 8 );
x_vec[2] = _mm512_loadu_pd( xt + 16 );
x_vec[3] = _mm512_loadu_pd( xt + 24 );
x_vec[0].v = _mm512_loadu_pd( xt );
x_vec[1].v = _mm512_loadu_pd( xt + 8 );
x_vec[2].v = _mm512_loadu_pd( xt + 16 );
x_vec[3].v = _mm512_loadu_pd( xt + 24 );
// Comparing to check for NaN
// Bits in the mask are set if NaN is encountered
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], x_vec[0], _CMP_UNORD_Q );
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1], x_vec[1], _CMP_UNORD_Q );
k_mask[2] = _mm512_cmp_pd_mask( x_vec[2], x_vec[2], _CMP_UNORD_Q );
k_mask[3] = _mm512_cmp_pd_mask( x_vec[3], x_vec[3], _CMP_UNORD_Q );
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, x_vec[0].v, _CMP_UNORD_Q );
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1].v, x_vec[1].v, _CMP_UNORD_Q );
k_mask[2] = _mm512_cmp_pd_mask( x_vec[2].v, x_vec[2].v, _CMP_UNORD_Q );
k_mask[3] = _mm512_cmp_pd_mask( x_vec[3].v, x_vec[3].v, _CMP_UNORD_Q );
// Checking if any bit in the masks are set
// The truth_val is set to 0 if any bit in the mask is 1
// Thus, truth_val[0] = 0 if x_vec[0] or x_vec[1] has NaN
// truth_val[1] = 0 if x_vec[2] or x_vec[3] has NaN
// Thus, truth_val[0] = 0 if x_vec[0].v or x_vec[1].v has NaN
// truth_val[1] = 0 if x_vec[2].v or x_vec[3].v has NaN
truth_val[0] = _kortestz_mask8_u8( k_mask[0], k_mask[1] );
truth_val[1] = _kortestz_mask8_u8( k_mask[2], k_mask[3] );
@@ -163,30 +171,30 @@ void bli_dnorm2fv_unb_var1_avx512
}
// Getting the absoulte values of elements in the vectors
x_vec[0] = _mm512_abs_pd( x_vec[0] );
x_vec[1] = _mm512_abs_pd( x_vec[1] );
x_vec[2] = _mm512_abs_pd( x_vec[2] );
x_vec[3] = _mm512_abs_pd( x_vec[3] );
x_vec[0].v = _mm512_abs_pd( x_vec[0].v );
x_vec[1].v = _mm512_abs_pd( x_vec[1].v );
x_vec[2].v = _mm512_abs_pd( x_vec[2].v );
x_vec[3].v = _mm512_abs_pd( x_vec[3].v );
// Setting the masks by comparing with thresh_sml_vec
// That is, k_mask[0][i] = 1 if x_vec[0][i] > thresh_sml_vec
// k_mask[1][i] = 1 if x_vec[1][i] > thresh_sml_vec
// k_mask[2][i] = 1 if x_vec[2][i] > thresh_sml_vec
// k_mask[3][i] = 1 if x_vec[3][i] > thresh_sml_vec
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], thresh_sml_vec, _CMP_GT_OS );
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1], thresh_sml_vec, _CMP_GT_OS );
k_mask[2] = _mm512_cmp_pd_mask( x_vec[2], thresh_sml_vec, _CMP_GT_OS );
k_mask[3] = _mm512_cmp_pd_mask( x_vec[3], thresh_sml_vec, _CMP_GT_OS );
// Setting the masks by comparing with thresh_sml_vec.v
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] > thresh_sml_vec.v
// k_mask[1][i] = 1 if x_vec[1].v[i] > thresh_sml_vec.v
// k_mask[2][i] = 1 if x_vec[2].v[i] > thresh_sml_vec.v
// k_mask[3][i] = 1 if x_vec[3].v[i] > thresh_sml_vec.v
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_sml_vec.v, _CMP_GT_OS );
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1].v, thresh_sml_vec.v, _CMP_GT_OS );
k_mask[2] = _mm512_cmp_pd_mask( x_vec[2].v, thresh_sml_vec.v, _CMP_GT_OS );
k_mask[3] = _mm512_cmp_pd_mask( x_vec[3].v, thresh_sml_vec.v, _CMP_GT_OS );
// Setting the masks by comparing with thresh_big_vec
// That is, k_mask[4][i] = 1 if x_vec[0][i] < thresh_big_vec
// k_mask[5][i] = 1 if x_vec[1][i] < thresh_big_vec
// k_mask[6][i] = 1 if x_vec[2][i] < thresh_big_vec
// k_mask[7][i] = 1 if x_vec[3][i] < thresh_big_vec
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0], thresh_big_vec, _CMP_LT_OS );
k_mask[5] = _mm512_cmp_pd_mask( x_vec[1], thresh_big_vec, _CMP_LT_OS );
k_mask[6] = _mm512_cmp_pd_mask( x_vec[2], thresh_big_vec, _CMP_LT_OS );
k_mask[7] = _mm512_cmp_pd_mask( x_vec[3], thresh_big_vec, _CMP_LT_OS );
// Setting the masks by comparing with thresh_big_vec.v
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] < thresh_big_vec.v
// k_mask[5][i] = 1 if x_vec[1].v[i] < thresh_big_vec.v
// k_mask[6][i] = 1 if x_vec[2].v[i] < thresh_big_vec.v
// k_mask[7][i] = 1 if x_vec[3].v[i] < thresh_big_vec.v
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_big_vec.v, _CMP_LT_OS );
k_mask[5] = _mm512_cmp_pd_mask( x_vec[1].v, thresh_big_vec.v, _CMP_LT_OS );
k_mask[6] = _mm512_cmp_pd_mask( x_vec[2].v, thresh_big_vec.v, _CMP_LT_OS );
k_mask[7] = _mm512_cmp_pd_mask( x_vec[3].v, thresh_big_vec.v, _CMP_LT_OS );
// Setting the masks to filter only the elements within the thresholds
// k_mask[0 ... 3] contain masks for elements > thresh_sml
@@ -200,10 +208,10 @@ void bli_dnorm2fv_unb_var1_avx512
// Setting booleans to check for underflow/overflow handling
// In case of having values outside threshold, the associated
// bit in k_mask[4 ... 7] is 0.
// Thus, truth_val[0] = 0 if x_vec[0] has elements outside thresholds
// truth_val[1] = 0 if x_vec[1] has elements outside thresholds
// truth_val[2] = 0 if x_vec[2] has elements outside thresholds
// truth_val[3] = 0 if x_vec[3] has elements outside thresholds
// Thus, truth_val[0] = 0 if x_vec[0].v has elements outside thresholds
// truth_val[1] = 0 if x_vec[1].v has elements outside thresholds
// truth_val[2] = 0 if x_vec[2].v has elements outside thresholds
// truth_val[3] = 0 if x_vec[3].v has elements outside thresholds
truth_val[0] = _kortestc_mask8_u8( k_mask[4], k_mask[4] );
truth_val[1] = _kortestc_mask8_u8( k_mask[5], k_mask[5] );
truth_val[2] = _kortestc_mask8_u8( k_mask[6], k_mask[6] );
@@ -211,10 +219,10 @@ void bli_dnorm2fv_unb_var1_avx512
// Computing using masked fmadds, that carries over values from
// accumulator register if the mask bit is 0
sum_med_vec[0] = _mm512_mask3_fmadd_pd( x_vec[0], x_vec[0], sum_med_vec[0], k_mask[4] );
sum_med_vec[1] = _mm512_mask3_fmadd_pd( x_vec[1], x_vec[1], sum_med_vec[1], k_mask[5] );
sum_med_vec[2] = _mm512_mask3_fmadd_pd( x_vec[2], x_vec[2], sum_med_vec[2], k_mask[6] );
sum_med_vec[3] = _mm512_mask3_fmadd_pd( x_vec[3], x_vec[3], sum_med_vec[3], k_mask[7] );
sum_med_vec[0].v = _mm512_mask3_fmadd_pd( x_vec[0].v, x_vec[0].v, sum_med_vec[0].v, k_mask[4] );
sum_med_vec[1].v = _mm512_mask3_fmadd_pd( x_vec[1].v, x_vec[1].v, sum_med_vec[1].v, k_mask[5] );
sum_med_vec[2].v = _mm512_mask3_fmadd_pd( x_vec[2].v, x_vec[2].v, sum_med_vec[2].v, k_mask[6] );
sum_med_vec[3].v = _mm512_mask3_fmadd_pd( x_vec[3].v, x_vec[3].v, sum_med_vec[3].v, k_mask[7] );
// In case of having elements outside the threshold
if( !( truth_val[0] && truth_val[1] && truth_val[2] && truth_val[3] ) )
@@ -224,20 +232,20 @@ void bli_dnorm2fv_unb_var1_avx512
// k_mask[0 ... 3] contain masks for elements > thresh_sml. This would
// include both elements < thresh_big and >= thresh_big
// XOR on these will produce masks for elements >= thresh_big
// That is, k_mask[4][i] = 1 if x_vec[0][i] >= thresh_big_vec
// k_mask[5][i] = 1 if x_vec[1][i] >= thresh_big_vec
// k_mask[6][i] = 1 if x_vec[2][i] >= thresh_big_vec
// k_mask[7][i] = 1 if x_vec[3][i] >= thresh_big_vec
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] >= thresh_big_vec.v
// k_mask[5][i] = 1 if x_vec[1].v[i] >= thresh_big_vec.v
// k_mask[6][i] = 1 if x_vec[2].v[i] >= thresh_big_vec.v
// k_mask[7][i] = 1 if x_vec[3].v[i] >= thresh_big_vec.v
k_mask[4] = _kxor_mask8( k_mask[0], k_mask[4] );
k_mask[5] = _kxor_mask8( k_mask[1], k_mask[5] );
k_mask[6] = _kxor_mask8( k_mask[2], k_mask[6] );
k_mask[7] = _kxor_mask8( k_mask[3], k_mask[7] );
// Inverting k_mask[0 ... 3], to obtain masks for elements <= thresh_sml
// That is, k_mask[0][i] = 1 if x_vec[0][i] <= thresh_sml_vec
// k_mask[1][i] = 1 if x_vec[1][i] <= thresh_sml_vec
// k_mask[2][i] = 1 if x_vec[2][i] <= thresh_sml_vec
// k_mask[3][i] = 1 if x_vec[3][i] <= thresh_sml_vec
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] <= thresh_sml_vec.v
// k_mask[1][i] = 1 if x_vec[1].v[i] <= thresh_sml_vec.v
// k_mask[2][i] = 1 if x_vec[2].v[i] <= thresh_sml_vec.v
// k_mask[3][i] = 1 if x_vec[3].v[i] <= thresh_sml_vec.v
k_mask[0] = _knot_mask8( k_mask[0] );
k_mask[1] = _knot_mask8( k_mask[1] );
k_mask[2] = _knot_mask8( k_mask[2] );
@@ -245,8 +253,8 @@ void bli_dnorm2fv_unb_var1_avx512
// Checking whether we have values greater than thresh_big
// The truth_val is set to 0 if any bit in the mask is 1
// Thus, truth_val[2] = 0 if x_vec[0] or x_vec[1] has elements >= thresh_big_vec
// truth_val[3] = 0 if x_vec[2] or x_vec[3] has elements >= thresh_big_vec
// Thus, truth_val[2] = 0 if x_vec[0].v or x_vec[1].v has elements >= thresh_big_vec.v
// truth_val[3] = 0 if x_vec[2].v or x_vec[3].v has elements >= thresh_big_vec.v
truth_val[2] = _kortestz_mask8_u8( k_mask[4], k_mask[5] );
truth_val[3] = _kortestz_mask8_u8( k_mask[6], k_mask[7] );
@@ -261,16 +269,16 @@ void bli_dnorm2fv_unb_var1_avx512
// are greater than thresh_big
// Scale the required elements in x_vec[0..3] by scale_smal
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[4], scale_big_vec, x_vec[0] );
temp[1] = _mm512_mask_mul_pd( zero_reg, k_mask[5], scale_big_vec, x_vec[1] );
temp[2] = _mm512_mask_mul_pd( zero_reg, k_mask[6], scale_big_vec, x_vec[2] );
temp[3] = _mm512_mask_mul_pd( zero_reg, k_mask[7], scale_big_vec, x_vec[3] );
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[4], scale_big_vec.v, x_vec[0].v );
temp[1].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[5], scale_big_vec.v, x_vec[1].v );
temp[2].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[6], scale_big_vec.v, x_vec[2].v );
temp[3].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[7], scale_big_vec.v, x_vec[3].v );
// Square and add the elements to the accumulators
sum_big_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_big_vec[0] );
sum_big_vec[1] = _mm512_fmadd_pd( temp[1], temp[1], sum_big_vec[1] );
sum_big_vec[2] = _mm512_fmadd_pd( temp[2], temp[2], sum_big_vec[2] );
sum_big_vec[3] = _mm512_fmadd_pd( temp[3], temp[3], sum_big_vec[3] );
sum_big_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_big_vec[0].v );
sum_big_vec[1].v = _mm512_fmadd_pd( temp[1].v, temp[1].v, sum_big_vec[1].v );
sum_big_vec[2].v = _mm512_fmadd_pd( temp[2].v, temp[2].v, sum_big_vec[2].v );
sum_big_vec[3].v = _mm512_fmadd_pd( temp[3].v, temp[3].v, sum_big_vec[3].v );
}
else if( !isbig )
{
@@ -279,16 +287,16 @@ void bli_dnorm2fv_unb_var1_avx512
// are lesser than thresh_sml, if needed
// Scale the required elements in x_vec[0..3] by scale_smal
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[0], scale_sml_vec, x_vec[0] );
temp[1] = _mm512_mask_mul_pd( zero_reg, k_mask[1], scale_sml_vec, x_vec[1] );
temp[2] = _mm512_mask_mul_pd( zero_reg, k_mask[2], scale_sml_vec, x_vec[2] );
temp[3] = _mm512_mask_mul_pd( zero_reg, k_mask[3], scale_sml_vec, x_vec[3] );
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[0], scale_sml_vec.v, x_vec[0].v );
temp[1].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[1], scale_sml_vec.v, x_vec[1].v );
temp[2].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[2], scale_sml_vec.v, x_vec[2].v );
temp[3].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[3], scale_sml_vec.v, x_vec[3].v );
// Square and add the elements to the accumulators
sum_sml_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_sml_vec[0] );
sum_sml_vec[1] = _mm512_fmadd_pd( temp[1], temp[1], sum_sml_vec[1] );
sum_sml_vec[2] = _mm512_fmadd_pd( temp[2], temp[2], sum_sml_vec[2] );
sum_sml_vec[3] = _mm512_fmadd_pd( temp[3], temp[3], sum_sml_vec[3] );
sum_sml_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_sml_vec[0].v );
sum_sml_vec[1].v = _mm512_fmadd_pd( temp[1].v, temp[1].v, sum_sml_vec[1].v );
sum_sml_vec[2].v = _mm512_fmadd_pd( temp[2].v, temp[2].v, sum_sml_vec[2].v );
sum_sml_vec[3].v = _mm512_fmadd_pd( temp[3].v, temp[3].v, sum_sml_vec[3].v );
}
}
@@ -300,21 +308,21 @@ void bli_dnorm2fv_unb_var1_avx512
for ( ; ( i + 16 ) <= n; i = i + 16 )
{
// Set temp[0..1] to zero
temp[0] = _mm512_setzero_pd();
temp[1] = _mm512_setzero_pd();
temp[0].v = _mm512_setzero_pd();
temp[1].v = _mm512_setzero_pd();
// Loading the vectors
x_vec[0] = _mm512_loadu_pd( xt );
x_vec[1] = _mm512_loadu_pd( xt + 8 );
x_vec[0].v = _mm512_loadu_pd( xt );
x_vec[1].v = _mm512_loadu_pd( xt + 8 );
// Comparing to check for NaN
// Bits in the mask are set if NaN is encountered
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], x_vec[0], _CMP_UNORD_Q );
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1], x_vec[1], _CMP_UNORD_Q );
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, x_vec[0].v, _CMP_UNORD_Q );
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1].v, x_vec[1].v, _CMP_UNORD_Q );
// Checking if any bit in the masks are set
// The truth_val is set to 0 if any bit in the mask is 1
// Thus, truth_val[0] = 0 if x_vec[0] or x_vec[1] has NaN
// Thus, truth_val[0] = 0 if x_vec[0].v or x_vec[1].v has NaN
truth_val[0] = _kortestz_mask8_u8( k_mask[0], k_mask[1] );
// Set norm to NaN and return early, if either truth_val[0] or truth_val[1] is set to 0
@@ -327,20 +335,20 @@ void bli_dnorm2fv_unb_var1_avx512
}
// Getting the absoulte values of elements in the vectors
x_vec[0] = _mm512_abs_pd( x_vec[0] );
x_vec[1] = _mm512_abs_pd( x_vec[1] );
x_vec[0].v = _mm512_abs_pd( x_vec[0].v );
x_vec[1].v = _mm512_abs_pd( x_vec[1].v );
// Setting the masks by comparing with thresh_sml_vec
// That is, k_mask[0][i] = 1 if x_vec[0][i] > thresh_sml_vec
// k_mask[1][i] = 1 if x_vec[1][i] > thresh_sml_vec
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], thresh_sml_vec, _CMP_GT_OS );
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1], thresh_sml_vec, _CMP_GT_OS );
// Setting the masks by comparing with thresh_sml_vec.v
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] > thresh_sml_vec.v
// k_mask[1][i] = 1 if x_vec[1].v[i] > thresh_sml_vec.v
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_sml_vec.v, _CMP_GT_OS );
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1].v, thresh_sml_vec.v, _CMP_GT_OS );
// Setting the masks by comparing with thresh_big_vec
// That is, k_mask[4][i] = 1 if x_vec[0][i] < thresh_big_vec
// k_mask[5][i] = 1 if x_vec[1][i] < thresh_big_vec
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0], thresh_big_vec, _CMP_LT_OS );
k_mask[5] = _mm512_cmp_pd_mask( x_vec[1], thresh_big_vec, _CMP_LT_OS );
// Setting the masks by comparing with thresh_big_vec.v
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] < thresh_big_vec.v
// k_mask[5][i] = 1 if x_vec[1].v[i] < thresh_big_vec.v
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_big_vec.v, _CMP_LT_OS );
k_mask[5] = _mm512_cmp_pd_mask( x_vec[1].v, thresh_big_vec.v, _CMP_LT_OS );
// Setting the masks to filter only the elements within the thresholds
// k_mask[0 ... 1] contain masks for elements > thresh_sml
@@ -352,15 +360,15 @@ void bli_dnorm2fv_unb_var1_avx512
// Setting booleans to check for underflow/overflow handling
// In case of having values outside threshold, the associated
// bit in k_mask[4 ... 7] is 0.
// Thus, truth_val[0] = 0 if x_vec[0] has elements outside thresholds
// truth_val[1] = 0 if x_vec[1] has elements outside thresholds
// Thus, truth_val[0] = 0 if x_vec[0].v has elements outside thresholds
// truth_val[1] = 0 if x_vec[1].v has elements outside thresholds
truth_val[0] = _kortestc_mask8_u8( k_mask[4], k_mask[4] );
truth_val[1] = _kortestc_mask8_u8( k_mask[5], k_mask[5] );
// Computing using masked fmadds, that carries over values from
// accumulator register if the mask bit is 0
sum_med_vec[0] = _mm512_mask3_fmadd_pd( x_vec[0], x_vec[0], sum_med_vec[0], k_mask[4] );
sum_med_vec[1] = _mm512_mask3_fmadd_pd( x_vec[1], x_vec[1], sum_med_vec[1], k_mask[5] );
sum_med_vec[0].v = _mm512_mask3_fmadd_pd( x_vec[0].v, x_vec[0].v, sum_med_vec[0].v, k_mask[4] );
sum_med_vec[1].v = _mm512_mask3_fmadd_pd( x_vec[1].v, x_vec[1].v, sum_med_vec[1].v, k_mask[5] );
// In case of having elements outside the threshold
if( !( truth_val[0] && truth_val[1] ) )
@@ -370,20 +378,20 @@ void bli_dnorm2fv_unb_var1_avx512
// k_mask[0 ... 1] contain masks for elements > thresh_sml. This would
// include both elements < thresh_big and >= thresh_big
// XOR on these will produce masks for elements >= thresh_big
// That is, k_mask[4][i] = 1 if x_vec[0][i] >= thresh_big_vec
// k_mask[5][i] = 1 if x_vec[1][i] >= thresh_big_vec
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] >= thresh_big_vec.v
// k_mask[5][i] = 1 if x_vec[1].v[i] >= thresh_big_vec.v
k_mask[4] = _kxor_mask8( k_mask[0], k_mask[4] );
k_mask[5] = _kxor_mask8( k_mask[1], k_mask[5] );
// Inverting k_mask[0 ... 1], to obtain masks for elements <= thresh_sml
// That is, k_mask[0][i] = 1 if x_vec[0][i] <= thresh_sml_vec
// k_mask[1][i] = 1 if x_vec[1][i] <= thresh_sml_vec
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] <= thresh_sml_vec.v
// k_mask[1][i] = 1 if x_vec[1].v[i] <= thresh_sml_vec.v
k_mask[0] = _knot_mask8( k_mask[0] );
k_mask[1] = _knot_mask8( k_mask[1] );
// Checking whether we have values greater than thresh_big
// The truth_val is set to 0 if any bit in the mask is 1
// Thus, truth_val[2] = 0 if x_vec[0] or x_vec[1] has elements >= thresh_big_vec
// Thus, truth_val[2] = 0 if x_vec[0].v or x_vec[1].v has elements >= thresh_big_vec.v
truth_val[2] = _kortestz_mask8_u8( k_mask[4], k_mask[5] );
// In case of having values greater than thresh_big
@@ -397,12 +405,12 @@ void bli_dnorm2fv_unb_var1_avx512
// are greater than thresh_big
// Scale the required elements in x_vec[0..3] by scale_smal
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[4], scale_big_vec, x_vec[0] );
temp[1] = _mm512_mask_mul_pd( zero_reg, k_mask[5], scale_big_vec, x_vec[1] );
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[4], scale_big_vec.v, x_vec[0].v );
temp[1].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[5], scale_big_vec.v, x_vec[1].v );
// Square and add the elements to the accumulators
sum_big_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_big_vec[0] );
sum_big_vec[1] = _mm512_fmadd_pd( temp[1], temp[1], sum_big_vec[1] );
sum_big_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_big_vec[0].v );
sum_big_vec[1].v = _mm512_fmadd_pd( temp[1].v, temp[1].v, sum_big_vec[1].v );
}
else if( !isbig )
{
@@ -411,12 +419,12 @@ void bli_dnorm2fv_unb_var1_avx512
// are lesser than thresh_sml, if needed
// Scale the required elements in x_vec[0..3] by scale_smal
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[0], scale_sml_vec, x_vec[0] );
temp[1] = _mm512_mask_mul_pd( zero_reg, k_mask[1], scale_sml_vec, x_vec[1] );
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[0], scale_sml_vec.v, x_vec[0].v );
temp[1].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[1], scale_sml_vec.v, x_vec[1].v );
// Square and add the elements to the accumulators
sum_sml_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_sml_vec[0] );
sum_sml_vec[1] = _mm512_fmadd_pd( temp[1], temp[1], sum_sml_vec[1] );
sum_sml_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_sml_vec[0].v );
sum_sml_vec[1].v = _mm512_fmadd_pd( temp[1].v, temp[1].v, sum_sml_vec[1].v );
}
}
@@ -425,19 +433,19 @@ void bli_dnorm2fv_unb_var1_avx512
}
for ( ; ( i + 8 ) <= n; i = i + 8 )
{
// Set temp[0] to zero
temp[0] = _mm512_setzero_pd();
// Set temp[0].v to zero
temp[0].v = _mm512_setzero_pd();
// Loading the vectors
x_vec[0] = _mm512_loadu_pd( xt );
x_vec[0].v = _mm512_loadu_pd( xt );
// Comparing to check for NaN
// Bits in the mask are set if NaN is encountered
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], x_vec[0], _CMP_UNORD_Q );
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, x_vec[0].v, _CMP_UNORD_Q );
// Checking if any bit in the masks are set
// The truth_val is set to 0 if any bit in the mask is 1
// Thus, truth_val[0] = 0 if x_vec[0] or x_vec[1] has NaN
// Thus, truth_val[0] = 0 if x_vec[0].v or x_vec[1].v has NaN
truth_val[0] = _kortestz_mask8_u8( k_mask[0], k_mask[0] );
// Set norm to NaN and return early, if either truth_val[0] or truth_val[1] is set to 0
@@ -450,15 +458,15 @@ void bli_dnorm2fv_unb_var1_avx512
}
// Getting the absoulte values of elements in the vectors
x_vec[0] = _mm512_abs_pd( x_vec[0] );
x_vec[0].v = _mm512_abs_pd( x_vec[0].v );
// Setting the masks by comparing with thresh_sml_vec
// That is, k_mask[0][i] = 1 if x_vec[0][i] > thresh_sml_vec
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], thresh_sml_vec, _CMP_GT_OS );
// Setting the masks by comparing with thresh_sml_vec.v
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] > thresh_sml_vec.v
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_sml_vec.v, _CMP_GT_OS );
// Setting the masks by comparing with thresh_big_vec
// That is, k_mask[4][i] = 1 if x_vec[0][i] < thresh_big_vec
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0], thresh_big_vec, _CMP_LT_OS );
// Setting the masks by comparing with thresh_big_vec.v
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] < thresh_big_vec.v
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_big_vec.v, _CMP_LT_OS );
// Setting the masks to filter only the elements within the thresholds
// k_mask[0] contain masks for elements > thresh_sml
@@ -469,12 +477,12 @@ void bli_dnorm2fv_unb_var1_avx512
// Setting booleans to check for underflow/overflow handling
// In case of having values outside threshold, the associated
// bit in k_mask[4] is 0.
// Thus, truth_val[0] = 0 if x_vec[0] has elements outside thresholds
// Thus, truth_val[0] = 0 if x_vec[0].v has elements outside thresholds
truth_val[0] = _kortestc_mask8_u8( k_mask[4], k_mask[4] );
// Computing using masked fmadds, that carries over values from
// accumulator register if the mask bit is 0
sum_med_vec[0] = _mm512_mask3_fmadd_pd( x_vec[0], x_vec[0], sum_med_vec[0], k_mask[4] );
sum_med_vec[0].v = _mm512_mask3_fmadd_pd( x_vec[0].v, x_vec[0].v, sum_med_vec[0].v, k_mask[4] );
// In case of having elements outside the threshold
if( !truth_val[0] )
@@ -484,18 +492,18 @@ void bli_dnorm2fv_unb_var1_avx512
// k_mask[0 ... 1] contain masks for elements > thresh_sml. This would
// include both elements < thresh_big and >= thresh_big
// XOR on these will produce masks for elements >= thresh_big
// That is, k_mask[4][i] = 1 if x_vec[0][i] >= thresh_big_vec
// k_mask[5][i] = 1 if x_vec[1][i] >= thresh_big_vec
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] >= thresh_big_vec.v
// k_mask[5][i] = 1 if x_vec[1].v[i] >= thresh_big_vec.v
k_mask[4] = _kxor_mask8( k_mask[0], k_mask[4] );
// Inverting k_mask[0 ... 1], to obtain masks for elements <= thresh_sml
// That is, k_mask[0][i] = 1 if x_vec[0][i] <= thresh_sml_vec
// k_mask[1][i] = 1 if x_vec[1][i] <= thresh_sml_vec
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] <= thresh_sml_vec.v
// k_mask[1][i] = 1 if x_vec[1].v[i] <= thresh_sml_vec.v
k_mask[0] = _knot_mask8( k_mask[0] );
// Checking whether we have values greater than thresh_big
// The truth_val is set to 0 if any bit in the mask is 1
// Thus, truth_val[2] = 0 if x_vec[0] or x_vec[1] has elements >= thresh_big_vec
// Thus, truth_val[2] = 0 if x_vec[0].v or x_vec[1].v has elements >= thresh_big_vec.v
truth_val[2] = _kortestz_mask8_u8( k_mask[4], k_mask[4] );
// In case of having values greater than thresh_big
@@ -509,10 +517,10 @@ void bli_dnorm2fv_unb_var1_avx512
// are greater than thresh_big
// Scale the required elements in x_vec[0..3] by scale_smal
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[4], scale_big_vec, x_vec[0] );
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[4], scale_big_vec.v, x_vec[0].v );
// Square and add the elements to the accumulators
sum_big_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_big_vec[0] );
sum_big_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_big_vec[0].v );
}
else if( !isbig )
{
@@ -521,10 +529,10 @@ void bli_dnorm2fv_unb_var1_avx512
// are lesser than thresh_sml, if needed
// Scale the required elements in x_vec[0..3] by scale_smal
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[0], scale_sml_vec, x_vec[0] );
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[0], scale_sml_vec.v, x_vec[0].v );
// Square and add the elements to the accumulators
sum_sml_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_sml_vec[0] );
sum_sml_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_sml_vec[0].v );
}
}
@@ -533,22 +541,22 @@ void bli_dnorm2fv_unb_var1_avx512
}
if( i < n )
{
// Set temp[0] to zero
temp[0] = _mm512_setzero_pd();
// Set temp[0].v to zero
temp[0].v = _mm512_setzero_pd();
// Setting the mask to load
k_mask[0] = ( 1 << ( n - i ) ) - 1;
// Loading the vectors
x_vec[0] = _mm512_maskz_loadu_pd( k_mask[0], xt );
x_vec[0].v = _mm512_maskz_loadu_pd( k_mask[0], xt );
// Comparing to check for NaN
// Bits in the mask are set if NaN is encountered
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], x_vec[0], _CMP_UNORD_Q );
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, x_vec[0].v, _CMP_UNORD_Q );
// Checking if any bit in the masks are set
// The truth_val is set to 0 if any bit in the mask is 1
// Thus, truth_val[0] = 0 if x_vec[0] or x_vec[1] has NaN
// Thus, truth_val[0] = 0 if x_vec[0].v or x_vec[1].v has NaN
truth_val[0] = _kortestz_mask8_u8( k_mask[0], k_mask[0] );
// Set norm to NaN and return early, if either truth_val[0] or truth_val[1] is set to 0
@@ -561,15 +569,15 @@ void bli_dnorm2fv_unb_var1_avx512
}
// Getting the absoulte values of elements in the vectors
x_vec[0] = _mm512_abs_pd( x_vec[0] );
x_vec[0].v = _mm512_abs_pd( x_vec[0].v );
// Setting the masks by comparing with thresh_sml_vec
// That is, k_mask[0][i] = 1 if x_vec[0][i] > thresh_sml_vec
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], thresh_sml_vec, _CMP_GT_OS );
// Setting the masks by comparing with thresh_sml_vec.v
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] > thresh_sml_vec.v
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_sml_vec.v, _CMP_GT_OS );
// Setting the masks by comparing with thresh_big_vec
// That is, k_mask[4][i] = 1 if x_vec[0][i] < thresh_big_vec
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0], thresh_big_vec, _CMP_LT_OS );
// Setting the masks by comparing with thresh_big_vec.v
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] < thresh_big_vec.v
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_big_vec.v, _CMP_LT_OS );
// Setting the masks to filter only the elements within the thresholds
// k_mask[0] contain masks for elements > thresh_sml
@@ -580,12 +588,12 @@ void bli_dnorm2fv_unb_var1_avx512
// Setting booleans to check for underflow/overflow handling
// In case of having values outside threshold, the associated
// bit in k_mask[4] is 0.
// Thus, truth_val[0] = 0 if x_vec[0] has elements outside thresholds
// Thus, truth_val[0] = 0 if x_vec[0].v has elements outside thresholds
truth_val[0] = _kortestc_mask8_u8( k_mask[4], k_mask[4] );
// Computing using masked fmadds, that carries over values from
// accumulator register if the mask bit is 0
sum_med_vec[0] = _mm512_mask3_fmadd_pd( x_vec[0], x_vec[0], sum_med_vec[0], k_mask[4] );
sum_med_vec[0].v = _mm512_mask3_fmadd_pd( x_vec[0].v, x_vec[0].v, sum_med_vec[0].v, k_mask[4] );
// In case of having elements outside the threshold
if( !truth_val[0] )
@@ -595,18 +603,18 @@ void bli_dnorm2fv_unb_var1_avx512
// k_mask[0 ... 1] contain masks for elements > thresh_sml. This would
// include both elements < thresh_big and >= thresh_big
// XOR on these will produce masks for elements >= thresh_big
// That is, k_mask[4][i] = 1 if x_vec[0][i] >= thresh_big_vec
// k_mask[5][i] = 1 if x_vec[1][i] >= thresh_big_vec
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] >= thresh_big_vec.v
// k_mask[5][i] = 1 if x_vec[1].v[i] >= thresh_big_vec.v
k_mask[4] = _kxor_mask8( k_mask[0], k_mask[4] );
// Inverting k_mask[0 ... 1], to obtain masks for elements <= thresh_sml
// That is, k_mask[0][i] = 1 if x_vec[0][i] <= thresh_sml_vec
// k_mask[1][i] = 1 if x_vec[1][i] <= thresh_sml_vec
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] <= thresh_sml_vec.v
// k_mask[1][i] = 1 if x_vec[1].v[i] <= thresh_sml_vec.v
k_mask[0] = _knot_mask8( k_mask[0] );
// Checking whether we have values greater than thresh_big
// The truth_val is set to 0 if any bit in the mask is 1
// Thus, truth_val[2] = 0 if x_vec[0] or x_vec[1] has elements >= thresh_big_vec
// Thus, truth_val[2] = 0 if x_vec[0].v or x_vec[1].v has elements >= thresh_big_vec.v
truth_val[2] = _kortestz_mask8_u8( k_mask[4], k_mask[4] );
// In case of having values greater than thresh_big
@@ -620,10 +628,10 @@ void bli_dnorm2fv_unb_var1_avx512
// are greater than thresh_big
// Scale the required elements in x_vec[0..3] by scale_smal
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[4], scale_big_vec, x_vec[0] );
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[4], scale_big_vec.v, x_vec[0].v );
// Square and add the elements to the accumulators
sum_big_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_big_vec[0] );
sum_big_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_big_vec[0].v );
}
else if( !isbig )
{
@@ -632,32 +640,35 @@ void bli_dnorm2fv_unb_var1_avx512
// are lesser than thresh_sml, if needed
// Scale the required elements in x_vec[0..3] by scale_smal
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[0], scale_sml_vec, x_vec[0] );
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[0], scale_sml_vec.v, x_vec[0].v );
// Square and add the elements to the accumulators
sum_sml_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_sml_vec[0] );
sum_sml_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_sml_vec[0].v );
}
}
}
// Reduction step
// Combining the results of accumulators for each category
sum_med_vec[0] = _mm512_add_pd( sum_med_vec[0], sum_med_vec[1] );
sum_med_vec[2] = _mm512_add_pd( sum_med_vec[2], sum_med_vec[3] );
sum_med_vec[0] = _mm512_add_pd( sum_med_vec[0], sum_med_vec[2] );
sum_med_vec[0].v = _mm512_add_pd( sum_med_vec[0].v, sum_med_vec[1].v );
sum_med_vec[2].v = _mm512_add_pd( sum_med_vec[2].v, sum_med_vec[3].v );
sum_med_vec[0].v = _mm512_add_pd( sum_med_vec[0].v, sum_med_vec[2].v );
sum_big_vec[0] = _mm512_add_pd( sum_big_vec[0], sum_big_vec[1] );
sum_big_vec[2] = _mm512_add_pd( sum_big_vec[2], sum_big_vec[3] );
sum_big_vec[0] = _mm512_add_pd( sum_big_vec[0], sum_big_vec[2] );
sum_big_vec[0].v = _mm512_add_pd( sum_big_vec[0].v, sum_big_vec[1].v );
sum_big_vec[2].v = _mm512_add_pd( sum_big_vec[2].v, sum_big_vec[3].v );
sum_big_vec[0].v = _mm512_add_pd( sum_big_vec[0].v, sum_big_vec[2].v );
sum_sml_vec[0] = _mm512_add_pd( sum_sml_vec[0], sum_sml_vec[1] );
sum_sml_vec[2] = _mm512_add_pd( sum_sml_vec[2], sum_sml_vec[3] );
sum_sml_vec[0] = _mm512_add_pd( sum_sml_vec[0], sum_sml_vec[2] );
sum_sml_vec[0].v = _mm512_add_pd( sum_sml_vec[0].v, sum_sml_vec[1].v );
sum_sml_vec[2].v = _mm512_add_pd( sum_sml_vec[2].v, sum_sml_vec[3].v );
sum_sml_vec[0].v = _mm512_add_pd( sum_sml_vec[0].v, sum_sml_vec[2].v );
// Final accumulation on the scalars
sum_sml += _mm512_reduce_add_pd( sum_sml_vec[0] );
sum_med += _mm512_reduce_add_pd( sum_med_vec[0] );
sum_big += _mm512_reduce_add_pd( sum_big_vec[0] );
sum_sml += sum_sml_vec[0].d[0] + sum_sml_vec[0].d[1] + sum_sml_vec[0].d[2] + sum_sml_vec[0].d[3]
+ sum_sml_vec[0].d[4] + sum_sml_vec[0].d[5] + sum_sml_vec[0].d[6] + sum_sml_vec[0].d[7];
sum_med += sum_med_vec[0].d[0] + sum_med_vec[0].d[1] + sum_med_vec[0].d[2] + sum_med_vec[0].d[3]
+ sum_med_vec[0].d[4] + sum_med_vec[0].d[5] + sum_med_vec[0].d[6] + sum_med_vec[0].d[7];
sum_big += sum_big_vec[0].d[0] + sum_big_vec[0].d[1] + sum_big_vec[0].d[2] + sum_big_vec[0].d[3]
+ sum_big_vec[0].d[4] + sum_big_vec[0].d[5] + sum_big_vec[0].d[6] + sum_big_vec[0].d[7];
}
// Dealing with non-unit strided inputs
else