mirror of
https://github.com/amd/blis.git
synced 2026-05-24 10:24:34 +00:00
Updating reduction step of AVX512 DNRM2 API
- Updated the final reduction of partial sums to use scalar accumulation entirely, instead of using the _mm512_reduce_add_pd( ... ) intrinsic. This will in turn change the associativity and the rounding-off pattern in the reduction step. - Defined a union data-type to do the same, by having a 512-bit register and a double-precision array as its members. - Updated the declaration and usage of the register variable according to the union definition, for uniformity. AMD-Internal: [CPUPL-5472] Change-Id: I997464a6ec47e4054dca48a000fbd4ac0cfcc679
This commit is contained in:
@@ -34,6 +34,14 @@
|
||||
#include "immintrin.h"
|
||||
#include "blis.h"
|
||||
|
||||
// Union data structure to access AVX registers
|
||||
// One 512-bit AVX register holds 8 DP elements.
|
||||
typedef union
|
||||
{
|
||||
__m512d v;
|
||||
double d[8] __attribute__( ( aligned( 64 ) ) );
|
||||
} v8df_t;
|
||||
|
||||
/*
|
||||
Optimized kernel that computes the Frobenius norm using AVX512 intrinsics.
|
||||
The kernel takes in the following input parameters :
|
||||
@@ -88,9 +96,9 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
{
|
||||
// AVX-512 code-section
|
||||
// Declaring registers for loading, accumulation, thresholds and scale factors
|
||||
__m512d x_vec[4], sum_sml_vec[4], sum_med_vec[4], sum_big_vec[4], temp[4];
|
||||
__m512d thresh_sml_vec, thresh_big_vec, scale_sml_vec, scale_big_vec;
|
||||
__m512d zero_reg;
|
||||
v8df_t x_vec[4], sum_sml_vec[4], sum_med_vec[4], sum_big_vec[4], temp[4];
|
||||
v8df_t thresh_sml_vec, thresh_big_vec, scale_sml_vec, scale_big_vec;
|
||||
v8df_t zero_reg;
|
||||
|
||||
// Masks to be used in computation
|
||||
__mmask8 k_mask[8];
|
||||
@@ -101,55 +109,55 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
unsigned char truth_val[4];
|
||||
|
||||
// Setting the thresholds and scaling factors
|
||||
thresh_sml_vec = _mm512_set1_pd( thresh_sml );
|
||||
thresh_big_vec = _mm512_set1_pd( thresh_big );
|
||||
scale_sml_vec = _mm512_set1_pd( scale_sml );
|
||||
scale_big_vec = _mm512_set1_pd( scale_big );
|
||||
thresh_sml_vec.v = _mm512_set1_pd( thresh_sml );
|
||||
thresh_big_vec.v = _mm512_set1_pd( thresh_big );
|
||||
scale_sml_vec.v = _mm512_set1_pd( scale_sml );
|
||||
scale_big_vec.v = _mm512_set1_pd( scale_big );
|
||||
|
||||
// Resetting the accumulators
|
||||
sum_sml_vec[0] = _mm512_setzero_pd();
|
||||
sum_sml_vec[1] = _mm512_setzero_pd();
|
||||
sum_sml_vec[2] = _mm512_setzero_pd();
|
||||
sum_sml_vec[3] = _mm512_setzero_pd();
|
||||
sum_sml_vec[0].v = _mm512_setzero_pd();
|
||||
sum_sml_vec[1].v = _mm512_setzero_pd();
|
||||
sum_sml_vec[2].v = _mm512_setzero_pd();
|
||||
sum_sml_vec[3].v = _mm512_setzero_pd();
|
||||
|
||||
sum_med_vec[0] = _mm512_setzero_pd();
|
||||
sum_med_vec[1] = _mm512_setzero_pd();
|
||||
sum_med_vec[2] = _mm512_setzero_pd();
|
||||
sum_med_vec[3] = _mm512_setzero_pd();
|
||||
sum_med_vec[0].v = _mm512_setzero_pd();
|
||||
sum_med_vec[1].v = _mm512_setzero_pd();
|
||||
sum_med_vec[2].v = _mm512_setzero_pd();
|
||||
sum_med_vec[3].v = _mm512_setzero_pd();
|
||||
|
||||
sum_big_vec[0] = _mm512_setzero_pd();
|
||||
sum_big_vec[1] = _mm512_setzero_pd();
|
||||
sum_big_vec[2] = _mm512_setzero_pd();
|
||||
sum_big_vec[3] = _mm512_setzero_pd();
|
||||
sum_big_vec[0].v = _mm512_setzero_pd();
|
||||
sum_big_vec[1].v = _mm512_setzero_pd();
|
||||
sum_big_vec[2].v = _mm512_setzero_pd();
|
||||
sum_big_vec[3].v = _mm512_setzero_pd();
|
||||
|
||||
zero_reg = _mm512_setzero_pd();
|
||||
zero_reg.v = _mm512_setzero_pd();
|
||||
|
||||
// Computing in blocks of 32
|
||||
for ( ; ( i + 32 ) <= n; i = i + 32 )
|
||||
{
|
||||
// Set temp[0..3] to zero
|
||||
temp[0] = _mm512_setzero_pd();
|
||||
temp[1] = _mm512_setzero_pd();
|
||||
temp[2] = _mm512_setzero_pd();
|
||||
temp[3] = _mm512_setzero_pd();
|
||||
temp[0].v = _mm512_setzero_pd();
|
||||
temp[1].v = _mm512_setzero_pd();
|
||||
temp[2].v = _mm512_setzero_pd();
|
||||
temp[3].v = _mm512_setzero_pd();
|
||||
|
||||
// Loading the vectors
|
||||
x_vec[0] = _mm512_loadu_pd( xt );
|
||||
x_vec[1] = _mm512_loadu_pd( xt + 8 );
|
||||
x_vec[2] = _mm512_loadu_pd( xt + 16 );
|
||||
x_vec[3] = _mm512_loadu_pd( xt + 24 );
|
||||
x_vec[0].v = _mm512_loadu_pd( xt );
|
||||
x_vec[1].v = _mm512_loadu_pd( xt + 8 );
|
||||
x_vec[2].v = _mm512_loadu_pd( xt + 16 );
|
||||
x_vec[3].v = _mm512_loadu_pd( xt + 24 );
|
||||
|
||||
// Comparing to check for NaN
|
||||
// Bits in the mask are set if NaN is encountered
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], x_vec[0], _CMP_UNORD_Q );
|
||||
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1], x_vec[1], _CMP_UNORD_Q );
|
||||
k_mask[2] = _mm512_cmp_pd_mask( x_vec[2], x_vec[2], _CMP_UNORD_Q );
|
||||
k_mask[3] = _mm512_cmp_pd_mask( x_vec[3], x_vec[3], _CMP_UNORD_Q );
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, x_vec[0].v, _CMP_UNORD_Q );
|
||||
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1].v, x_vec[1].v, _CMP_UNORD_Q );
|
||||
k_mask[2] = _mm512_cmp_pd_mask( x_vec[2].v, x_vec[2].v, _CMP_UNORD_Q );
|
||||
k_mask[3] = _mm512_cmp_pd_mask( x_vec[3].v, x_vec[3].v, _CMP_UNORD_Q );
|
||||
|
||||
// Checking if any bit in the masks are set
|
||||
// The truth_val is set to 0 if any bit in the mask is 1
|
||||
// Thus, truth_val[0] = 0 if x_vec[0] or x_vec[1] has NaN
|
||||
// truth_val[1] = 0 if x_vec[2] or x_vec[3] has NaN
|
||||
// Thus, truth_val[0] = 0 if x_vec[0].v or x_vec[1].v has NaN
|
||||
// truth_val[1] = 0 if x_vec[2].v or x_vec[3].v has NaN
|
||||
truth_val[0] = _kortestz_mask8_u8( k_mask[0], k_mask[1] );
|
||||
truth_val[1] = _kortestz_mask8_u8( k_mask[2], k_mask[3] );
|
||||
|
||||
@@ -163,30 +171,30 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
}
|
||||
|
||||
// Getting the absoulte values of elements in the vectors
|
||||
x_vec[0] = _mm512_abs_pd( x_vec[0] );
|
||||
x_vec[1] = _mm512_abs_pd( x_vec[1] );
|
||||
x_vec[2] = _mm512_abs_pd( x_vec[2] );
|
||||
x_vec[3] = _mm512_abs_pd( x_vec[3] );
|
||||
x_vec[0].v = _mm512_abs_pd( x_vec[0].v );
|
||||
x_vec[1].v = _mm512_abs_pd( x_vec[1].v );
|
||||
x_vec[2].v = _mm512_abs_pd( x_vec[2].v );
|
||||
x_vec[3].v = _mm512_abs_pd( x_vec[3].v );
|
||||
|
||||
// Setting the masks by comparing with thresh_sml_vec
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0][i] > thresh_sml_vec
|
||||
// k_mask[1][i] = 1 if x_vec[1][i] > thresh_sml_vec
|
||||
// k_mask[2][i] = 1 if x_vec[2][i] > thresh_sml_vec
|
||||
// k_mask[3][i] = 1 if x_vec[3][i] > thresh_sml_vec
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], thresh_sml_vec, _CMP_GT_OS );
|
||||
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1], thresh_sml_vec, _CMP_GT_OS );
|
||||
k_mask[2] = _mm512_cmp_pd_mask( x_vec[2], thresh_sml_vec, _CMP_GT_OS );
|
||||
k_mask[3] = _mm512_cmp_pd_mask( x_vec[3], thresh_sml_vec, _CMP_GT_OS );
|
||||
// Setting the masks by comparing with thresh_sml_vec.v
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] > thresh_sml_vec.v
|
||||
// k_mask[1][i] = 1 if x_vec[1].v[i] > thresh_sml_vec.v
|
||||
// k_mask[2][i] = 1 if x_vec[2].v[i] > thresh_sml_vec.v
|
||||
// k_mask[3][i] = 1 if x_vec[3].v[i] > thresh_sml_vec.v
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_sml_vec.v, _CMP_GT_OS );
|
||||
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1].v, thresh_sml_vec.v, _CMP_GT_OS );
|
||||
k_mask[2] = _mm512_cmp_pd_mask( x_vec[2].v, thresh_sml_vec.v, _CMP_GT_OS );
|
||||
k_mask[3] = _mm512_cmp_pd_mask( x_vec[3].v, thresh_sml_vec.v, _CMP_GT_OS );
|
||||
|
||||
// Setting the masks by comparing with thresh_big_vec
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0][i] < thresh_big_vec
|
||||
// k_mask[5][i] = 1 if x_vec[1][i] < thresh_big_vec
|
||||
// k_mask[6][i] = 1 if x_vec[2][i] < thresh_big_vec
|
||||
// k_mask[7][i] = 1 if x_vec[3][i] < thresh_big_vec
|
||||
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0], thresh_big_vec, _CMP_LT_OS );
|
||||
k_mask[5] = _mm512_cmp_pd_mask( x_vec[1], thresh_big_vec, _CMP_LT_OS );
|
||||
k_mask[6] = _mm512_cmp_pd_mask( x_vec[2], thresh_big_vec, _CMP_LT_OS );
|
||||
k_mask[7] = _mm512_cmp_pd_mask( x_vec[3], thresh_big_vec, _CMP_LT_OS );
|
||||
// Setting the masks by comparing with thresh_big_vec.v
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] < thresh_big_vec.v
|
||||
// k_mask[5][i] = 1 if x_vec[1].v[i] < thresh_big_vec.v
|
||||
// k_mask[6][i] = 1 if x_vec[2].v[i] < thresh_big_vec.v
|
||||
// k_mask[7][i] = 1 if x_vec[3].v[i] < thresh_big_vec.v
|
||||
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_big_vec.v, _CMP_LT_OS );
|
||||
k_mask[5] = _mm512_cmp_pd_mask( x_vec[1].v, thresh_big_vec.v, _CMP_LT_OS );
|
||||
k_mask[6] = _mm512_cmp_pd_mask( x_vec[2].v, thresh_big_vec.v, _CMP_LT_OS );
|
||||
k_mask[7] = _mm512_cmp_pd_mask( x_vec[3].v, thresh_big_vec.v, _CMP_LT_OS );
|
||||
|
||||
// Setting the masks to filter only the elements within the thresholds
|
||||
// k_mask[0 ... 3] contain masks for elements > thresh_sml
|
||||
@@ -200,10 +208,10 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// Setting booleans to check for underflow/overflow handling
|
||||
// In case of having values outside threshold, the associated
|
||||
// bit in k_mask[4 ... 7] is 0.
|
||||
// Thus, truth_val[0] = 0 if x_vec[0] has elements outside thresholds
|
||||
// truth_val[1] = 0 if x_vec[1] has elements outside thresholds
|
||||
// truth_val[2] = 0 if x_vec[2] has elements outside thresholds
|
||||
// truth_val[3] = 0 if x_vec[3] has elements outside thresholds
|
||||
// Thus, truth_val[0] = 0 if x_vec[0].v has elements outside thresholds
|
||||
// truth_val[1] = 0 if x_vec[1].v has elements outside thresholds
|
||||
// truth_val[2] = 0 if x_vec[2].v has elements outside thresholds
|
||||
// truth_val[3] = 0 if x_vec[3].v has elements outside thresholds
|
||||
truth_val[0] = _kortestc_mask8_u8( k_mask[4], k_mask[4] );
|
||||
truth_val[1] = _kortestc_mask8_u8( k_mask[5], k_mask[5] );
|
||||
truth_val[2] = _kortestc_mask8_u8( k_mask[6], k_mask[6] );
|
||||
@@ -211,10 +219,10 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
|
||||
// Computing using masked fmadds, that carries over values from
|
||||
// accumulator register if the mask bit is 0
|
||||
sum_med_vec[0] = _mm512_mask3_fmadd_pd( x_vec[0], x_vec[0], sum_med_vec[0], k_mask[4] );
|
||||
sum_med_vec[1] = _mm512_mask3_fmadd_pd( x_vec[1], x_vec[1], sum_med_vec[1], k_mask[5] );
|
||||
sum_med_vec[2] = _mm512_mask3_fmadd_pd( x_vec[2], x_vec[2], sum_med_vec[2], k_mask[6] );
|
||||
sum_med_vec[3] = _mm512_mask3_fmadd_pd( x_vec[3], x_vec[3], sum_med_vec[3], k_mask[7] );
|
||||
sum_med_vec[0].v = _mm512_mask3_fmadd_pd( x_vec[0].v, x_vec[0].v, sum_med_vec[0].v, k_mask[4] );
|
||||
sum_med_vec[1].v = _mm512_mask3_fmadd_pd( x_vec[1].v, x_vec[1].v, sum_med_vec[1].v, k_mask[5] );
|
||||
sum_med_vec[2].v = _mm512_mask3_fmadd_pd( x_vec[2].v, x_vec[2].v, sum_med_vec[2].v, k_mask[6] );
|
||||
sum_med_vec[3].v = _mm512_mask3_fmadd_pd( x_vec[3].v, x_vec[3].v, sum_med_vec[3].v, k_mask[7] );
|
||||
|
||||
// In case of having elements outside the threshold
|
||||
if( !( truth_val[0] && truth_val[1] && truth_val[2] && truth_val[3] ) )
|
||||
@@ -224,20 +232,20 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// k_mask[0 ... 3] contain masks for elements > thresh_sml. This would
|
||||
// include both elements < thresh_big and >= thresh_big
|
||||
// XOR on these will produce masks for elements >= thresh_big
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0][i] >= thresh_big_vec
|
||||
// k_mask[5][i] = 1 if x_vec[1][i] >= thresh_big_vec
|
||||
// k_mask[6][i] = 1 if x_vec[2][i] >= thresh_big_vec
|
||||
// k_mask[7][i] = 1 if x_vec[3][i] >= thresh_big_vec
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] >= thresh_big_vec.v
|
||||
// k_mask[5][i] = 1 if x_vec[1].v[i] >= thresh_big_vec.v
|
||||
// k_mask[6][i] = 1 if x_vec[2].v[i] >= thresh_big_vec.v
|
||||
// k_mask[7][i] = 1 if x_vec[3].v[i] >= thresh_big_vec.v
|
||||
k_mask[4] = _kxor_mask8( k_mask[0], k_mask[4] );
|
||||
k_mask[5] = _kxor_mask8( k_mask[1], k_mask[5] );
|
||||
k_mask[6] = _kxor_mask8( k_mask[2], k_mask[6] );
|
||||
k_mask[7] = _kxor_mask8( k_mask[3], k_mask[7] );
|
||||
|
||||
// Inverting k_mask[0 ... 3], to obtain masks for elements <= thresh_sml
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0][i] <= thresh_sml_vec
|
||||
// k_mask[1][i] = 1 if x_vec[1][i] <= thresh_sml_vec
|
||||
// k_mask[2][i] = 1 if x_vec[2][i] <= thresh_sml_vec
|
||||
// k_mask[3][i] = 1 if x_vec[3][i] <= thresh_sml_vec
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] <= thresh_sml_vec.v
|
||||
// k_mask[1][i] = 1 if x_vec[1].v[i] <= thresh_sml_vec.v
|
||||
// k_mask[2][i] = 1 if x_vec[2].v[i] <= thresh_sml_vec.v
|
||||
// k_mask[3][i] = 1 if x_vec[3].v[i] <= thresh_sml_vec.v
|
||||
k_mask[0] = _knot_mask8( k_mask[0] );
|
||||
k_mask[1] = _knot_mask8( k_mask[1] );
|
||||
k_mask[2] = _knot_mask8( k_mask[2] );
|
||||
@@ -245,8 +253,8 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
|
||||
// Checking whether we have values greater than thresh_big
|
||||
// The truth_val is set to 0 if any bit in the mask is 1
|
||||
// Thus, truth_val[2] = 0 if x_vec[0] or x_vec[1] has elements >= thresh_big_vec
|
||||
// truth_val[3] = 0 if x_vec[2] or x_vec[3] has elements >= thresh_big_vec
|
||||
// Thus, truth_val[2] = 0 if x_vec[0].v or x_vec[1].v has elements >= thresh_big_vec.v
|
||||
// truth_val[3] = 0 if x_vec[2].v or x_vec[3].v has elements >= thresh_big_vec.v
|
||||
truth_val[2] = _kortestz_mask8_u8( k_mask[4], k_mask[5] );
|
||||
truth_val[3] = _kortestz_mask8_u8( k_mask[6], k_mask[7] );
|
||||
|
||||
@@ -261,16 +269,16 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// are greater than thresh_big
|
||||
|
||||
// Scale the required elements in x_vec[0..3] by scale_smal
|
||||
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[4], scale_big_vec, x_vec[0] );
|
||||
temp[1] = _mm512_mask_mul_pd( zero_reg, k_mask[5], scale_big_vec, x_vec[1] );
|
||||
temp[2] = _mm512_mask_mul_pd( zero_reg, k_mask[6], scale_big_vec, x_vec[2] );
|
||||
temp[3] = _mm512_mask_mul_pd( zero_reg, k_mask[7], scale_big_vec, x_vec[3] );
|
||||
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[4], scale_big_vec.v, x_vec[0].v );
|
||||
temp[1].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[5], scale_big_vec.v, x_vec[1].v );
|
||||
temp[2].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[6], scale_big_vec.v, x_vec[2].v );
|
||||
temp[3].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[7], scale_big_vec.v, x_vec[3].v );
|
||||
|
||||
// Square and add the elements to the accumulators
|
||||
sum_big_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_big_vec[0] );
|
||||
sum_big_vec[1] = _mm512_fmadd_pd( temp[1], temp[1], sum_big_vec[1] );
|
||||
sum_big_vec[2] = _mm512_fmadd_pd( temp[2], temp[2], sum_big_vec[2] );
|
||||
sum_big_vec[3] = _mm512_fmadd_pd( temp[3], temp[3], sum_big_vec[3] );
|
||||
sum_big_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_big_vec[0].v );
|
||||
sum_big_vec[1].v = _mm512_fmadd_pd( temp[1].v, temp[1].v, sum_big_vec[1].v );
|
||||
sum_big_vec[2].v = _mm512_fmadd_pd( temp[2].v, temp[2].v, sum_big_vec[2].v );
|
||||
sum_big_vec[3].v = _mm512_fmadd_pd( temp[3].v, temp[3].v, sum_big_vec[3].v );
|
||||
}
|
||||
else if( !isbig )
|
||||
{
|
||||
@@ -279,16 +287,16 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// are lesser than thresh_sml, if needed
|
||||
|
||||
// Scale the required elements in x_vec[0..3] by scale_smal
|
||||
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[0], scale_sml_vec, x_vec[0] );
|
||||
temp[1] = _mm512_mask_mul_pd( zero_reg, k_mask[1], scale_sml_vec, x_vec[1] );
|
||||
temp[2] = _mm512_mask_mul_pd( zero_reg, k_mask[2], scale_sml_vec, x_vec[2] );
|
||||
temp[3] = _mm512_mask_mul_pd( zero_reg, k_mask[3], scale_sml_vec, x_vec[3] );
|
||||
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[0], scale_sml_vec.v, x_vec[0].v );
|
||||
temp[1].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[1], scale_sml_vec.v, x_vec[1].v );
|
||||
temp[2].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[2], scale_sml_vec.v, x_vec[2].v );
|
||||
temp[3].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[3], scale_sml_vec.v, x_vec[3].v );
|
||||
|
||||
// Square and add the elements to the accumulators
|
||||
sum_sml_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_sml_vec[0] );
|
||||
sum_sml_vec[1] = _mm512_fmadd_pd( temp[1], temp[1], sum_sml_vec[1] );
|
||||
sum_sml_vec[2] = _mm512_fmadd_pd( temp[2], temp[2], sum_sml_vec[2] );
|
||||
sum_sml_vec[3] = _mm512_fmadd_pd( temp[3], temp[3], sum_sml_vec[3] );
|
||||
sum_sml_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_sml_vec[0].v );
|
||||
sum_sml_vec[1].v = _mm512_fmadd_pd( temp[1].v, temp[1].v, sum_sml_vec[1].v );
|
||||
sum_sml_vec[2].v = _mm512_fmadd_pd( temp[2].v, temp[2].v, sum_sml_vec[2].v );
|
||||
sum_sml_vec[3].v = _mm512_fmadd_pd( temp[3].v, temp[3].v, sum_sml_vec[3].v );
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,21 +308,21 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
for ( ; ( i + 16 ) <= n; i = i + 16 )
|
||||
{
|
||||
// Set temp[0..1] to zero
|
||||
temp[0] = _mm512_setzero_pd();
|
||||
temp[1] = _mm512_setzero_pd();
|
||||
temp[0].v = _mm512_setzero_pd();
|
||||
temp[1].v = _mm512_setzero_pd();
|
||||
|
||||
// Loading the vectors
|
||||
x_vec[0] = _mm512_loadu_pd( xt );
|
||||
x_vec[1] = _mm512_loadu_pd( xt + 8 );
|
||||
x_vec[0].v = _mm512_loadu_pd( xt );
|
||||
x_vec[1].v = _mm512_loadu_pd( xt + 8 );
|
||||
|
||||
// Comparing to check for NaN
|
||||
// Bits in the mask are set if NaN is encountered
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], x_vec[0], _CMP_UNORD_Q );
|
||||
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1], x_vec[1], _CMP_UNORD_Q );
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, x_vec[0].v, _CMP_UNORD_Q );
|
||||
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1].v, x_vec[1].v, _CMP_UNORD_Q );
|
||||
|
||||
// Checking if any bit in the masks are set
|
||||
// The truth_val is set to 0 if any bit in the mask is 1
|
||||
// Thus, truth_val[0] = 0 if x_vec[0] or x_vec[1] has NaN
|
||||
// Thus, truth_val[0] = 0 if x_vec[0].v or x_vec[1].v has NaN
|
||||
truth_val[0] = _kortestz_mask8_u8( k_mask[0], k_mask[1] );
|
||||
|
||||
// Set norm to NaN and return early, if either truth_val[0] or truth_val[1] is set to 0
|
||||
@@ -327,20 +335,20 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
}
|
||||
|
||||
// Getting the absoulte values of elements in the vectors
|
||||
x_vec[0] = _mm512_abs_pd( x_vec[0] );
|
||||
x_vec[1] = _mm512_abs_pd( x_vec[1] );
|
||||
x_vec[0].v = _mm512_abs_pd( x_vec[0].v );
|
||||
x_vec[1].v = _mm512_abs_pd( x_vec[1].v );
|
||||
|
||||
// Setting the masks by comparing with thresh_sml_vec
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0][i] > thresh_sml_vec
|
||||
// k_mask[1][i] = 1 if x_vec[1][i] > thresh_sml_vec
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], thresh_sml_vec, _CMP_GT_OS );
|
||||
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1], thresh_sml_vec, _CMP_GT_OS );
|
||||
// Setting the masks by comparing with thresh_sml_vec.v
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] > thresh_sml_vec.v
|
||||
// k_mask[1][i] = 1 if x_vec[1].v[i] > thresh_sml_vec.v
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_sml_vec.v, _CMP_GT_OS );
|
||||
k_mask[1] = _mm512_cmp_pd_mask( x_vec[1].v, thresh_sml_vec.v, _CMP_GT_OS );
|
||||
|
||||
// Setting the masks by comparing with thresh_big_vec
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0][i] < thresh_big_vec
|
||||
// k_mask[5][i] = 1 if x_vec[1][i] < thresh_big_vec
|
||||
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0], thresh_big_vec, _CMP_LT_OS );
|
||||
k_mask[5] = _mm512_cmp_pd_mask( x_vec[1], thresh_big_vec, _CMP_LT_OS );
|
||||
// Setting the masks by comparing with thresh_big_vec.v
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] < thresh_big_vec.v
|
||||
// k_mask[5][i] = 1 if x_vec[1].v[i] < thresh_big_vec.v
|
||||
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_big_vec.v, _CMP_LT_OS );
|
||||
k_mask[5] = _mm512_cmp_pd_mask( x_vec[1].v, thresh_big_vec.v, _CMP_LT_OS );
|
||||
|
||||
// Setting the masks to filter only the elements within the thresholds
|
||||
// k_mask[0 ... 1] contain masks for elements > thresh_sml
|
||||
@@ -352,15 +360,15 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// Setting booleans to check for underflow/overflow handling
|
||||
// In case of having values outside threshold, the associated
|
||||
// bit in k_mask[4 ... 7] is 0.
|
||||
// Thus, truth_val[0] = 0 if x_vec[0] has elements outside thresholds
|
||||
// truth_val[1] = 0 if x_vec[1] has elements outside thresholds
|
||||
// Thus, truth_val[0] = 0 if x_vec[0].v has elements outside thresholds
|
||||
// truth_val[1] = 0 if x_vec[1].v has elements outside thresholds
|
||||
truth_val[0] = _kortestc_mask8_u8( k_mask[4], k_mask[4] );
|
||||
truth_val[1] = _kortestc_mask8_u8( k_mask[5], k_mask[5] );
|
||||
|
||||
// Computing using masked fmadds, that carries over values from
|
||||
// accumulator register if the mask bit is 0
|
||||
sum_med_vec[0] = _mm512_mask3_fmadd_pd( x_vec[0], x_vec[0], sum_med_vec[0], k_mask[4] );
|
||||
sum_med_vec[1] = _mm512_mask3_fmadd_pd( x_vec[1], x_vec[1], sum_med_vec[1], k_mask[5] );
|
||||
sum_med_vec[0].v = _mm512_mask3_fmadd_pd( x_vec[0].v, x_vec[0].v, sum_med_vec[0].v, k_mask[4] );
|
||||
sum_med_vec[1].v = _mm512_mask3_fmadd_pd( x_vec[1].v, x_vec[1].v, sum_med_vec[1].v, k_mask[5] );
|
||||
|
||||
// In case of having elements outside the threshold
|
||||
if( !( truth_val[0] && truth_val[1] ) )
|
||||
@@ -370,20 +378,20 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// k_mask[0 ... 1] contain masks for elements > thresh_sml. This would
|
||||
// include both elements < thresh_big and >= thresh_big
|
||||
// XOR on these will produce masks for elements >= thresh_big
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0][i] >= thresh_big_vec
|
||||
// k_mask[5][i] = 1 if x_vec[1][i] >= thresh_big_vec
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] >= thresh_big_vec.v
|
||||
// k_mask[5][i] = 1 if x_vec[1].v[i] >= thresh_big_vec.v
|
||||
k_mask[4] = _kxor_mask8( k_mask[0], k_mask[4] );
|
||||
k_mask[5] = _kxor_mask8( k_mask[1], k_mask[5] );
|
||||
|
||||
// Inverting k_mask[0 ... 1], to obtain masks for elements <= thresh_sml
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0][i] <= thresh_sml_vec
|
||||
// k_mask[1][i] = 1 if x_vec[1][i] <= thresh_sml_vec
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] <= thresh_sml_vec.v
|
||||
// k_mask[1][i] = 1 if x_vec[1].v[i] <= thresh_sml_vec.v
|
||||
k_mask[0] = _knot_mask8( k_mask[0] );
|
||||
k_mask[1] = _knot_mask8( k_mask[1] );
|
||||
|
||||
// Checking whether we have values greater than thresh_big
|
||||
// The truth_val is set to 0 if any bit in the mask is 1
|
||||
// Thus, truth_val[2] = 0 if x_vec[0] or x_vec[1] has elements >= thresh_big_vec
|
||||
// Thus, truth_val[2] = 0 if x_vec[0].v or x_vec[1].v has elements >= thresh_big_vec.v
|
||||
truth_val[2] = _kortestz_mask8_u8( k_mask[4], k_mask[5] );
|
||||
|
||||
// In case of having values greater than thresh_big
|
||||
@@ -397,12 +405,12 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// are greater than thresh_big
|
||||
|
||||
// Scale the required elements in x_vec[0..3] by scale_smal
|
||||
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[4], scale_big_vec, x_vec[0] );
|
||||
temp[1] = _mm512_mask_mul_pd( zero_reg, k_mask[5], scale_big_vec, x_vec[1] );
|
||||
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[4], scale_big_vec.v, x_vec[0].v );
|
||||
temp[1].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[5], scale_big_vec.v, x_vec[1].v );
|
||||
|
||||
// Square and add the elements to the accumulators
|
||||
sum_big_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_big_vec[0] );
|
||||
sum_big_vec[1] = _mm512_fmadd_pd( temp[1], temp[1], sum_big_vec[1] );
|
||||
sum_big_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_big_vec[0].v );
|
||||
sum_big_vec[1].v = _mm512_fmadd_pd( temp[1].v, temp[1].v, sum_big_vec[1].v );
|
||||
}
|
||||
else if( !isbig )
|
||||
{
|
||||
@@ -411,12 +419,12 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// are lesser than thresh_sml, if needed
|
||||
|
||||
// Scale the required elements in x_vec[0..3] by scale_smal
|
||||
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[0], scale_sml_vec, x_vec[0] );
|
||||
temp[1] = _mm512_mask_mul_pd( zero_reg, k_mask[1], scale_sml_vec, x_vec[1] );
|
||||
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[0], scale_sml_vec.v, x_vec[0].v );
|
||||
temp[1].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[1], scale_sml_vec.v, x_vec[1].v );
|
||||
|
||||
// Square and add the elements to the accumulators
|
||||
sum_sml_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_sml_vec[0] );
|
||||
sum_sml_vec[1] = _mm512_fmadd_pd( temp[1], temp[1], sum_sml_vec[1] );
|
||||
sum_sml_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_sml_vec[0].v );
|
||||
sum_sml_vec[1].v = _mm512_fmadd_pd( temp[1].v, temp[1].v, sum_sml_vec[1].v );
|
||||
}
|
||||
}
|
||||
|
||||
@@ -425,19 +433,19 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
}
|
||||
for ( ; ( i + 8 ) <= n; i = i + 8 )
|
||||
{
|
||||
// Set temp[0] to zero
|
||||
temp[0] = _mm512_setzero_pd();
|
||||
// Set temp[0].v to zero
|
||||
temp[0].v = _mm512_setzero_pd();
|
||||
|
||||
// Loading the vectors
|
||||
x_vec[0] = _mm512_loadu_pd( xt );
|
||||
x_vec[0].v = _mm512_loadu_pd( xt );
|
||||
|
||||
// Comparing to check for NaN
|
||||
// Bits in the mask are set if NaN is encountered
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], x_vec[0], _CMP_UNORD_Q );
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, x_vec[0].v, _CMP_UNORD_Q );
|
||||
|
||||
// Checking if any bit in the masks are set
|
||||
// The truth_val is set to 0 if any bit in the mask is 1
|
||||
// Thus, truth_val[0] = 0 if x_vec[0] or x_vec[1] has NaN
|
||||
// Thus, truth_val[0] = 0 if x_vec[0].v or x_vec[1].v has NaN
|
||||
truth_val[0] = _kortestz_mask8_u8( k_mask[0], k_mask[0] );
|
||||
|
||||
// Set norm to NaN and return early, if either truth_val[0] or truth_val[1] is set to 0
|
||||
@@ -450,15 +458,15 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
}
|
||||
|
||||
// Getting the absoulte values of elements in the vectors
|
||||
x_vec[0] = _mm512_abs_pd( x_vec[0] );
|
||||
x_vec[0].v = _mm512_abs_pd( x_vec[0].v );
|
||||
|
||||
// Setting the masks by comparing with thresh_sml_vec
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0][i] > thresh_sml_vec
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], thresh_sml_vec, _CMP_GT_OS );
|
||||
// Setting the masks by comparing with thresh_sml_vec.v
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] > thresh_sml_vec.v
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_sml_vec.v, _CMP_GT_OS );
|
||||
|
||||
// Setting the masks by comparing with thresh_big_vec
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0][i] < thresh_big_vec
|
||||
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0], thresh_big_vec, _CMP_LT_OS );
|
||||
// Setting the masks by comparing with thresh_big_vec.v
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] < thresh_big_vec.v
|
||||
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_big_vec.v, _CMP_LT_OS );
|
||||
|
||||
// Setting the masks to filter only the elements within the thresholds
|
||||
// k_mask[0] contain masks for elements > thresh_sml
|
||||
@@ -469,12 +477,12 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// Setting booleans to check for underflow/overflow handling
|
||||
// In case of having values outside threshold, the associated
|
||||
// bit in k_mask[4] is 0.
|
||||
// Thus, truth_val[0] = 0 if x_vec[0] has elements outside thresholds
|
||||
// Thus, truth_val[0] = 0 if x_vec[0].v has elements outside thresholds
|
||||
truth_val[0] = _kortestc_mask8_u8( k_mask[4], k_mask[4] );
|
||||
|
||||
// Computing using masked fmadds, that carries over values from
|
||||
// accumulator register if the mask bit is 0
|
||||
sum_med_vec[0] = _mm512_mask3_fmadd_pd( x_vec[0], x_vec[0], sum_med_vec[0], k_mask[4] );
|
||||
sum_med_vec[0].v = _mm512_mask3_fmadd_pd( x_vec[0].v, x_vec[0].v, sum_med_vec[0].v, k_mask[4] );
|
||||
|
||||
// In case of having elements outside the threshold
|
||||
if( !truth_val[0] )
|
||||
@@ -484,18 +492,18 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// k_mask[0 ... 1] contain masks for elements > thresh_sml. This would
|
||||
// include both elements < thresh_big and >= thresh_big
|
||||
// XOR on these will produce masks for elements >= thresh_big
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0][i] >= thresh_big_vec
|
||||
// k_mask[5][i] = 1 if x_vec[1][i] >= thresh_big_vec
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] >= thresh_big_vec.v
|
||||
// k_mask[5][i] = 1 if x_vec[1].v[i] >= thresh_big_vec.v
|
||||
k_mask[4] = _kxor_mask8( k_mask[0], k_mask[4] );
|
||||
|
||||
// Inverting k_mask[0 ... 1], to obtain masks for elements <= thresh_sml
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0][i] <= thresh_sml_vec
|
||||
// k_mask[1][i] = 1 if x_vec[1][i] <= thresh_sml_vec
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] <= thresh_sml_vec.v
|
||||
// k_mask[1][i] = 1 if x_vec[1].v[i] <= thresh_sml_vec.v
|
||||
k_mask[0] = _knot_mask8( k_mask[0] );
|
||||
|
||||
// Checking whether we have values greater than thresh_big
|
||||
// The truth_val is set to 0 if any bit in the mask is 1
|
||||
// Thus, truth_val[2] = 0 if x_vec[0] or x_vec[1] has elements >= thresh_big_vec
|
||||
// Thus, truth_val[2] = 0 if x_vec[0].v or x_vec[1].v has elements >= thresh_big_vec.v
|
||||
truth_val[2] = _kortestz_mask8_u8( k_mask[4], k_mask[4] );
|
||||
|
||||
// In case of having values greater than thresh_big
|
||||
@@ -509,10 +517,10 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// are greater than thresh_big
|
||||
|
||||
// Scale the required elements in x_vec[0..3] by scale_smal
|
||||
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[4], scale_big_vec, x_vec[0] );
|
||||
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[4], scale_big_vec.v, x_vec[0].v );
|
||||
|
||||
// Square and add the elements to the accumulators
|
||||
sum_big_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_big_vec[0] );
|
||||
sum_big_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_big_vec[0].v );
|
||||
}
|
||||
else if( !isbig )
|
||||
{
|
||||
@@ -521,10 +529,10 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// are lesser than thresh_sml, if needed
|
||||
|
||||
// Scale the required elements in x_vec[0..3] by scale_smal
|
||||
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[0], scale_sml_vec, x_vec[0] );
|
||||
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[0], scale_sml_vec.v, x_vec[0].v );
|
||||
|
||||
// Square and add the elements to the accumulators
|
||||
sum_sml_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_sml_vec[0] );
|
||||
sum_sml_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_sml_vec[0].v );
|
||||
}
|
||||
}
|
||||
|
||||
@@ -533,22 +541,22 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
}
|
||||
if( i < n )
|
||||
{
|
||||
// Set temp[0] to zero
|
||||
temp[0] = _mm512_setzero_pd();
|
||||
// Set temp[0].v to zero
|
||||
temp[0].v = _mm512_setzero_pd();
|
||||
|
||||
// Setting the mask to load
|
||||
k_mask[0] = ( 1 << ( n - i ) ) - 1;
|
||||
|
||||
// Loading the vectors
|
||||
x_vec[0] = _mm512_maskz_loadu_pd( k_mask[0], xt );
|
||||
x_vec[0].v = _mm512_maskz_loadu_pd( k_mask[0], xt );
|
||||
|
||||
// Comparing to check for NaN
|
||||
// Bits in the mask are set if NaN is encountered
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], x_vec[0], _CMP_UNORD_Q );
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, x_vec[0].v, _CMP_UNORD_Q );
|
||||
|
||||
// Checking if any bit in the masks are set
|
||||
// The truth_val is set to 0 if any bit in the mask is 1
|
||||
// Thus, truth_val[0] = 0 if x_vec[0] or x_vec[1] has NaN
|
||||
// Thus, truth_val[0] = 0 if x_vec[0].v or x_vec[1].v has NaN
|
||||
truth_val[0] = _kortestz_mask8_u8( k_mask[0], k_mask[0] );
|
||||
|
||||
// Set norm to NaN and return early, if either truth_val[0] or truth_val[1] is set to 0
|
||||
@@ -561,15 +569,15 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
}
|
||||
|
||||
// Getting the absoulte values of elements in the vectors
|
||||
x_vec[0] = _mm512_abs_pd( x_vec[0] );
|
||||
x_vec[0].v = _mm512_abs_pd( x_vec[0].v );
|
||||
|
||||
// Setting the masks by comparing with thresh_sml_vec
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0][i] > thresh_sml_vec
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0], thresh_sml_vec, _CMP_GT_OS );
|
||||
// Setting the masks by comparing with thresh_sml_vec.v
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] > thresh_sml_vec.v
|
||||
k_mask[0] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_sml_vec.v, _CMP_GT_OS );
|
||||
|
||||
// Setting the masks by comparing with thresh_big_vec
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0][i] < thresh_big_vec
|
||||
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0], thresh_big_vec, _CMP_LT_OS );
|
||||
// Setting the masks by comparing with thresh_big_vec.v
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] < thresh_big_vec.v
|
||||
k_mask[4] = _mm512_cmp_pd_mask( x_vec[0].v, thresh_big_vec.v, _CMP_LT_OS );
|
||||
|
||||
// Setting the masks to filter only the elements within the thresholds
|
||||
// k_mask[0] contain masks for elements > thresh_sml
|
||||
@@ -580,12 +588,12 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// Setting booleans to check for underflow/overflow handling
|
||||
// In case of having values outside threshold, the associated
|
||||
// bit in k_mask[4] is 0.
|
||||
// Thus, truth_val[0] = 0 if x_vec[0] has elements outside thresholds
|
||||
// Thus, truth_val[0] = 0 if x_vec[0].v has elements outside thresholds
|
||||
truth_val[0] = _kortestc_mask8_u8( k_mask[4], k_mask[4] );
|
||||
|
||||
// Computing using masked fmadds, that carries over values from
|
||||
// accumulator register if the mask bit is 0
|
||||
sum_med_vec[0] = _mm512_mask3_fmadd_pd( x_vec[0], x_vec[0], sum_med_vec[0], k_mask[4] );
|
||||
sum_med_vec[0].v = _mm512_mask3_fmadd_pd( x_vec[0].v, x_vec[0].v, sum_med_vec[0].v, k_mask[4] );
|
||||
|
||||
// In case of having elements outside the threshold
|
||||
if( !truth_val[0] )
|
||||
@@ -595,18 +603,18 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// k_mask[0 ... 1] contain masks for elements > thresh_sml. This would
|
||||
// include both elements < thresh_big and >= thresh_big
|
||||
// XOR on these will produce masks for elements >= thresh_big
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0][i] >= thresh_big_vec
|
||||
// k_mask[5][i] = 1 if x_vec[1][i] >= thresh_big_vec
|
||||
// That is, k_mask[4][i] = 1 if x_vec[0].v[i] >= thresh_big_vec.v
|
||||
// k_mask[5][i] = 1 if x_vec[1].v[i] >= thresh_big_vec.v
|
||||
k_mask[4] = _kxor_mask8( k_mask[0], k_mask[4] );
|
||||
|
||||
// Inverting k_mask[0 ... 1], to obtain masks for elements <= thresh_sml
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0][i] <= thresh_sml_vec
|
||||
// k_mask[1][i] = 1 if x_vec[1][i] <= thresh_sml_vec
|
||||
// That is, k_mask[0][i] = 1 if x_vec[0].v[i] <= thresh_sml_vec.v
|
||||
// k_mask[1][i] = 1 if x_vec[1].v[i] <= thresh_sml_vec.v
|
||||
k_mask[0] = _knot_mask8( k_mask[0] );
|
||||
|
||||
// Checking whether we have values greater than thresh_big
|
||||
// The truth_val is set to 0 if any bit in the mask is 1
|
||||
// Thus, truth_val[2] = 0 if x_vec[0] or x_vec[1] has elements >= thresh_big_vec
|
||||
// Thus, truth_val[2] = 0 if x_vec[0].v or x_vec[1].v has elements >= thresh_big_vec.v
|
||||
truth_val[2] = _kortestz_mask8_u8( k_mask[4], k_mask[4] );
|
||||
|
||||
// In case of having values greater than thresh_big
|
||||
@@ -620,10 +628,10 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// are greater than thresh_big
|
||||
|
||||
// Scale the required elements in x_vec[0..3] by scale_smal
|
||||
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[4], scale_big_vec, x_vec[0] );
|
||||
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[4], scale_big_vec.v, x_vec[0].v );
|
||||
|
||||
// Square and add the elements to the accumulators
|
||||
sum_big_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_big_vec[0] );
|
||||
sum_big_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_big_vec[0].v );
|
||||
}
|
||||
else if( !isbig )
|
||||
{
|
||||
@@ -632,32 +640,35 @@ void bli_dnorm2fv_unb_var1_avx512
|
||||
// are lesser than thresh_sml, if needed
|
||||
|
||||
// Scale the required elements in x_vec[0..3] by scale_smal
|
||||
temp[0] = _mm512_mask_mul_pd( zero_reg, k_mask[0], scale_sml_vec, x_vec[0] );
|
||||
temp[0].v = _mm512_mask_mul_pd( zero_reg.v, k_mask[0], scale_sml_vec.v, x_vec[0].v );
|
||||
|
||||
// Square and add the elements to the accumulators
|
||||
sum_sml_vec[0] = _mm512_fmadd_pd( temp[0], temp[0], sum_sml_vec[0] );
|
||||
sum_sml_vec[0].v = _mm512_fmadd_pd( temp[0].v, temp[0].v, sum_sml_vec[0].v );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduction step
|
||||
// Combining the results of accumulators for each category
|
||||
sum_med_vec[0] = _mm512_add_pd( sum_med_vec[0], sum_med_vec[1] );
|
||||
sum_med_vec[2] = _mm512_add_pd( sum_med_vec[2], sum_med_vec[3] );
|
||||
sum_med_vec[0] = _mm512_add_pd( sum_med_vec[0], sum_med_vec[2] );
|
||||
sum_med_vec[0].v = _mm512_add_pd( sum_med_vec[0].v, sum_med_vec[1].v );
|
||||
sum_med_vec[2].v = _mm512_add_pd( sum_med_vec[2].v, sum_med_vec[3].v );
|
||||
sum_med_vec[0].v = _mm512_add_pd( sum_med_vec[0].v, sum_med_vec[2].v );
|
||||
|
||||
sum_big_vec[0] = _mm512_add_pd( sum_big_vec[0], sum_big_vec[1] );
|
||||
sum_big_vec[2] = _mm512_add_pd( sum_big_vec[2], sum_big_vec[3] );
|
||||
sum_big_vec[0] = _mm512_add_pd( sum_big_vec[0], sum_big_vec[2] );
|
||||
sum_big_vec[0].v = _mm512_add_pd( sum_big_vec[0].v, sum_big_vec[1].v );
|
||||
sum_big_vec[2].v = _mm512_add_pd( sum_big_vec[2].v, sum_big_vec[3].v );
|
||||
sum_big_vec[0].v = _mm512_add_pd( sum_big_vec[0].v, sum_big_vec[2].v );
|
||||
|
||||
sum_sml_vec[0] = _mm512_add_pd( sum_sml_vec[0], sum_sml_vec[1] );
|
||||
sum_sml_vec[2] = _mm512_add_pd( sum_sml_vec[2], sum_sml_vec[3] );
|
||||
sum_sml_vec[0] = _mm512_add_pd( sum_sml_vec[0], sum_sml_vec[2] );
|
||||
sum_sml_vec[0].v = _mm512_add_pd( sum_sml_vec[0].v, sum_sml_vec[1].v );
|
||||
sum_sml_vec[2].v = _mm512_add_pd( sum_sml_vec[2].v, sum_sml_vec[3].v );
|
||||
sum_sml_vec[0].v = _mm512_add_pd( sum_sml_vec[0].v, sum_sml_vec[2].v );
|
||||
|
||||
// Final accumulation on the scalars
|
||||
sum_sml += _mm512_reduce_add_pd( sum_sml_vec[0] );
|
||||
sum_med += _mm512_reduce_add_pd( sum_med_vec[0] );
|
||||
sum_big += _mm512_reduce_add_pd( sum_big_vec[0] );
|
||||
sum_sml += sum_sml_vec[0].d[0] + sum_sml_vec[0].d[1] + sum_sml_vec[0].d[2] + sum_sml_vec[0].d[3]
|
||||
+ sum_sml_vec[0].d[4] + sum_sml_vec[0].d[5] + sum_sml_vec[0].d[6] + sum_sml_vec[0].d[7];
|
||||
sum_med += sum_med_vec[0].d[0] + sum_med_vec[0].d[1] + sum_med_vec[0].d[2] + sum_med_vec[0].d[3]
|
||||
+ sum_med_vec[0].d[4] + sum_med_vec[0].d[5] + sum_med_vec[0].d[6] + sum_med_vec[0].d[7];
|
||||
sum_big += sum_big_vec[0].d[0] + sum_big_vec[0].d[1] + sum_big_vec[0].d[2] + sum_big_vec[0].d[3]
|
||||
+ sum_big_vec[0].d[4] + sum_big_vec[0].d[5] + sum_big_vec[0].d[6] + sum_big_vec[0].d[7];
|
||||
}
|
||||
// Dealing with non-unit strided inputs
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user