Added support for different types of zero-point in f32 eltwise APIs.

Description
 - Zero point support for <s32/s8/bf16/u8> datatype in element-wise
   postop only f32o<f32/s8/u8/s32/bf16> APIs.

 AMD-Internal: [SWLCSG-3390]

Change-Id: I2fdb308b05c1393013294df7d8a03cdcd7978379
This commit is contained in:
Deepak Negi
2025-02-24 23:50:57 +05:30
parent 7394aafd1e
commit cc321fb95d
9 changed files with 1152 additions and 192 deletions

View File

@@ -746,10 +746,45 @@ POST_OPS_DOWNSCALE_5x64_OPS:
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
}
}
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
@@ -767,14 +802,45 @@ POST_OPS_DOWNSCALE_5x64_OPS:
}
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
BF16_F32_ZP_LOAD(zero_point0, k0, 0);
BF16_F32_ZP_LOAD(zero_point1, k1, 1);
BF16_F32_ZP_LOAD(zero_point2, k2, 2);
BF16_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
S32_F32_ZP_LOAD(zero_point0, k0, 0);
S32_F32_ZP_LOAD(zero_point1, k1, 1);
S32_F32_ZP_LOAD(zero_point2, k2, 2);
S32_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
S8_F32_ZP_LOAD(zero_point0, k0, 0);
S8_F32_ZP_LOAD(zero_point1, k1, 1);
S8_F32_ZP_LOAD(zero_point2, k2, 2);
S8_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
U8_F32_ZP_LOAD(zero_point0, k0, 0);
U8_F32_ZP_LOAD(zero_point1, k1, 1);
U8_F32_ZP_LOAD(zero_point2, k2, 2);
U8_F32_ZP_LOAD(zero_point3, k3, 3);
}
else
{
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
}
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -861,14 +927,49 @@ POST_OPS_DOWNSCALE_5x64_OPS:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
BF16_F32_ZP_BCAST(zero_point1, zp_mask, 1);
BF16_F32_ZP_BCAST(zero_point2, zp_mask, 2);
BF16_F32_ZP_BCAST(zero_point3, zp_mask, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
S32_F32_ZP_BCAST(zero_point1, zp_mask, 1);
S32_F32_ZP_BCAST(zero_point2, zp_mask, 2);
S32_F32_ZP_BCAST(zero_point3, zp_mask, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
S8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
S8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
S8_F32_ZP_BCAST(zero_point3, zp_mask, 3);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
U8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
U8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
U8_F32_ZP_BCAST(zero_point3, zp_mask, 3);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
}
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -925,8 +1026,31 @@ POST_OPS_DOWNSCALE_5x64_OPS:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
}
}
//c[4, 0-15]
F32_SCL_MULRND(zmm24, selector1, zero_point0);
@@ -2602,10 +2726,45 @@ POST_OPS_DOWNSCALE_4x64_OPS:
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
}
}
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
@@ -2623,14 +2782,45 @@ POST_OPS_DOWNSCALE_4x64_OPS:
}
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
BF16_F32_ZP_LOAD(zero_point0, k0, 0);
BF16_F32_ZP_LOAD(zero_point1, k1, 1);
BF16_F32_ZP_LOAD(zero_point2, k2, 2);
BF16_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
S32_F32_ZP_LOAD(zero_point0, k0, 0);
S32_F32_ZP_LOAD(zero_point1, k1, 1);
S32_F32_ZP_LOAD(zero_point2, k2, 2);
S32_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
S8_F32_ZP_LOAD(zero_point0, k0, 0);
S8_F32_ZP_LOAD(zero_point1, k1, 1);
S8_F32_ZP_LOAD(zero_point2, k2, 2);
S8_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
U8_F32_ZP_LOAD(zero_point0, k0, 0);
U8_F32_ZP_LOAD(zero_point1, k1, 1);
U8_F32_ZP_LOAD(zero_point2, k2, 2);
U8_F32_ZP_LOAD(zero_point3, k3, 3);
}
else
{
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
}
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -2705,14 +2895,49 @@ POST_OPS_DOWNSCALE_4x64_OPS:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
BF16_F32_ZP_BCAST(zero_point1, zp_mask, 1);
BF16_F32_ZP_BCAST(zero_point2, zp_mask, 2);
BF16_F32_ZP_BCAST(zero_point3, zp_mask, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
S32_F32_ZP_BCAST(zero_point1, zp_mask, 1);
S32_F32_ZP_BCAST(zero_point2, zp_mask, 2);
S32_F32_ZP_BCAST(zero_point3, zp_mask, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
S8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
S8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
S8_F32_ZP_BCAST(zero_point3, zp_mask, 3);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
U8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
U8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
U8_F32_ZP_BCAST(zero_point3, zp_mask, 3);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
}
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -4157,10 +4382,45 @@ POST_OPS_DOWNSCALE_3x64_OPS:
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
}
}
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
@@ -4178,14 +4438,45 @@ POST_OPS_DOWNSCALE_3x64_OPS:
}
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
BF16_F32_ZP_LOAD(zero_point0, k0, 0);
BF16_F32_ZP_LOAD(zero_point1, k1, 1);
BF16_F32_ZP_LOAD(zero_point2, k2, 2);
BF16_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
S32_F32_ZP_LOAD(zero_point0, k0, 0);
S32_F32_ZP_LOAD(zero_point1, k1, 1);
S32_F32_ZP_LOAD(zero_point2, k2, 2);
S32_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
S8_F32_ZP_LOAD(zero_point0, k0, 0);
S8_F32_ZP_LOAD(zero_point1, k1, 1);
S8_F32_ZP_LOAD(zero_point2, k2, 2);
S8_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
U8_F32_ZP_LOAD(zero_point0, k0, 0);
U8_F32_ZP_LOAD(zero_point1, k1, 1);
U8_F32_ZP_LOAD(zero_point2, k2, 2);
U8_F32_ZP_LOAD(zero_point3, k3, 3);
}
else
{
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
}
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -4245,12 +4536,43 @@ POST_OPS_DOWNSCALE_3x64_OPS:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
BF16_F32_ZP_BCAST(zero_point1, zp_mask, 1);
BF16_F32_ZP_BCAST(zero_point2, zp_mask, 2);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
S32_F32_ZP_BCAST(zero_point1, zp_mask, 1);
S32_F32_ZP_BCAST(zero_point2, zp_mask, 2);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
S8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
S8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
U8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
U8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
}
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -5423,10 +5745,45 @@ POST_OPS_DOWNSCALE_2x64_OPS:
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
}
}
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
@@ -5444,14 +5801,45 @@ POST_OPS_DOWNSCALE_2x64_OPS:
}
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
BF16_F32_ZP_LOAD(zero_point0, k0, 0);
BF16_F32_ZP_LOAD(zero_point1, k1, 1);
BF16_F32_ZP_LOAD(zero_point2, k2, 2);
BF16_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
S32_F32_ZP_LOAD(zero_point0, k0, 0);
S32_F32_ZP_LOAD(zero_point1, k1, 1);
S32_F32_ZP_LOAD(zero_point2, k2, 2);
S32_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
S8_F32_ZP_LOAD(zero_point0, k0, 0);
S8_F32_ZP_LOAD(zero_point1, k1, 1);
S8_F32_ZP_LOAD(zero_point2, k2, 2);
S8_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
U8_F32_ZP_LOAD(zero_point0, k0, 0);
U8_F32_ZP_LOAD(zero_point1, k1, 1);
U8_F32_ZP_LOAD(zero_point2, k2, 2);
U8_F32_ZP_LOAD(zero_point3, k3, 3);
}
else
{
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
}
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -5496,10 +5884,37 @@ POST_OPS_DOWNSCALE_2x64_OPS:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
BF16_F32_ZP_BCAST(zero_point1, zp_mask, 1);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
S32_F32_ZP_BCAST(zero_point1, zp_mask, 1);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
S8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
U8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
}
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -6400,10 +6815,45 @@ POST_OPS_DOWNSCALE_1x64_OPS:
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
}
}
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
@@ -6421,14 +6871,45 @@ POST_OPS_DOWNSCALE_1x64_OPS:
}
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
BF16_F32_ZP_LOAD(zero_point0, k0, 0);
BF16_F32_ZP_LOAD(zero_point1, k1, 1);
BF16_F32_ZP_LOAD(zero_point2, k2, 2);
BF16_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
S32_F32_ZP_LOAD(zero_point0, k0, 0);
S32_F32_ZP_LOAD(zero_point1, k1, 1);
S32_F32_ZP_LOAD(zero_point2, k2, 2);
S32_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
S8_F32_ZP_LOAD(zero_point0, k0, 0);
S8_F32_ZP_LOAD(zero_point1, k1, 1);
S8_F32_ZP_LOAD(zero_point2, k2, 2);
S8_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
U8_F32_ZP_LOAD(zero_point0, k0, 0);
U8_F32_ZP_LOAD(zero_point1, k1, 1);
U8_F32_ZP_LOAD(zero_point2, k2, 2);
U8_F32_ZP_LOAD(zero_point3, k3, 3);
}
else
{
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
}
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -6458,8 +6939,31 @@ POST_OPS_DOWNSCALE_1x64_OPS:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
}
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);

View File

@@ -882,10 +882,45 @@ POST_OPS_DOWNSCALE_6x64_OPS:
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) == 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
BF16_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
S32_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
S8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_SCALAR_ZP_BCAST(zero_point0, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point1, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point2, zp_mask);
U8_F32_SCALAR_ZP_BCAST(zero_point3, zp_mask);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
}
}
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
@@ -904,14 +939,45 @@ POST_OPS_DOWNSCALE_6x64_OPS:
}
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
BF16_F32_ZP_LOAD(zero_point0, k0, 0);
BF16_F32_ZP_LOAD(zero_point1, k1, 1);
BF16_F32_ZP_LOAD(zero_point2, k2, 2);
BF16_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
S32_F32_ZP_LOAD(zero_point0, k0, 0);
S32_F32_ZP_LOAD(zero_point1, k1, 1);
S32_F32_ZP_LOAD(zero_point2, k2, 2);
S32_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
S8_F32_ZP_LOAD(zero_point0, k0, 0);
S8_F32_ZP_LOAD(zero_point1, k1, 1);
S8_F32_ZP_LOAD(zero_point2, k2, 2);
S8_F32_ZP_LOAD(zero_point3, k3, 3);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
U8_F32_ZP_LOAD(zero_point0, k0, 0);
U8_F32_ZP_LOAD(zero_point1, k1, 1);
U8_F32_ZP_LOAD(zero_point2, k2, 2);
U8_F32_ZP_LOAD(zero_point3, k3, 3);
}
else
{
zero_point0 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 16 ) );
zero_point1 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 1 * 16 ) );
zero_point2 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 2 * 16 ) );
zero_point3 = _mm512_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
}
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -1010,14 +1076,49 @@ POST_OPS_DOWNSCALE_6x64_OPS:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
BF16_F32_ZP_BCAST(zero_point1, zp_mask, 1);
BF16_F32_ZP_BCAST(zero_point2, zp_mask, 2);
BF16_F32_ZP_BCAST(zero_point3, zp_mask, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
S32_F32_ZP_BCAST(zero_point1, zp_mask, 1);
S32_F32_ZP_BCAST(zero_point2, zp_mask, 2);
S32_F32_ZP_BCAST(zero_point3, zp_mask, 3);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
S8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
S8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
S8_F32_ZP_BCAST(zero_point3, zp_mask, 3);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
U8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
U8_F32_ZP_BCAST(zero_point2, zp_mask, 2);
U8_F32_ZP_BCAST(zero_point3, zp_mask, 3);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
}
}
//c[0, 0-15]
@@ -1077,10 +1178,37 @@ POST_OPS_DOWNSCALE_6x64_OPS:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 5 ) );
if ( post_ops_list_temp->zp_stor_type == BF16 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
BF16_F32_ZP_BCAST(zero_point0, zp_mask, 0);
BF16_F32_ZP_BCAST(zero_point1, zp_mask, 1);
}
else if ( post_ops_list_temp->zp_stor_type == S32 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S32_F32_ZP_BCAST(zero_point0, zp_mask, 0);
S32_F32_ZP_BCAST(zero_point1, zp_mask, 1);
}
else if ( post_ops_list_temp->zp_stor_type == S8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
S8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
S8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
}
else if ( post_ops_list_temp->zp_stor_type == U8 )
{
__mmask16 zp_mask = _cvtu32_mask16( 0xFFFF );
U8_F32_ZP_BCAST(zero_point0, zp_mask, 0);
U8_F32_ZP_BCAST(zero_point1, zp_mask, 1);
}
else
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 5 ) );
}
}
//c[4, 0-15]
F32_SCL_MULRND(zmm24, selector1, zero_point0);

View File

@@ -515,6 +515,176 @@
reg = _mm512_mul_ps( reg, selector ); \
reg = _mm512_add_ps( reg, zero_point ); \
//u8 zero point helper macros
#define U8_F32_ZP_LOAD(scr,mask,n_ind) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_cvtepu8_epi32 \
( \
_mm_maskz_loadu_epi8 \
( \
( mask ), \
( ( int8_t* )post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
) \
); \
//s8 zero point helper macros
#define S8_F32_ZP_LOAD(scr,mask,n_ind) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_cvtepi8_epi32 \
( \
_mm_maskz_loadu_epi8 \
( \
( mask ), \
( ( int8_t* )post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
) \
); \
//bf16 zero point helper macros
#define BF16_F32_ZP_LOAD(scr,mask,n_ind) \
scr = ( __m512)( _mm512_sllv_epi32 \
( \
_mm512_cvtepi16_epi32 \
( \
_mm256_maskz_loadu_epi16 \
( \
( mask ), \
( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
), _mm512_set1_epi32( 16 ) \
) \
); \
//s32 zero point helper macros
#define S32_F32_ZP_LOAD(scr,mask,n_ind) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_maskz_loadu_epi32 \
( \
( mask ), \
( ( int32_t* ) post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
); \
//u8 zero point helper macros
#define U8_F32_ZP_BCAST(scr,mask,m_ind) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_cvtepu8_epi32 \
( \
_mm_maskz_set1_epi8 \
( \
( mask ), \
*( ( ( int8_t* )post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_i + m_ind ) \
) \
) \
); \
//s8 zero point helper macros
#define S8_F32_ZP_BCAST(scr,mask,m_ind) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_cvtepi8_epi32 \
( \
_mm_maskz_set1_epi8 \
( \
( mask ), \
*( ( ( int8_t* )post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_i + m_ind ) \
) \
) \
); \
//bf16 zero point helper macros
#define BF16_F32_ZP_BCAST(scr,mask,m_ind) \
scr = ( __m512)( _mm512_sllv_epi32 \
( \
_mm512_cvtepi16_epi32 \
( \
_mm256_maskz_set1_epi16 \
( \
( mask ), \
*( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_i + m_ind ) \
) \
), _mm512_set1_epi32( 16 ) \
) \
); \
//s32 zero point helper macros
#define S32_F32_ZP_BCAST(scr,mask,m_ind) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_maskz_set1_epi32 \
( \
( mask ), \
*( ( ( int32_t* ) post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_i + m_ind ) \
) \
); \
//u8 zero point helper macros
#define U8_F32_SCALAR_ZP_BCAST(scr,mask) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_cvtepu8_epi32 \
( \
_mm_maskz_set1_epi8 \
( \
( mask ), \
*( ( int8_t* )post_ops_list_temp->op_args1 ) \
) \
) \
); \
//s8 zero point helper macros
#define S8_F32_SCALAR_ZP_BCAST(scr,mask) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_cvtepi8_epi32 \
( \
_mm_maskz_set1_epi8 \
( \
( mask ), \
*( ( int8_t* )post_ops_list_temp->op_args1 ) \
) \
) \
); \
//bf16 zero point helper macros
#define BF16_F32_SCALAR_ZP_BCAST(scr,mask) \
scr = ( __m512)( _mm512_sllv_epi32 \
( \
_mm512_cvtepi16_epi32 \
( \
_mm256_maskz_set1_epi16 \
( \
( mask ), \
*( ( bfloat16* )post_ops_list_temp->op_args1 ) \
) \
), _mm512_set1_epi32( 16 ) \
) \
); \
//s32 zero point helper macros
#define S32_F32_SCALAR_ZP_BCAST(scr,mask) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_maskz_set1_epi32 \
( \
( mask ), \
*( ( int32_t* ) post_ops_list_temp->op_args1 ) \
) \
); \
#ifdef LPGEMM_BF16_JIT
#define CVT_STORE_F32_BF16_POST_OPS_MASK(ir,jr,reg,mask,m_ind,n_ind)
#else