mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Added support for different types of zero-point in f32 eltwise APIs.
Description - Zero point support for <s32/s8/bf16/u8> datatype in element-wise postop only f32o<f32/s8/u8/s32/bf16> APIs. AMD-Internal: [SWLCSG-3390] Change-Id: I2fdb308b05c1393013294df7d8a03cdcd7978379
This commit is contained in:
@@ -746,10 +746,45 @@ POST_OPS_DOWNSCALE_5x64_OPS:
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
}
|
||||
}
|
||||
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
|
||||
@@ -767,14 +802,45 @@ POST_OPS_DOWNSCALE_5x64_OPS:
|
||||
}
|
||||
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
|
||||
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
|
||||
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
|
||||
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
BF16_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
BF16_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
BF16_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
BF16_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
S32_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
S32_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
S32_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
S32_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
S8_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
S8_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
S8_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
S8_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
U8_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
U8_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
U8_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
U8_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
|
||||
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
|
||||
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
|
||||
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -861,14 +927,49 @@ POST_OPS_DOWNSCALE_5x64_OPS:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
BF16_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
BF16_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
BF16_F32_ZP_BCAST(zero_point3, zp_mask, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
S32_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
S32_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
S32_F32_ZP_BCAST(zero_point3, zp_mask, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
S8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
S8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
S8_F32_ZP_BCAST(zero_point3, zp_mask, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
U8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
U8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
U8_F32_ZP_BCAST(zero_point3, zp_mask, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -925,8 +1026,31 @@ POST_OPS_DOWNSCALE_5x64_OPS:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
}
|
||||
}
|
||||
//c[4, 0-15]
|
||||
F32_SCL_MULRND(zmm24, selector1, zero_point0);
|
||||
@@ -2602,10 +2726,45 @@ POST_OPS_DOWNSCALE_4x64_OPS:
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
}
|
||||
}
|
||||
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
|
||||
@@ -2623,14 +2782,45 @@ POST_OPS_DOWNSCALE_4x64_OPS:
|
||||
}
|
||||
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
|
||||
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
|
||||
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
|
||||
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
BF16_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
BF16_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
BF16_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
BF16_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
S32_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
S32_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
S32_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
S32_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
S8_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
S8_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
S8_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
S8_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
U8_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
U8_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
U8_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
U8_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
|
||||
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
|
||||
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
|
||||
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -2705,14 +2895,49 @@ POST_OPS_DOWNSCALE_4x64_OPS:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
BF16_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
BF16_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
BF16_F32_ZP_BCAST(zero_point3, zp_mask, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
S32_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
S32_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
S32_F32_ZP_BCAST(zero_point3, zp_mask, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
S8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
S8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
S8_F32_ZP_BCAST(zero_point3, zp_mask, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
U8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
U8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
U8_F32_ZP_BCAST(zero_point3, zp_mask, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -4157,10 +4382,45 @@ POST_OPS_DOWNSCALE_3x64_OPS:
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
}
|
||||
}
|
||||
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
|
||||
@@ -4178,14 +4438,45 @@ POST_OPS_DOWNSCALE_3x64_OPS:
|
||||
}
|
||||
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
|
||||
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
|
||||
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
|
||||
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
BF16_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
BF16_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
BF16_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
BF16_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
S32_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
S32_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
S32_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
S32_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
S8_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
S8_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
S8_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
S8_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
U8_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
U8_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
U8_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
U8_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
|
||||
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
|
||||
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
|
||||
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -4245,12 +4536,43 @@ POST_OPS_DOWNSCALE_3x64_OPS:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
BF16_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
BF16_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
S32_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
S32_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
S8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
S8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
U8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
U8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -5423,10 +5745,45 @@ POST_OPS_DOWNSCALE_2x64_OPS:
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
}
|
||||
}
|
||||
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
|
||||
@@ -5444,14 +5801,45 @@ POST_OPS_DOWNSCALE_2x64_OPS:
|
||||
}
|
||||
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
|
||||
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
|
||||
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
|
||||
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
BF16_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
BF16_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
BF16_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
BF16_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
S32_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
S32_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
S32_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
S32_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
S8_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
S8_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
S8_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
S8_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
U8_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
U8_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
U8_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
U8_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
|
||||
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
|
||||
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
|
||||
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -5496,10 +5884,37 @@ POST_OPS_DOWNSCALE_2x64_OPS:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
BF16_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
S32_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
S8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
U8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -6400,10 +6815,45 @@ POST_OPS_DOWNSCALE_1x64_OPS:
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
}
|
||||
}
|
||||
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
|
||||
@@ -6421,14 +6871,45 @@ POST_OPS_DOWNSCALE_1x64_OPS:
|
||||
}
|
||||
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
|
||||
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
|
||||
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
|
||||
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
BF16_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
BF16_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
BF16_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
BF16_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
S32_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
S32_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
S32_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
S32_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
S8_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
S8_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
S8_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
S8_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
U8_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
U8_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
U8_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
U8_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
|
||||
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
|
||||
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
|
||||
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -6458,8 +6939,31 @@ POST_OPS_DOWNSCALE_1x64_OPS:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
|
||||
@@ -882,10 +882,45 @@ POST_OPS_DOWNSCALE_6x64_OPS:
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
BF16_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
S32_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
S8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
|
||||
U8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
}
|
||||
}
|
||||
|
||||
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
@@ -904,14 +939,45 @@ POST_OPS_DOWNSCALE_6x64_OPS:
|
||||
}
|
||||
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
|
||||
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
|
||||
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
|
||||
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
BF16_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
BF16_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
BF16_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
BF16_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
S32_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
S32_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
S32_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
S32_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
S8_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
S8_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
S8_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
S8_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
U8_F32_ZP_LOAD(zero_point0, k0, 0);
|
||||
U8_F32_ZP_LOAD(zero_point1, k1, 1);
|
||||
U8_F32_ZP_LOAD(zero_point2, k2, 2);
|
||||
U8_F32_ZP_LOAD(zero_point3, k3, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
|
||||
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
|
||||
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
|
||||
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
}
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -1010,14 +1076,49 @@ POST_OPS_DOWNSCALE_6x64_OPS:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
BF16_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
BF16_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
BF16_F32_ZP_BCAST(zero_point3, zp_mask, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
S32_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
S32_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
S32_F32_ZP_BCAST(zero_point3, zp_mask, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
S8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
S8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
S8_F32_ZP_BCAST(zero_point3, zp_mask, 3);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
U8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
U8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
|
||||
U8_F32_ZP_BCAST(zero_point3, zp_mask, 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
}
|
||||
|
||||
//c[0, 0-15]
|
||||
@@ -1077,10 +1178,37 @@ POST_OPS_DOWNSCALE_6x64_OPS:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 5 ) );
|
||||
if ( post_ops_list_temp->zp_stor_type == BF16 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
BF16_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S32 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
S32_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
S8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
}
|
||||
else if ( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
|
||||
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
|
||||
U8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 5 ) );
|
||||
}
|
||||
}
|
||||
//c[4, 0-15]
|
||||
F32_SCL_MULRND(zmm24, selector1, zero_point0);
|
||||
|
||||
@@ -515,6 +515,176 @@
|
||||
reg = _mm512_mul_ps( reg, selector ); \
|
||||
reg = _mm512_add_ps( reg, zero_point ); \
|
||||
|
||||
//u8 zero point helper macros
|
||||
#define U8_F32_ZP_LOAD(scr,mask,n_ind) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_cvtepu8_epi32 \
|
||||
( \
|
||||
_mm_maskz_loadu_epi8 \
|
||||
( \
|
||||
( mask ), \
|
||||
( ( int8_t* )post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
) \
|
||||
); \
|
||||
|
||||
//s8 zero point helper macros
|
||||
#define S8_F32_ZP_LOAD(scr,mask,n_ind) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_cvtepi8_epi32 \
|
||||
( \
|
||||
_mm_maskz_loadu_epi8 \
|
||||
( \
|
||||
( mask ), \
|
||||
( ( int8_t* )post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
) \
|
||||
); \
|
||||
|
||||
//bf16 zero point helper macros
|
||||
#define BF16_F32_ZP_LOAD(scr,mask,n_ind) \
|
||||
scr = ( __m512)( _mm512_sllv_epi32 \
|
||||
( \
|
||||
_mm512_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm256_maskz_loadu_epi16 \
|
||||
( \
|
||||
( mask ), \
|
||||
( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
), _mm512_set1_epi32( 16 ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
//s32 zero point helper macros
|
||||
#define S32_F32_ZP_LOAD(scr,mask,n_ind) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_maskz_loadu_epi32 \
|
||||
( \
|
||||
( mask ), \
|
||||
( ( int32_t* ) post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
//u8 zero point helper macros
|
||||
#define U8_F32_ZP_BCAST(scr,mask,m_ind) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_cvtepu8_epi32 \
|
||||
( \
|
||||
_mm_maskz_set1_epi8 \
|
||||
( \
|
||||
( mask ), \
|
||||
*( ( ( int8_t* )post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_i + m_ind ) \
|
||||
) \
|
||||
) \
|
||||
); \
|
||||
|
||||
//s8 zero point helper macros
|
||||
#define S8_F32_ZP_BCAST(scr,mask,m_ind) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_cvtepi8_epi32 \
|
||||
( \
|
||||
_mm_maskz_set1_epi8 \
|
||||
( \
|
||||
( mask ), \
|
||||
*( ( ( int8_t* )post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_i + m_ind ) \
|
||||
) \
|
||||
) \
|
||||
); \
|
||||
|
||||
//bf16 zero point helper macros
|
||||
#define BF16_F32_ZP_BCAST(scr,mask,m_ind) \
|
||||
scr = ( __m512)( _mm512_sllv_epi32 \
|
||||
( \
|
||||
_mm512_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm256_maskz_set1_epi16 \
|
||||
( \
|
||||
( mask ), \
|
||||
*( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_i + m_ind ) \
|
||||
) \
|
||||
), _mm512_set1_epi32( 16 ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
//s32 zero point helper macros
|
||||
#define S32_F32_ZP_BCAST(scr,mask,m_ind) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_maskz_set1_epi32 \
|
||||
( \
|
||||
( mask ), \
|
||||
*( ( ( int32_t* ) post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_i + m_ind ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
//u8 zero point helper macros
|
||||
#define U8_F32_SCALAR_ZP_BCAST(scr,mask) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_cvtepu8_epi32 \
|
||||
( \
|
||||
_mm_maskz_set1_epi8 \
|
||||
( \
|
||||
( mask ), \
|
||||
*( ( int8_t* )post_ops_list_temp->op_args1 ) \
|
||||
) \
|
||||
) \
|
||||
); \
|
||||
|
||||
//s8 zero point helper macros
|
||||
#define S8_F32_SCALAR_ZP_BCAST(scr,mask) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_cvtepi8_epi32 \
|
||||
( \
|
||||
_mm_maskz_set1_epi8 \
|
||||
( \
|
||||
( mask ), \
|
||||
*( ( int8_t* )post_ops_list_temp->op_args1 ) \
|
||||
) \
|
||||
) \
|
||||
); \
|
||||
|
||||
//bf16 zero point helper macros
|
||||
#define BF16_F32_SCALAR_ZP_BCAST(scr,mask) \
|
||||
scr = ( __m512)( _mm512_sllv_epi32 \
|
||||
( \
|
||||
_mm512_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm256_maskz_set1_epi16 \
|
||||
( \
|
||||
( mask ), \
|
||||
*( ( bfloat16* )post_ops_list_temp->op_args1 ) \
|
||||
) \
|
||||
), _mm512_set1_epi32( 16 ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
//s32 zero point helper macros
|
||||
#define S32_F32_SCALAR_ZP_BCAST(scr,mask) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_maskz_set1_epi32 \
|
||||
( \
|
||||
( mask ), \
|
||||
*( ( int32_t* ) post_ops_list_temp->op_args1 ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
#define CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,reg,mask,m_ind,n_ind)
|
||||
#else
|
||||
|
||||
Reference in New Issue
Block a user