Bug Fixes in LPGEMM for AVX512(SkyLake) machine (#24)

* Bug Fixes in LPGEMM for AVX512(SkyLake) machine

 - B-matrix in bf16bf16f32obf16/f32 API is re-ordered. For machines that
  doesn't support BF16 instructions, the BF16 input is unre-ordered and
  converted to FP32 to use FP32 kernels.

 - For n = 1 and k = 1 sized matrices, re-ordering in BF16 is copying the
  matrix to the re-ordered buffer array. But the un-reordering to FP32
  requires the matrix to have size multiple of 16 along n and multiple
  of 2 along k dimension.

 - The entry condition to the above has been modified for AVX512 configuration.

 - In bf16 API, the tiny path entry check has been modified to prevent
  seg fault while AOCL_ENABLE_INSTRUCTIONS=AVX2 is set in BF16 supporting
  machines.

 - Modified existing store instructions in FP32 AVX512 kernels to support
  execution in machines that has AVX512 support but not BF16/VNNI(SkyLake).

 - Added Bf16 beta and store types in FP32 avx512_256 kernels

AMD Internal: [SWLCSG-3552]

* Bug Fixes in LPGEMM for AVX512(SkyLake) machine

 - B-matrix in bf16bf16f32obf16/f32 API is re-ordered. For machines that
  doesn't support BF16 instructions, the BF16 input is unre-ordered and
  converted to FP32 to use FP32 kernels.

 - For n = 1 and k = 1 sized matrices, re-ordering in BF16 is copying the
  matrix to the re-ordered buffer array. But the un-reordering to FP32
  requires the matrix to have size multiple of 16 along n and multiple
  of 2 along k dimension.

 - The entry condition to the above has been modified for AVX512 configuration.

 - In bf16 API, the tiny path entry check has been modified to prevent
  seg fault while AOCL_ENABLE_INSTRUCTIONS=AVX2 is set in BF16 supporting
  machines.

 - Modified existing store instructions in FP32 AVX512 kernels to support
  execution in machines that has AVX512 support but not BF16/VNNI(SkyLake).

 - Added Bf16 beta and store types, along with BIAS and ZP in FP32 avx512_256
  kernels

AMD Internal: [SWLCSG-3552]

* Bug Fixes in LPGEMM for AVX512(SkyLake) machine

 - Support added in FP32 512_256 kerenls for : Beta, BIAS, Zero-point and
   BF16 store types for bf16bf16f32obf16 API execution in AVX2 mode.

 - B-matrix in bf16bf16f32obf16/f32 API is re-ordered. For machines that
  doesn't support BF16 instructions, the BF16 input is unre-ordered and
  converted to FP32 type to use FP32 kernels.

 - For n = 1 and k = 1 sized matrices, re-ordering in BF16 is copying the
  matrix to the re-ordered buffer array. But the un-reordering to FP32
  requires the matrix to have size multiple of 16 along n and multiple
  of 2 along k dimension. The entry condition here has been modified for
  AVX512 configuration.

 - Fix for seg fault with AOCL_ENABLE_INSTRUCTIONS=AVX2 mode in BF16/VNNI
   ISA supporting configruations:
   - BF16 tiny path entry check has been modified to take into account arch_id
     to ensure improper entry into the tiny kernel.
   - The store in BF16->FP32 col-major for m = 1 conditions were updated to
     correct storage pattern,
   - BF16 beta load macro was modified to account for data in unaligned memory.

 - Modified existing store instructions in FP32 AVX512 kernels to support
  execution in machines that has AVX512 support but not BF16/VNNI(SkyLake)

AMD Internal: [SWLCSG-3552]

---------

Co-authored-by: VarshaV <varshav2@amd.com>
This commit is contained in:
V, Varsha
2025-05-30 17:22:49 +05:30
committed by GitHub
parent 62d4fcb398
commit 532eab12d3
12 changed files with 1390 additions and 738 deletions

View File

@@ -669,7 +669,7 @@ void cvt_bf16_f32_col_major
SHUFFLE_8x8_AVX2
PERMUTE_8x8_AVX2
GET_STORE_MASK(4, store_mask);
MASKED_STORE_2COLS_AVX2(store_mask);
_mm256_maskstore_ps( ( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr ), store_mask, b_reg[0] ); \
}
for( ; ( kr + 1 ) < KC; kr += 2 )
{
@@ -683,7 +683,7 @@ void cvt_bf16_f32_col_major
SHUFFLE_8x8_AVX2
PERMUTE_8x8_AVX2
GET_STORE_MASK(2, store_mask);
MASKED_STORE_2COLS_AVX2(store_mask);
_mm256_maskstore_ps( ( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr ), store_mask, b_reg[0] );
}
for( ; kr < KC; kr += 1 )
{
@@ -695,7 +695,7 @@ void cvt_bf16_f32_col_major
SHUFFLE_8x8_AVX2
PERMUTE_8x8_AVX2
GET_STORE_MASK(1, store_mask);
MASKED_STORE_2COLS_AVX2(store_mask);
_mm256_maskstore_ps( ( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr ), store_mask, b_reg[0] ); \
}
}
}

View File

@@ -195,7 +195,7 @@ multiply with Beta, and add to alpha*A*B*/
( \
_mm256_cvtepi16_epi32 \
( \
_mm_load_si128 \
_mm_loadu_si128 \
( \
( __m128i const* )( \
( bfloat16* )post_ops_attr.buf_downscale + \

View File

@@ -76,12 +76,6 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m)
&&POST_OPS_SIGMOID_6x16F
};
uint64_t n_left = n0 % NR; //n0 is expected to be n0<=NR
// First check whether this is a edge case in the n dimension.
// If so, dispatch other 6x?m kernels, as needed.

View File

@@ -790,7 +790,7 @@ POST_OPS_DOWNSCALE_5x64F:
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
}
}
}
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
@@ -919,7 +919,7 @@ POST_OPS_DOWNSCALE_5x64F:
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point2, 2, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point3, 3, zp_mask)
}
@@ -932,7 +932,7 @@ POST_OPS_DOWNSCALE_5x64F:
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
post_ops_attr.post_op_c_i + 3 ) );
}
}
//c[0, 0-15]
@@ -992,15 +992,15 @@ POST_OPS_DOWNSCALE_5x64F:
{
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_first_k == TRUE ) )
{
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_COL_BCST(zero_point0, 4, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point0, 4, zp_mask)
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
}
}
}
//c[4, 0-15]
F32_SCL_MULRND(zmm24, selector1, zero_point0);
@@ -1454,33 +1454,38 @@ POST_OPS_SIGMOID_5x64F:
POST_OPS_5x64F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK(zmm11, 0, 3);
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK(zmm15, 1, 3);
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK(zmm19, 2, 3);
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK(zmm22, 3, 2);
CVT_STORE_F32_BF16_MASK(zmm23, 3, 3);
CVT_STORE_F32_BF16_MASK(zmm24, 4, 0);
CVT_STORE_F32_BF16_MASK(zmm25, 4, 1);
CVT_STORE_F32_BF16_MASK(zmm26, 4, 2);
CVT_STORE_F32_BF16_MASK(zmm27, 4, 3);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm11, 0, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm15, 1, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm19, 2, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm22, 3, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm23, 3, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm24, 4, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm25, 4, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm26, 4, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm27, 4, 3);
}
else
{
@@ -2266,7 +2271,7 @@ POST_OPS_DOWNSCALE_4x64F:
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point2, 2, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point3, 3, zp_mask)
}
@@ -2279,7 +2284,7 @@ POST_OPS_DOWNSCALE_4x64F:
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
post_ops_attr.post_op_c_i + 3 ) );
}
}
//c[0, 0-15]
@@ -2703,29 +2708,33 @@ POST_OPS_SIGMOID_4x64F:
}
POST_OPS_4x64F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK(zmm11, 0, 3);
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK(zmm15, 1, 3);
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK(zmm19, 2, 3);
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK(zmm22, 3, 2);
CVT_STORE_F32_BF16_MASK(zmm23, 3, 3);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm11, 0, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm15, 1, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm19, 2, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm22, 3, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm23, 3, 3);
}
else
{
@@ -3378,7 +3387,7 @@ POST_OPS_DOWNSCALE_3x64F:
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point2, 2, zp_mask)
}
else
@@ -3388,7 +3397,7 @@ POST_OPS_DOWNSCALE_3x64F:
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
post_ops_attr.post_op_c_i + 2 ) );
}
}
//c[0, 0-15]
@@ -3745,25 +3754,28 @@ POST_OPS_SIGMOID_3x64F:
}
POST_OPS_3x64F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK(zmm11, 0, 3);
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK(zmm15, 1, 3);
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK(zmm19, 2, 3);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm11, 0, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm15, 1, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm19, 2, 3);
}
else
{
@@ -4281,14 +4293,14 @@ POST_OPS_DOWNSCALE_2x64F:
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
post_ops_attr.post_op_c_i + 1 ) );
}
}
//c[0, 0-15]
@@ -4573,21 +4585,23 @@ POST_OPS_SIGMOID_2x64F:
}
POST_OPS_2x64F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK(zmm11, 0, 3);
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK(zmm15, 1, 3);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm11, 0, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm15, 1, 3);
}
else
{
@@ -4970,7 +4984,7 @@ POST_OPS_DOWNSCALE_1x64F:
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
post_ops_attr.post_op_c_i + 0 ) );
}
}
//c[0, 0-15]
@@ -5184,17 +5198,18 @@ POST_OPS_SIGMOID_1x64F:
}
POST_OPS_1x64F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK(zmm11, 0, 3);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm11, 0, 3);
}
else
{
@@ -5929,7 +5944,7 @@ POST_OPS_DOWNSCALE_5x48F:
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point2, 2, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point3, 3, zp_mask)
}
@@ -5990,9 +6005,9 @@ POST_OPS_DOWNSCALE_5x48F:
{
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_first_k == TRUE ) )
{
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_COL_BCST(zero_point0, 4, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point0, 4, zp_mask)
}
else
{
@@ -6394,33 +6409,33 @@ POST_OPS_SIGMOID_5x48F:
}
POST_OPS_5x48F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK(zmm22, 3, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK(zmm24, 4, 0);
CVT_STORE_F32_BF16_MASK(zmm25, 4, 1);
CVT_STORE_F32_BF16_MASK(zmm26, 4, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm22, 3, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm24, 4, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm25, 4, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm26, 4, 2);
}
else
{
@@ -7072,7 +7087,7 @@ POST_OPS_DOWNSCALE_4x48F:
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point2, 2, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point3, 3, zp_mask)
}
@@ -7454,29 +7469,29 @@ POST_OPS_SIGMOID_4x48F:
}
POST_OPS_4x48F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK(zmm22, 3, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm22, 3, 2);
}
else
{
@@ -8019,7 +8034,7 @@ POST_OPS_DOWNSCALE_3x48F:
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point2, 2, zp_mask)
}
else
@@ -8333,25 +8348,25 @@ POST_OPS_SIGMOID_3x48F:
}
POST_OPS_3x48F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
}
else
{
@@ -8789,7 +8804,7 @@ POST_OPS_DOWNSCALE_2x48F:
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
}
else
{
@@ -9042,21 +9057,21 @@ POST_OPS_SIGMOID_2x48F:
}
POST_OPS_2x48F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
}
else
{
@@ -9149,7 +9164,7 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x48)
zmm1 = _mm512_loadu_ps(cbuf + 16);
zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8);
zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9);
zmm0 = _mm512_loadu_ps(cbuf + 32);
zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10);
}
@@ -9580,17 +9595,17 @@ POST_OPS_SIGMOID_1x48F:
}
POST_OPS_1x48F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
}
else
{
@@ -10565,28 +10580,28 @@ POST_OPS_SIGMOID_5x32F:
}
POST_OPS_5x32F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK(zmm24, 4, 0);
CVT_STORE_F32_BF16_MASK(zmm25, 4, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm24, 4, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm25, 4, 1);
}
else
{
@@ -11439,25 +11454,25 @@ POST_OPS_SIGMOID_4x32F:
}
POST_OPS_4x32F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
}
else
{
@@ -12190,22 +12205,22 @@ POST_OPS_SIGMOID_3x32F:
}
POST_OPS_3x32F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
}
else
{
@@ -12778,19 +12793,19 @@ POST_OPS_SIGMOID_2x32F:
}
POST_OPS_2x32F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
}
else
{
@@ -13240,16 +13255,16 @@ POST_OPS_SIGMOID_1x32F:
}
POST_OPS_1x32F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
}
else
{
@@ -13257,4 +13272,4 @@ POST_OPS_1x32F_DISABLE:
_mm512_storeu_ps(cbuf + 16, zmm9);
}
}
#endif
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -222,6 +222,19 @@
mask_all1, (__m256i) _mm512_cvtneps_pbh( reg ) \
) \
#define CVT_STORE_F32_BF16_MASK_AVX512( reg, m_ind, n_ind ) \
{ \
_mm512_storeu_ps((float*)temp, reg); \
dest = ( bfloat16* )post_ops_attr.buf_downscale + \
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ); \
for(i = 0; i < 16; i++) \
{ \
tlsb = ( temp[i] & ( uint32_t )0x00010000 ) > 16; \
rounded = temp[i] + ( uint32_t )0x00007FFF + tlsb; \
memcpy( (dest+i), ((char *)(&rounded))+2, sizeof(bfloat16)); \
} \
}
// BF16 bias helper macros.
#define BF16_F32_BIAS_LOAD(scr,mask,n_ind) \
scr = ( __m512)( _mm512_sllv_epi32 \

View File

@@ -45,9 +45,9 @@
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m)
{
//Call RD kernels if B is transposed
if(rs_b == 1)
if(rs_b == 1 && n0 != 1 )
{
lpgemm_rowvar_f32f32f32of32_avx512_6x64m_rd
lpgemm_rowvar_f32f32f32of32_avx512_6x64m_rd
(
m0, n0, k0,
a, rs_a, cs_a, ps_a,
@@ -77,7 +77,6 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m)
};
uint64_t n_left = n0 % 64; //n0 is expected to be n0<=NR
// First check whether this is a edge case in the n dimension.
// If so, dispatch other 12x?m kernels, as needed.
if ( n_left )
@@ -1829,43 +1828,42 @@ POST_OPS_SIGMOID_6x64F:
POST_OPS_6x64F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK(zmm11, 0, 3);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK(zmm15, 1, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm11, 0, 3);
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK(zmm19, 2, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm15, 1, 3);
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK(zmm22, 3, 2);
CVT_STORE_F32_BF16_MASK(zmm23, 3, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm19, 2, 3);
CVT_STORE_F32_BF16_MASK(zmm24, 4, 0);
CVT_STORE_F32_BF16_MASK(zmm25, 4, 1);
CVT_STORE_F32_BF16_MASK(zmm26, 4, 2);
CVT_STORE_F32_BF16_MASK(zmm27, 4, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm22, 3, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm23, 3, 3);
CVT_STORE_F32_BF16_MASK(zmm28, 5, 0);
CVT_STORE_F32_BF16_MASK(zmm29, 5, 1);
CVT_STORE_F32_BF16_MASK(zmm30, 5, 2);
CVT_STORE_F32_BF16_MASK(zmm31, 5, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm24, 4, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm25, 4, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm26, 4, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm27, 4, 3);
CVT_STORE_F32_BF16_MASK_AVX512(zmm28, 5, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm29, 5, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm30, 5, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm31, 5, 3);
}
else
{
@@ -1904,6 +1902,7 @@ POST_OPS_6x64F_DISABLE:
}//mloop
consider_edge_cases:
// Handle edge cases in the m dimension, if they exist.
if( m_left )
{
@@ -3315,37 +3314,37 @@ POST_OPS_SIGMOID_6x48F:
POST_OPS_6x48F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK(zmm22, 3, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
CVT_STORE_F32_BF16_MASK(zmm24, 4, 0);
CVT_STORE_F32_BF16_MASK(zmm25, 4, 1);
CVT_STORE_F32_BF16_MASK(zmm26, 4, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm22, 3, 2);
CVT_STORE_F32_BF16_MASK(zmm28, 5, 0);
CVT_STORE_F32_BF16_MASK(zmm29, 5, 1);
CVT_STORE_F32_BF16_MASK(zmm30, 5, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm24, 4, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm25, 4, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm26, 4, 2);
CVT_STORE_F32_BF16_MASK_AVX512(zmm28, 5, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm29, 5, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm30, 5, 2);
}
else
{
@@ -4520,31 +4519,31 @@ POST_OPS_SIGMOID_6x32F:
}
POST_OPS_6x32F_DISABLE:
;
// Generate a mask16 of all 1's.
__m512i selector_a = _mm512_setzero_epi32();
__m512i selector_b = _mm512_set1_epi32( 10 );
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
uint32_t tlsb, rounded, temp[16] = {0};
int i;
bfloat16* dest;
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
CVT_STORE_F32_BF16_MASK(zmm24, 4, 0);
CVT_STORE_F32_BF16_MASK(zmm25, 4, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
CVT_STORE_F32_BF16_MASK(zmm28, 5, 0);
CVT_STORE_F32_BF16_MASK(zmm29, 5, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm24, 4, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm25, 4, 1);
CVT_STORE_F32_BF16_MASK_AVX512(zmm28, 5, 0);
CVT_STORE_F32_BF16_MASK_AVX512(zmm29, 5, 1);
}
else
{
@@ -4605,4 +4604,4 @@ POST_OPS_6x32F_DISABLE:
return;
}
}
#endif
#endif

View File

@@ -169,46 +169,82 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_256_6x32m)
{
ymm4 = _mm256_broadcast_ss(&(beta));
_cbuf = cbuf;
//load c and multiply with beta and
//add to accumulator and store back
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm8)
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm9)
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm10)
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm11)
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_first_k == TRUE ) )
{
BF16_F32_C_BNZ_8(0, 0, ymm0, ymm4, ymm8)
BF16_F32_C_BNZ_8(0, 1, ymm1, ymm4, ymm9)
BF16_F32_C_BNZ_8(0, 2, ymm2, ymm4, ymm10)
BF16_F32_C_BNZ_8(0, 3, ymm3, ymm4, ymm11)
_cbuf += rs_c;
BF16_F32_C_BNZ_8(1, 0, ymm0, ymm4, ymm12)
BF16_F32_C_BNZ_8(1, 1, ymm1, ymm4, ymm13)
BF16_F32_C_BNZ_8(1, 2, ymm2, ymm4, ymm14)
BF16_F32_C_BNZ_8(1, 3, ymm3, ymm4, ymm15)
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm12)
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm13)
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm14)
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm15)
BF16_F32_C_BNZ_8(2, 0, ymm0, ymm4, ymm16)
BF16_F32_C_BNZ_8(2, 1, ymm1, ymm4, ymm17)
BF16_F32_C_BNZ_8(2, 2, ymm2, ymm4, ymm18)
BF16_F32_C_BNZ_8(2, 3, ymm3, ymm4, ymm19)
_cbuf += rs_c;
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm16)
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm17)
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm18)
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm19)
BF16_F32_C_BNZ_8(3, 0, ymm0, ymm4, ymm20)
BF16_F32_C_BNZ_8(3, 1, ymm1, ymm4, ymm21)
BF16_F32_C_BNZ_8(3, 2, ymm2, ymm4, ymm22)
BF16_F32_C_BNZ_8(3, 3, ymm3, ymm4, ymm23)
_cbuf += rs_c;
BF16_F32_C_BNZ_8(4, 0, ymm0, ymm4, ymm24)
BF16_F32_C_BNZ_8(4, 1, ymm1, ymm4, ymm25)
BF16_F32_C_BNZ_8(4, 2, ymm2, ymm4, ymm26)
BF16_F32_C_BNZ_8(4, 3, ymm3, ymm4, ymm27)
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm20)
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm21)
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm22)
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm23)
BF16_F32_C_BNZ_8(5, 0, ymm0, ymm4, ymm28)
BF16_F32_C_BNZ_8(5, 1, ymm1, ymm4, ymm29)
BF16_F32_C_BNZ_8(5, 2, ymm2, ymm4, ymm30)
BF16_F32_C_BNZ_8(5, 3, ymm3, ymm4, ymm31)
}
else
{
_cbuf = cbuf;
//load c and multiply with beta and
//add to accumulator and store back
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm8)
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm9)
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm10)
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm11)
_cbuf += rs_c;
_cbuf += rs_c;
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm24)
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm25)
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm26)
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm27)
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm12)
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm13)
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm14)
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm15)
_cbuf += rs_c;
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm28)
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm29)
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm30)
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm31)
_cbuf += rs_c;
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm16)
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm17)
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm18)
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm19)
_cbuf += rs_c;
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm20)
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm21)
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm22)
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm23)
_cbuf += rs_c;
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm24)
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm25)
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm26)
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm27)
_cbuf += rs_c;
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm28)
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm29)
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm30)
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm31)
}
}
// Post Ops
@@ -220,14 +256,24 @@ POST_OPS_BIAS_6x32F:
if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
ymm0 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
ymm1 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 8 ) );
ymm2 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 8 ) );
ymm3 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 8 ) );
if( post_ops_list_temp->stor_type == BF16 )
{
BF16_F32_BIAS_LOAD_AVX2( ymm0, 0 )
BF16_F32_BIAS_LOAD_AVX2( ymm1, 1 )
BF16_F32_BIAS_LOAD_AVX2( ymm2, 2 )
BF16_F32_BIAS_LOAD_AVX2( ymm3, 3 )
}
else
{
ymm0 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
ymm1 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 8 ) );
ymm2 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 8 ) );
ymm3 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 8 ) );
}
ymm8 = _mm256_add_ps(ymm8, ymm0);
ymm9 = _mm256_add_ps(ymm9, ymm1);
@@ -261,18 +307,30 @@ POST_OPS_BIAS_6x32F:
}
else
{
ymm0 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 );
ymm1 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 );
ymm2 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 );
ymm3 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 );
ymm4 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 );
ymm5 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 5 );
if( post_ops_list_temp->stor_type == BF16 )
{
BF16_F32_BIAS_BCAST_AVX2(ymm0,0)
BF16_F32_BIAS_BCAST_AVX2(ymm1,1)
BF16_F32_BIAS_BCAST_AVX2(ymm2,2)
BF16_F32_BIAS_BCAST_AVX2(ymm3,3)
BF16_F32_BIAS_BCAST_AVX2(ymm4,4)
BF16_F32_BIAS_BCAST_AVX2(ymm5,5)
}
else
{
ymm0 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 );
ymm1 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 );
ymm2 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 );
ymm3 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 );
ymm4 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 );
ymm5 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 5 );
}
ymm8 = _mm256_add_ps(ymm8, ymm0);
ymm9 = _mm256_add_ps(ymm9, ymm0);
@@ -506,6 +564,10 @@ POST_OPS_DOWNSCALE_6x32F:
__m256 zero_point4 = _mm256_setzero_ps();
__m256 zero_point5 = _mm256_setzero_ps();
bool is_bf16 = ( post_ops_list_temp->stor_type == BF16 ) ||
( ( post_ops_list_temp->stor_type == NONE ) &&
( post_ops_attr.c_stor_type == BF16 ) );
if( post_ops_list_temp->scale_factor_len == 1 )
{
selector1 =
@@ -524,12 +586,24 @@ POST_OPS_DOWNSCALE_6x32F:
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
if( is_bf16 == TRUE )
{
BF16_F32_ZP_SCALAR_BCAST_AVX2(zero_point0)
BF16_F32_ZP_SCALAR_BCAST_AVX2(zero_point1)
BF16_F32_ZP_SCALAR_BCAST_AVX2(zero_point2)
BF16_F32_ZP_SCALAR_BCAST_AVX2(zero_point3)
BF16_F32_ZP_SCALAR_BCAST_AVX2(zero_point4)
BF16_F32_ZP_SCALAR_BCAST_AVX2(zero_point5)
}
else
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
}
}
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
@@ -548,16 +622,25 @@ POST_OPS_DOWNSCALE_6x32F:
}
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
zero_point1 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 8 ) );
zero_point2 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 8 ) );
zero_point3 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 8 ) );
if( is_bf16 == TRUE )
{
BF16_F32_ZP_VECTOR_LOAD_AVX2(zero_point0,0)
BF16_F32_ZP_VECTOR_LOAD_AVX2(zero_point1,1)
BF16_F32_ZP_VECTOR_LOAD_AVX2(zero_point2,2)
BF16_F32_ZP_VECTOR_LOAD_AVX2(zero_point3,3)
}
else
{
zero_point0 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
zero_point1 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 8 ) );
zero_point2 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 8 ) );
zero_point3 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 8 ) );
}
}
F32_SCL_MULRND_AVX2(ymm8, selector1, zero_point0)
F32_SCL_MULRND_AVX2(ymm9, selector2, zero_point1)
F32_SCL_MULRND_AVX2(ymm10, selector3, zero_point2)
@@ -613,18 +696,30 @@ POST_OPS_DOWNSCALE_6x32F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 5 ) );
if( is_bf16 == TRUE )
{
BF16_F32_ZP_VECTOR_BCAST_AVX2(zero_point0,0)
BF16_F32_ZP_VECTOR_BCAST_AVX2(zero_point1,1)
BF16_F32_ZP_VECTOR_BCAST_AVX2(zero_point2,2)
BF16_F32_ZP_VECTOR_BCAST_AVX2(zero_point3,3)
BF16_F32_ZP_VECTOR_BCAST_AVX2(zero_point4,4)
BF16_F32_ZP_VECTOR_BCAST_AVX2(zero_point5,5)
}
else
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 5 ) );
}
}
F32_SCL_MULRND_AVX2(ymm8, selector1, zero_point0)
@@ -1006,51 +1101,91 @@ POST_OPS_SIGMOID_6x32F:
}
POST_OPS_6x32F_DISABLE:
{
_mm256_storeu_ps(cbuf, ymm8);
_mm256_storeu_ps(cbuf + 8, ymm9);
_mm256_storeu_ps(cbuf + 16, ymm10);
_mm256_storeu_ps(cbuf + 24, ymm11);
uint32_t tlsb, rounded, temp[8] = {0};
int i;
bfloat16* dest;
cbuf += rs_c;
if ( ( post_ops_attr.buf_downscale != NULL ) &&
( post_ops_attr.is_last_k == TRUE ) )
{
STORE_F32_BF16_YMM(ymm8, 0, 0)
STORE_F32_BF16_YMM(ymm9, 0, 1)
STORE_F32_BF16_YMM(ymm10, 0, 2)
STORE_F32_BF16_YMM(ymm11, 0, 3)
_mm256_storeu_ps(cbuf, ymm12);
_mm256_storeu_ps(cbuf + 8, ymm13);
_mm256_storeu_ps(cbuf + 16, ymm14);
_mm256_storeu_ps(cbuf + 24, ymm15);
STORE_F32_BF16_YMM(ymm12, 1, 0)
STORE_F32_BF16_YMM(ymm13, 1, 1)
STORE_F32_BF16_YMM(ymm14, 1, 2)
STORE_F32_BF16_YMM(ymm15, 1, 3)
cbuf += rs_c;
STORE_F32_BF16_YMM(ymm16, 2, 0)
STORE_F32_BF16_YMM(ymm17, 2, 1)
STORE_F32_BF16_YMM(ymm18, 2, 2)
STORE_F32_BF16_YMM(ymm19, 2, 3)
_mm256_storeu_ps(cbuf, ymm16);
_mm256_storeu_ps(cbuf + 8, ymm17);
_mm256_storeu_ps(cbuf + 16, ymm18);
_mm256_storeu_ps(cbuf + 24, ymm19);
STORE_F32_BF16_YMM(ymm20, 3, 0)
STORE_F32_BF16_YMM(ymm21, 3, 1)
STORE_F32_BF16_YMM(ymm22, 3, 2)
STORE_F32_BF16_YMM(ymm23, 3, 3)
cbuf += rs_c;
STORE_F32_BF16_YMM(ymm24, 4, 0)
STORE_F32_BF16_YMM(ymm25, 4, 1)
STORE_F32_BF16_YMM(ymm26, 4, 2)
STORE_F32_BF16_YMM(ymm27, 4, 3)
_mm256_storeu_ps(cbuf, ymm20);
_mm256_storeu_ps(cbuf + 8, ymm21);
_mm256_storeu_ps(cbuf + 16, ymm22);
_mm256_storeu_ps(cbuf + 24, ymm23);
STORE_F32_BF16_YMM(ymm28, 5, 0)
STORE_F32_BF16_YMM(ymm29, 5, 1)
STORE_F32_BF16_YMM(ymm30, 5, 2)
STORE_F32_BF16_YMM(ymm31, 5, 3)
}
else
{
_mm256_storeu_ps(cbuf, ymm8);
_mm256_storeu_ps(cbuf + 8, ymm9);
_mm256_storeu_ps(cbuf + 16, ymm10);
_mm256_storeu_ps(cbuf + 24, ymm11);
cbuf += rs_c;
cbuf += rs_c;
_mm256_storeu_ps(cbuf, ymm24);
_mm256_storeu_ps(cbuf + 8, ymm25);
_mm256_storeu_ps(cbuf + 16, ymm26);
_mm256_storeu_ps(cbuf + 24, ymm27);
_mm256_storeu_ps(cbuf, ymm12);
_mm256_storeu_ps(cbuf + 8, ymm13);
_mm256_storeu_ps(cbuf + 16, ymm14);
_mm256_storeu_ps(cbuf + 24, ymm15);
cbuf += rs_c;
cbuf += rs_c;
_mm256_storeu_ps(cbuf, ymm28);
_mm256_storeu_ps(cbuf + 8, ymm29);
_mm256_storeu_ps(cbuf + 16, ymm30);
_mm256_storeu_ps(cbuf + 24, ymm31);
_mm256_storeu_ps(cbuf, ymm16);
_mm256_storeu_ps(cbuf + 8, ymm17);
_mm256_storeu_ps(cbuf + 16, ymm18);
_mm256_storeu_ps(cbuf + 24, ymm19);
cbuf += rs_c;
_mm256_storeu_ps(cbuf, ymm20);
_mm256_storeu_ps(cbuf + 8, ymm21);
_mm256_storeu_ps(cbuf + 16, ymm22);
_mm256_storeu_ps(cbuf + 24, ymm23);
cbuf += rs_c;
_mm256_storeu_ps(cbuf, ymm24);
_mm256_storeu_ps(cbuf + 8, ymm25);
_mm256_storeu_ps(cbuf + 16, ymm26);
_mm256_storeu_ps(cbuf + 24, ymm27);
cbuf += rs_c;
_mm256_storeu_ps(cbuf, ymm28);
_mm256_storeu_ps(cbuf + 8, ymm29);
_mm256_storeu_ps(cbuf + 16, ymm30);
_mm256_storeu_ps(cbuf + 24, ymm31);
}
}
post_ops_attr.post_op_c_i += MR;
}//mloop
consider_edge_cases:
// Handle edge cases in the m dimension, if they exist.
if ( m_left )
{