mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Bug Fixes in LPGEMM for AVX512(SkyLake) machine (#24)
* Bug Fixes in LPGEMM for AVX512(SkyLake) machine
- B-matrix in bf16bf16f32obf16/f32 API is re-ordered. For machines that
doesn't support BF16 instructions, the BF16 input is unre-ordered and
converted to FP32 to use FP32 kernels.
- For n = 1 and k = 1 sized matrices, re-ordering in BF16 is copying the
matrix to the re-ordered buffer array. But the un-reordering to FP32
requires the matrix to have size multiple of 16 along n and multiple
of 2 along k dimension.
- The entry condition to the above has been modified for AVX512 configuration.
- In bf16 API, the tiny path entry check has been modified to prevent
seg fault while AOCL_ENABLE_INSTRUCTIONS=AVX2 is set in BF16 supporting
machines.
- Modified existing store instructions in FP32 AVX512 kernels to support
execution in machines that has AVX512 support but not BF16/VNNI(SkyLake).
- Added Bf16 beta and store types in FP32 avx512_256 kernels
AMD Internal: [SWLCSG-3552]
* Bug Fixes in LPGEMM for AVX512(SkyLake) machine
- B-matrix in bf16bf16f32obf16/f32 API is re-ordered. For machines that
doesn't support BF16 instructions, the BF16 input is unre-ordered and
converted to FP32 to use FP32 kernels.
- For n = 1 and k = 1 sized matrices, re-ordering in BF16 is copying the
matrix to the re-ordered buffer array. But the un-reordering to FP32
requires the matrix to have size multiple of 16 along n and multiple
of 2 along k dimension.
- The entry condition to the above has been modified for AVX512 configuration.
- In bf16 API, the tiny path entry check has been modified to prevent
seg fault while AOCL_ENABLE_INSTRUCTIONS=AVX2 is set in BF16 supporting
machines.
- Modified existing store instructions in FP32 AVX512 kernels to support
execution in machines that has AVX512 support but not BF16/VNNI(SkyLake).
- Added Bf16 beta and store types, along with BIAS and ZP in FP32 avx512_256
kernels
AMD Internal: [SWLCSG-3552]
* Bug Fixes in LPGEMM for AVX512(SkyLake) machine
- Support added in FP32 512_256 kerenls for : Beta, BIAS, Zero-point and
BF16 store types for bf16bf16f32obf16 API execution in AVX2 mode.
- B-matrix in bf16bf16f32obf16/f32 API is re-ordered. For machines that
doesn't support BF16 instructions, the BF16 input is unre-ordered and
converted to FP32 type to use FP32 kernels.
- For n = 1 and k = 1 sized matrices, re-ordering in BF16 is copying the
matrix to the re-ordered buffer array. But the un-reordering to FP32
requires the matrix to have size multiple of 16 along n and multiple
of 2 along k dimension. The entry condition here has been modified for
AVX512 configuration.
- Fix for seg fault with AOCL_ENABLE_INSTRUCTIONS=AVX2 mode in BF16/VNNI
ISA supporting configruations:
- BF16 tiny path entry check has been modified to take into account arch_id
to ensure improper entry into the tiny kernel.
- The store in BF16->FP32 col-major for m = 1 conditions were updated to
correct storage pattern,
- BF16 beta load macro was modified to account for data in unaligned memory.
- Modified existing store instructions in FP32 AVX512 kernels to support
execution in machines that has AVX512 support but not BF16/VNNI(SkyLake)
AMD Internal: [SWLCSG-3552]
---------
Co-authored-by: VarshaV <varshav2@amd.com>
This commit is contained in:
@@ -669,7 +669,7 @@ void cvt_bf16_f32_col_major
|
||||
SHUFFLE_8x8_AVX2
|
||||
PERMUTE_8x8_AVX2
|
||||
GET_STORE_MASK(4, store_mask);
|
||||
MASKED_STORE_2COLS_AVX2(store_mask);
|
||||
_mm256_maskstore_ps( ( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr ), store_mask, b_reg[0] ); \
|
||||
}
|
||||
for( ; ( kr + 1 ) < KC; kr += 2 )
|
||||
{
|
||||
@@ -683,7 +683,7 @@ void cvt_bf16_f32_col_major
|
||||
SHUFFLE_8x8_AVX2
|
||||
PERMUTE_8x8_AVX2
|
||||
GET_STORE_MASK(2, store_mask);
|
||||
MASKED_STORE_2COLS_AVX2(store_mask);
|
||||
_mm256_maskstore_ps( ( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr ), store_mask, b_reg[0] );
|
||||
}
|
||||
for( ; kr < KC; kr += 1 )
|
||||
{
|
||||
@@ -695,7 +695,7 @@ void cvt_bf16_f32_col_major
|
||||
SHUFFLE_8x8_AVX2
|
||||
PERMUTE_8x8_AVX2
|
||||
GET_STORE_MASK(1, store_mask);
|
||||
MASKED_STORE_2COLS_AVX2(store_mask);
|
||||
_mm256_maskstore_ps( ( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr ), store_mask, b_reg[0] ); \
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ multiply with Beta, and add to alpha*A*B*/
|
||||
( \
|
||||
_mm256_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm_load_si128 \
|
||||
_mm_loadu_si128 \
|
||||
( \
|
||||
( __m128i const* )( \
|
||||
( bfloat16* )post_ops_attr.buf_downscale + \
|
||||
|
||||
@@ -76,12 +76,6 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m)
|
||||
&&POST_OPS_SIGMOID_6x16F
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
uint64_t n_left = n0 % NR; //n0 is expected to be n0<=NR
|
||||
// First check whether this is a edge case in the n dimension.
|
||||
// If so, dispatch other 6x?m kernels, as needed.
|
||||
|
||||
@@ -790,7 +790,7 @@ POST_OPS_DOWNSCALE_5x64F:
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
}
|
||||
}
|
||||
}
|
||||
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
|
||||
@@ -919,7 +919,7 @@ POST_OPS_DOWNSCALE_5x64F:
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point2, 2, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point3, 3, zp_mask)
|
||||
}
|
||||
@@ -932,7 +932,7 @@ POST_OPS_DOWNSCALE_5x64F:
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
@@ -992,15 +992,15 @@ POST_OPS_DOWNSCALE_5x64F:
|
||||
{
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_first_k == TRUE ) )
|
||||
{
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_COL_BCST(zero_point0, 4, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point0, 4, zp_mask)
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
}
|
||||
}
|
||||
}
|
||||
//c[4, 0-15]
|
||||
F32_SCL_MULRND(zmm24, selector1, zero_point0);
|
||||
@@ -1454,33 +1454,38 @@ POST_OPS_SIGMOID_5x64F:
|
||||
|
||||
POST_OPS_5x64F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm11, 0, 3);
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm15, 1, 3);
|
||||
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm19, 2, 3);
|
||||
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm22, 3, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm23, 3, 3);
|
||||
CVT_STORE_F32_BF16_MASK(zmm24, 4, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm25, 4, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm26, 4, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm27, 4, 3);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm11, 0, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm15, 1, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm19, 2, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm22, 3, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm23, 3, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm24, 4, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm25, 4, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm26, 4, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm27, 4, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -2266,7 +2271,7 @@ POST_OPS_DOWNSCALE_4x64F:
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point2, 2, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point3, 3, zp_mask)
|
||||
}
|
||||
@@ -2279,7 +2284,7 @@ POST_OPS_DOWNSCALE_4x64F:
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
@@ -2703,29 +2708,33 @@ POST_OPS_SIGMOID_4x64F:
|
||||
}
|
||||
POST_OPS_4x64F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm11, 0, 3);
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm15, 1, 3);
|
||||
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm19, 2, 3);
|
||||
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm22, 3, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm23, 3, 3);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm11, 0, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm15, 1, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm19, 2, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm22, 3, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm23, 3, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -3378,7 +3387,7 @@ POST_OPS_DOWNSCALE_3x64F:
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point2, 2, zp_mask)
|
||||
}
|
||||
else
|
||||
@@ -3388,7 +3397,7 @@ POST_OPS_DOWNSCALE_3x64F:
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
@@ -3745,25 +3754,28 @@ POST_OPS_SIGMOID_3x64F:
|
||||
}
|
||||
POST_OPS_3x64F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm11, 0, 3);
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm15, 1, 3);
|
||||
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm19, 2, 3);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm11, 0, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm15, 1, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm19, 2, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -4281,14 +4293,14 @@ POST_OPS_DOWNSCALE_2x64F:
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
@@ -4573,21 +4585,23 @@ POST_OPS_SIGMOID_2x64F:
|
||||
}
|
||||
POST_OPS_2x64F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm11, 0, 3);
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm15, 1, 3);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm11, 0, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm15, 1, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -4970,7 +4984,7 @@ POST_OPS_DOWNSCALE_1x64F:
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
@@ -5184,17 +5198,18 @@ POST_OPS_SIGMOID_1x64F:
|
||||
}
|
||||
POST_OPS_1x64F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm11, 0, 3);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm11, 0, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -5929,7 +5944,7 @@ POST_OPS_DOWNSCALE_5x48F:
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point2, 2, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point3, 3, zp_mask)
|
||||
}
|
||||
@@ -5990,9 +6005,9 @@ POST_OPS_DOWNSCALE_5x48F:
|
||||
{
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_first_k == TRUE ) )
|
||||
{
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_COL_BCST(zero_point0, 4, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point0, 4, zp_mask)
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -6394,33 +6409,33 @@ POST_OPS_SIGMOID_5x48F:
|
||||
}
|
||||
POST_OPS_5x48F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm22, 3, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm24, 4, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm25, 4, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm26, 4, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm22, 3, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm24, 4, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm25, 4, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm26, 4, 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -7072,7 +7087,7 @@ POST_OPS_DOWNSCALE_4x48F:
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point2, 2, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point3, 3, zp_mask)
|
||||
}
|
||||
@@ -7454,29 +7469,29 @@ POST_OPS_SIGMOID_4x48F:
|
||||
}
|
||||
POST_OPS_4x48F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm22, 3, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm22, 3, 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -8019,7 +8034,7 @@ POST_OPS_DOWNSCALE_3x48F:
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point2, 2, zp_mask)
|
||||
}
|
||||
else
|
||||
@@ -8333,25 +8348,25 @@ POST_OPS_SIGMOID_3x48F:
|
||||
}
|
||||
POST_OPS_3x48F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -8789,7 +8804,7 @@ POST_OPS_DOWNSCALE_2x48F:
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_COL_BCST(zero_point0, 0, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
BF16_F32_ZP_COL_BCST(zero_point1, 1, zp_mask)
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -9042,21 +9057,21 @@ POST_OPS_SIGMOID_2x48F:
|
||||
}
|
||||
POST_OPS_2x48F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -9149,7 +9164,7 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x48)
|
||||
zmm1 = _mm512_loadu_ps(cbuf + 16);
|
||||
zmm8 = _mm512_fmadd_ps(zmm0, zmm3, zmm8);
|
||||
zmm9 = _mm512_fmadd_ps(zmm1, zmm3, zmm9);
|
||||
|
||||
|
||||
zmm0 = _mm512_loadu_ps(cbuf + 32);
|
||||
zmm10 = _mm512_fmadd_ps(zmm0, zmm3, zmm10);
|
||||
}
|
||||
@@ -9580,17 +9595,17 @@ POST_OPS_SIGMOID_1x48F:
|
||||
}
|
||||
POST_OPS_1x48F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -10565,28 +10580,28 @@ POST_OPS_SIGMOID_5x32F:
|
||||
}
|
||||
POST_OPS_5x32F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm24, 4, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm25, 4, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm24, 4, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm25, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -11439,25 +11454,25 @@ POST_OPS_SIGMOID_4x32F:
|
||||
}
|
||||
POST_OPS_4x32F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -12190,22 +12205,22 @@ POST_OPS_SIGMOID_3x32F:
|
||||
}
|
||||
POST_OPS_3x32F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -12778,19 +12793,19 @@ POST_OPS_SIGMOID_2x32F:
|
||||
}
|
||||
POST_OPS_2x32F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -13240,16 +13255,16 @@ POST_OPS_SIGMOID_1x32F:
|
||||
}
|
||||
POST_OPS_1x32F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -13257,4 +13272,4 @@ POST_OPS_1x32F_DISABLE:
|
||||
_mm512_storeu_ps(cbuf + 16, zmm9);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
File diff suppressed because it is too large
Load Diff
@@ -222,6 +222,19 @@
|
||||
mask_all1, (__m256i) _mm512_cvtneps_pbh( reg ) \
|
||||
) \
|
||||
|
||||
#define CVT_STORE_F32_BF16_MASK_AVX512( reg, m_ind, n_ind ) \
|
||||
{ \
|
||||
_mm512_storeu_ps((float*)temp, reg); \
|
||||
dest = ( bfloat16* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ); \
|
||||
for(i = 0; i < 16; i++) \
|
||||
{ \
|
||||
tlsb = ( temp[i] & ( uint32_t )0x00010000 ) > 16; \
|
||||
rounded = temp[i] + ( uint32_t )0x00007FFF + tlsb; \
|
||||
memcpy( (dest+i), ((char *)(&rounded))+2, sizeof(bfloat16)); \
|
||||
} \
|
||||
}
|
||||
// BF16 bias helper macros.
|
||||
#define BF16_F32_BIAS_LOAD(scr,mask,n_ind) \
|
||||
scr = ( __m512)( _mm512_sllv_epi32 \
|
||||
|
||||
@@ -45,9 +45,9 @@
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m)
|
||||
{
|
||||
//Call RD kernels if B is transposed
|
||||
if(rs_b == 1)
|
||||
if(rs_b == 1 && n0 != 1 )
|
||||
{
|
||||
lpgemm_rowvar_f32f32f32of32_avx512_6x64m_rd
|
||||
lpgemm_rowvar_f32f32f32of32_avx512_6x64m_rd
|
||||
(
|
||||
m0, n0, k0,
|
||||
a, rs_a, cs_a, ps_a,
|
||||
@@ -77,7 +77,6 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m)
|
||||
};
|
||||
|
||||
uint64_t n_left = n0 % 64; //n0 is expected to be n0<=NR
|
||||
|
||||
// First check whether this is a edge case in the n dimension.
|
||||
// If so, dispatch other 12x?m kernels, as needed.
|
||||
if ( n_left )
|
||||
@@ -1829,43 +1828,42 @@ POST_OPS_SIGMOID_6x64F:
|
||||
POST_OPS_6x64F_DISABLE:
|
||||
;
|
||||
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm11, 0, 3);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm15, 1, 3);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm11, 0, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm19, 2, 3);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm15, 1, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm22, 3, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm23, 3, 3);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm19, 2, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm24, 4, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm25, 4, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm26, 4, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm27, 4, 3);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm22, 3, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm23, 3, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm28, 5, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm29, 5, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm30, 5, 2);
|
||||
CVT_STORE_F32_BF16_MASK(zmm31, 5, 3);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm24, 4, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm25, 4, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm26, 4, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm27, 4, 3);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm28, 5, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm29, 5, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm30, 5, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm31, 5, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1904,6 +1902,7 @@ POST_OPS_6x64F_DISABLE:
|
||||
}//mloop
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
// Handle edge cases in the m dimension, if they exist.
|
||||
if( m_left )
|
||||
{
|
||||
@@ -3315,37 +3314,37 @@ POST_OPS_SIGMOID_6x48F:
|
||||
|
||||
POST_OPS_6x48F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm10, 0, 2);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm14, 1, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm10, 0, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm18, 2, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm14, 1, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm22, 3, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm18, 2, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm24, 4, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm25, 4, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm26, 4, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm22, 3, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm28, 5, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm29, 5, 1);
|
||||
CVT_STORE_F32_BF16_MASK(zmm30, 5, 2);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm24, 4, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm25, 4, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm26, 4, 2);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm28, 5, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm29, 5, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm30, 5, 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -4520,31 +4519,31 @@ POST_OPS_SIGMOID_6x32F:
|
||||
}
|
||||
POST_OPS_6x32F_DISABLE:
|
||||
;
|
||||
// Generate a mask16 of all 1's.
|
||||
__m512i selector_a = _mm512_setzero_epi32();
|
||||
__m512i selector_b = _mm512_set1_epi32( 10 );
|
||||
__mmask16 mask_all1 = _mm512_cmplt_epi32_mask( selector_a, selector_b );
|
||||
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
CVT_STORE_F32_BF16_MASK(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm9, 0, 1);
|
||||
uint32_t tlsb, rounded, temp[16] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm13, 1, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm8, 0, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm9, 0, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm17, 2, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm12, 1, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm13, 1, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm21, 3, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm16, 2, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm17, 2, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm24, 4, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm25, 4, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm20, 3, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm21, 3, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK(zmm28, 5, 0);
|
||||
CVT_STORE_F32_BF16_MASK(zmm29, 5, 1);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm24, 4, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm25, 4, 1);
|
||||
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm28, 5, 0);
|
||||
CVT_STORE_F32_BF16_MASK_AVX512(zmm29, 5, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -4605,4 +4604,4 @@ POST_OPS_6x32F_DISABLE:
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@@ -169,46 +169,82 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_256_6x32m)
|
||||
{
|
||||
ymm4 = _mm256_broadcast_ss(&(beta));
|
||||
|
||||
_cbuf = cbuf;
|
||||
//load c and multiply with beta and
|
||||
//add to accumulator and store back
|
||||
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm8)
|
||||
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm9)
|
||||
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm10)
|
||||
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm11)
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_first_k == TRUE ) )
|
||||
{
|
||||
BF16_F32_C_BNZ_8(0, 0, ymm0, ymm4, ymm8)
|
||||
BF16_F32_C_BNZ_8(0, 1, ymm1, ymm4, ymm9)
|
||||
BF16_F32_C_BNZ_8(0, 2, ymm2, ymm4, ymm10)
|
||||
BF16_F32_C_BNZ_8(0, 3, ymm3, ymm4, ymm11)
|
||||
|
||||
_cbuf += rs_c;
|
||||
BF16_F32_C_BNZ_8(1, 0, ymm0, ymm4, ymm12)
|
||||
BF16_F32_C_BNZ_8(1, 1, ymm1, ymm4, ymm13)
|
||||
BF16_F32_C_BNZ_8(1, 2, ymm2, ymm4, ymm14)
|
||||
BF16_F32_C_BNZ_8(1, 3, ymm3, ymm4, ymm15)
|
||||
|
||||
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm12)
|
||||
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm13)
|
||||
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm14)
|
||||
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm15)
|
||||
BF16_F32_C_BNZ_8(2, 0, ymm0, ymm4, ymm16)
|
||||
BF16_F32_C_BNZ_8(2, 1, ymm1, ymm4, ymm17)
|
||||
BF16_F32_C_BNZ_8(2, 2, ymm2, ymm4, ymm18)
|
||||
BF16_F32_C_BNZ_8(2, 3, ymm3, ymm4, ymm19)
|
||||
|
||||
_cbuf += rs_c;
|
||||
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm16)
|
||||
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm17)
|
||||
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm18)
|
||||
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm19)
|
||||
BF16_F32_C_BNZ_8(3, 0, ymm0, ymm4, ymm20)
|
||||
BF16_F32_C_BNZ_8(3, 1, ymm1, ymm4, ymm21)
|
||||
BF16_F32_C_BNZ_8(3, 2, ymm2, ymm4, ymm22)
|
||||
BF16_F32_C_BNZ_8(3, 3, ymm3, ymm4, ymm23)
|
||||
|
||||
_cbuf += rs_c;
|
||||
BF16_F32_C_BNZ_8(4, 0, ymm0, ymm4, ymm24)
|
||||
BF16_F32_C_BNZ_8(4, 1, ymm1, ymm4, ymm25)
|
||||
BF16_F32_C_BNZ_8(4, 2, ymm2, ymm4, ymm26)
|
||||
BF16_F32_C_BNZ_8(4, 3, ymm3, ymm4, ymm27)
|
||||
|
||||
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm20)
|
||||
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm21)
|
||||
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm22)
|
||||
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm23)
|
||||
BF16_F32_C_BNZ_8(5, 0, ymm0, ymm4, ymm28)
|
||||
BF16_F32_C_BNZ_8(5, 1, ymm1, ymm4, ymm29)
|
||||
BF16_F32_C_BNZ_8(5, 2, ymm2, ymm4, ymm30)
|
||||
BF16_F32_C_BNZ_8(5, 3, ymm3, ymm4, ymm31)
|
||||
}
|
||||
else
|
||||
{
|
||||
_cbuf = cbuf;
|
||||
//load c and multiply with beta and
|
||||
//add to accumulator and store back
|
||||
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm8)
|
||||
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm9)
|
||||
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm10)
|
||||
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm11)
|
||||
|
||||
_cbuf += rs_c;
|
||||
_cbuf += rs_c;
|
||||
|
||||
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm24)
|
||||
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm25)
|
||||
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm26)
|
||||
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm27)
|
||||
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm12)
|
||||
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm13)
|
||||
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm14)
|
||||
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm15)
|
||||
|
||||
_cbuf += rs_c;
|
||||
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm28)
|
||||
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm29)
|
||||
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm30)
|
||||
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm31)
|
||||
_cbuf += rs_c;
|
||||
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm16)
|
||||
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm17)
|
||||
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm18)
|
||||
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm19)
|
||||
|
||||
_cbuf += rs_c;
|
||||
|
||||
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm20)
|
||||
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm21)
|
||||
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm22)
|
||||
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm23)
|
||||
|
||||
_cbuf += rs_c;
|
||||
|
||||
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm24)
|
||||
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm25)
|
||||
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm26)
|
||||
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm27)
|
||||
|
||||
_cbuf += rs_c;
|
||||
F32_C_BNZ_8(_cbuf,rs_c,ymm0,ymm4,ymm28)
|
||||
F32_C_BNZ_8(_cbuf+8,rs_c,ymm1,ymm4,ymm29)
|
||||
F32_C_BNZ_8(_cbuf+16,rs_c,ymm2,ymm4,ymm30)
|
||||
F32_C_BNZ_8(_cbuf+24,rs_c,ymm3,ymm4,ymm31)
|
||||
}
|
||||
}
|
||||
|
||||
// Post Ops
|
||||
@@ -220,14 +256,24 @@ POST_OPS_BIAS_6x32F:
|
||||
if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
|
||||
{
|
||||
ymm0 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
|
||||
ymm1 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 8 ) );
|
||||
ymm2 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 8 ) );
|
||||
ymm3 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 8 ) );
|
||||
if( post_ops_list_temp->stor_type == BF16 )
|
||||
{
|
||||
BF16_F32_BIAS_LOAD_AVX2( ymm0, 0 )
|
||||
BF16_F32_BIAS_LOAD_AVX2( ymm1, 1 )
|
||||
BF16_F32_BIAS_LOAD_AVX2( ymm2, 2 )
|
||||
BF16_F32_BIAS_LOAD_AVX2( ymm3, 3 )
|
||||
}
|
||||
else
|
||||
{
|
||||
ymm0 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
|
||||
ymm1 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 8 ) );
|
||||
ymm2 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 8 ) );
|
||||
ymm3 = _mm256_loadu_ps( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 8 ) );
|
||||
}
|
||||
|
||||
ymm8 = _mm256_add_ps(ymm8, ymm0);
|
||||
ymm9 = _mm256_add_ps(ymm9, ymm1);
|
||||
@@ -261,18 +307,30 @@ POST_OPS_BIAS_6x32F:
|
||||
}
|
||||
else
|
||||
{
|
||||
ymm0 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
ymm1 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 );
|
||||
ymm2 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
ymm3 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
ymm4 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 );
|
||||
ymm5 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 5 );
|
||||
if( post_ops_list_temp->stor_type == BF16 )
|
||||
{
|
||||
BF16_F32_BIAS_BCAST_AVX2(ymm0,0)
|
||||
BF16_F32_BIAS_BCAST_AVX2(ymm1,1)
|
||||
BF16_F32_BIAS_BCAST_AVX2(ymm2,2)
|
||||
BF16_F32_BIAS_BCAST_AVX2(ymm3,3)
|
||||
BF16_F32_BIAS_BCAST_AVX2(ymm4,4)
|
||||
BF16_F32_BIAS_BCAST_AVX2(ymm5,5)
|
||||
}
|
||||
else
|
||||
{
|
||||
ymm0 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
ymm1 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 );
|
||||
ymm2 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
ymm3 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
ymm4 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 );
|
||||
ymm5 = _mm256_broadcast_ss( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 5 );
|
||||
}
|
||||
|
||||
ymm8 = _mm256_add_ps(ymm8, ymm0);
|
||||
ymm9 = _mm256_add_ps(ymm9, ymm0);
|
||||
@@ -506,6 +564,10 @@ POST_OPS_DOWNSCALE_6x32F:
|
||||
__m256 zero_point4 = _mm256_setzero_ps();
|
||||
__m256 zero_point5 = _mm256_setzero_ps();
|
||||
|
||||
bool is_bf16 = ( post_ops_list_temp->stor_type == BF16 ) ||
|
||||
( ( post_ops_list_temp->stor_type == NONE ) &&
|
||||
( post_ops_attr.c_stor_type == BF16 ) );
|
||||
|
||||
if( post_ops_list_temp->scale_factor_len == 1 )
|
||||
{
|
||||
selector1 =
|
||||
@@ -524,12 +586,24 @@ POST_OPS_DOWNSCALE_6x32F:
|
||||
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
if( is_bf16 == TRUE )
|
||||
{
|
||||
BF16_F32_ZP_SCALAR_BCAST_AVX2(zero_point0)
|
||||
BF16_F32_ZP_SCALAR_BCAST_AVX2(zero_point1)
|
||||
BF16_F32_ZP_SCALAR_BCAST_AVX2(zero_point2)
|
||||
BF16_F32_ZP_SCALAR_BCAST_AVX2(zero_point3)
|
||||
BF16_F32_ZP_SCALAR_BCAST_AVX2(zero_point4)
|
||||
BF16_F32_ZP_SCALAR_BCAST_AVX2(zero_point5)
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
}
|
||||
}
|
||||
|
||||
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
@@ -548,16 +622,25 @@ POST_OPS_DOWNSCALE_6x32F:
|
||||
}
|
||||
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
|
||||
zero_point1 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 8 ) );
|
||||
zero_point2 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 8 ) );
|
||||
zero_point3 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 8 ) );
|
||||
if( is_bf16 == TRUE )
|
||||
{
|
||||
BF16_F32_ZP_VECTOR_LOAD_AVX2(zero_point0,0)
|
||||
BF16_F32_ZP_VECTOR_LOAD_AVX2(zero_point1,1)
|
||||
BF16_F32_ZP_VECTOR_LOAD_AVX2(zero_point2,2)
|
||||
BF16_F32_ZP_VECTOR_LOAD_AVX2(zero_point3,3)
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
|
||||
zero_point1 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 8 ) );
|
||||
zero_point2 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 8 ) );
|
||||
zero_point3 = _mm256_loadu_ps ( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 8 ) );
|
||||
}
|
||||
}
|
||||
|
||||
F32_SCL_MULRND_AVX2(ymm8, selector1, zero_point0)
|
||||
F32_SCL_MULRND_AVX2(ymm9, selector2, zero_point1)
|
||||
F32_SCL_MULRND_AVX2(ymm10, selector3, zero_point2)
|
||||
@@ -613,18 +696,30 @@ POST_OPS_DOWNSCALE_6x32F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 5 ) );
|
||||
if( is_bf16 == TRUE )
|
||||
{
|
||||
BF16_F32_ZP_VECTOR_BCAST_AVX2(zero_point0,0)
|
||||
BF16_F32_ZP_VECTOR_BCAST_AVX2(zero_point1,1)
|
||||
BF16_F32_ZP_VECTOR_BCAST_AVX2(zero_point2,2)
|
||||
BF16_F32_ZP_VECTOR_BCAST_AVX2(zero_point3,3)
|
||||
BF16_F32_ZP_VECTOR_BCAST_AVX2(zero_point4,4)
|
||||
BF16_F32_ZP_VECTOR_BCAST_AVX2(zero_point5,5)
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 5 ) );
|
||||
}
|
||||
}
|
||||
|
||||
F32_SCL_MULRND_AVX2(ymm8, selector1, zero_point0)
|
||||
@@ -1006,51 +1101,91 @@ POST_OPS_SIGMOID_6x32F:
|
||||
}
|
||||
POST_OPS_6x32F_DISABLE:
|
||||
{
|
||||
_mm256_storeu_ps(cbuf, ymm8);
|
||||
_mm256_storeu_ps(cbuf + 8, ymm9);
|
||||
_mm256_storeu_ps(cbuf + 16, ymm10);
|
||||
_mm256_storeu_ps(cbuf + 24, ymm11);
|
||||
uint32_t tlsb, rounded, temp[8] = {0};
|
||||
int i;
|
||||
bfloat16* dest;
|
||||
|
||||
cbuf += rs_c;
|
||||
if ( ( post_ops_attr.buf_downscale != NULL ) &&
|
||||
( post_ops_attr.is_last_k == TRUE ) )
|
||||
{
|
||||
STORE_F32_BF16_YMM(ymm8, 0, 0)
|
||||
STORE_F32_BF16_YMM(ymm9, 0, 1)
|
||||
STORE_F32_BF16_YMM(ymm10, 0, 2)
|
||||
STORE_F32_BF16_YMM(ymm11, 0, 3)
|
||||
|
||||
_mm256_storeu_ps(cbuf, ymm12);
|
||||
_mm256_storeu_ps(cbuf + 8, ymm13);
|
||||
_mm256_storeu_ps(cbuf + 16, ymm14);
|
||||
_mm256_storeu_ps(cbuf + 24, ymm15);
|
||||
STORE_F32_BF16_YMM(ymm12, 1, 0)
|
||||
STORE_F32_BF16_YMM(ymm13, 1, 1)
|
||||
STORE_F32_BF16_YMM(ymm14, 1, 2)
|
||||
STORE_F32_BF16_YMM(ymm15, 1, 3)
|
||||
|
||||
cbuf += rs_c;
|
||||
STORE_F32_BF16_YMM(ymm16, 2, 0)
|
||||
STORE_F32_BF16_YMM(ymm17, 2, 1)
|
||||
STORE_F32_BF16_YMM(ymm18, 2, 2)
|
||||
STORE_F32_BF16_YMM(ymm19, 2, 3)
|
||||
|
||||
_mm256_storeu_ps(cbuf, ymm16);
|
||||
_mm256_storeu_ps(cbuf + 8, ymm17);
|
||||
_mm256_storeu_ps(cbuf + 16, ymm18);
|
||||
_mm256_storeu_ps(cbuf + 24, ymm19);
|
||||
STORE_F32_BF16_YMM(ymm20, 3, 0)
|
||||
STORE_F32_BF16_YMM(ymm21, 3, 1)
|
||||
STORE_F32_BF16_YMM(ymm22, 3, 2)
|
||||
STORE_F32_BF16_YMM(ymm23, 3, 3)
|
||||
|
||||
cbuf += rs_c;
|
||||
STORE_F32_BF16_YMM(ymm24, 4, 0)
|
||||
STORE_F32_BF16_YMM(ymm25, 4, 1)
|
||||
STORE_F32_BF16_YMM(ymm26, 4, 2)
|
||||
STORE_F32_BF16_YMM(ymm27, 4, 3)
|
||||
|
||||
_mm256_storeu_ps(cbuf, ymm20);
|
||||
_mm256_storeu_ps(cbuf + 8, ymm21);
|
||||
_mm256_storeu_ps(cbuf + 16, ymm22);
|
||||
_mm256_storeu_ps(cbuf + 24, ymm23);
|
||||
STORE_F32_BF16_YMM(ymm28, 5, 0)
|
||||
STORE_F32_BF16_YMM(ymm29, 5, 1)
|
||||
STORE_F32_BF16_YMM(ymm30, 5, 2)
|
||||
STORE_F32_BF16_YMM(ymm31, 5, 3)
|
||||
}
|
||||
else
|
||||
{
|
||||
_mm256_storeu_ps(cbuf, ymm8);
|
||||
_mm256_storeu_ps(cbuf + 8, ymm9);
|
||||
_mm256_storeu_ps(cbuf + 16, ymm10);
|
||||
_mm256_storeu_ps(cbuf + 24, ymm11);
|
||||
|
||||
cbuf += rs_c;
|
||||
cbuf += rs_c;
|
||||
|
||||
_mm256_storeu_ps(cbuf, ymm24);
|
||||
_mm256_storeu_ps(cbuf + 8, ymm25);
|
||||
_mm256_storeu_ps(cbuf + 16, ymm26);
|
||||
_mm256_storeu_ps(cbuf + 24, ymm27);
|
||||
_mm256_storeu_ps(cbuf, ymm12);
|
||||
_mm256_storeu_ps(cbuf + 8, ymm13);
|
||||
_mm256_storeu_ps(cbuf + 16, ymm14);
|
||||
_mm256_storeu_ps(cbuf + 24, ymm15);
|
||||
|
||||
cbuf += rs_c;
|
||||
cbuf += rs_c;
|
||||
|
||||
_mm256_storeu_ps(cbuf, ymm28);
|
||||
_mm256_storeu_ps(cbuf + 8, ymm29);
|
||||
_mm256_storeu_ps(cbuf + 16, ymm30);
|
||||
_mm256_storeu_ps(cbuf + 24, ymm31);
|
||||
_mm256_storeu_ps(cbuf, ymm16);
|
||||
_mm256_storeu_ps(cbuf + 8, ymm17);
|
||||
_mm256_storeu_ps(cbuf + 16, ymm18);
|
||||
_mm256_storeu_ps(cbuf + 24, ymm19);
|
||||
|
||||
cbuf += rs_c;
|
||||
|
||||
_mm256_storeu_ps(cbuf, ymm20);
|
||||
_mm256_storeu_ps(cbuf + 8, ymm21);
|
||||
_mm256_storeu_ps(cbuf + 16, ymm22);
|
||||
_mm256_storeu_ps(cbuf + 24, ymm23);
|
||||
|
||||
cbuf += rs_c;
|
||||
|
||||
_mm256_storeu_ps(cbuf, ymm24);
|
||||
_mm256_storeu_ps(cbuf + 8, ymm25);
|
||||
_mm256_storeu_ps(cbuf + 16, ymm26);
|
||||
_mm256_storeu_ps(cbuf + 24, ymm27);
|
||||
|
||||
cbuf += rs_c;
|
||||
|
||||
_mm256_storeu_ps(cbuf, ymm28);
|
||||
_mm256_storeu_ps(cbuf + 8, ymm29);
|
||||
_mm256_storeu_ps(cbuf + 16, ymm30);
|
||||
_mm256_storeu_ps(cbuf + 24, ymm31);
|
||||
}
|
||||
}
|
||||
post_ops_attr.post_op_c_i += MR;
|
||||
}//mloop
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
|
||||
// Handle edge cases in the m dimension, if they exist.
|
||||
if ( m_left )
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user