Bug Fix in F32 eltwise Api with post ops(clip, swish, relu_scale).

Description
1. In the cases of clip, swish, and relu_scale, constants are currently
   loaded as float. However, they are of C type, so handling has been
   adjusted, for integer these constants are first loaded as integer
   and then converted to float.

Change-Id: I176b805b69679df42be5745b6306f75e23de274d
This commit is contained in:
Deepak Negi
2025-04-09 18:01:37 +05:30
committed by Nallani Bhaskar
parent b9998a1d7f
commit f76f37cc11
4 changed files with 264 additions and 48 deletions

View File

@@ -379,6 +379,14 @@ GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32ou8)
GEN_CLIP_POST_OP_VAL_FLOAT(bf16of32)
GEN_CLIP_POST_OP_VAL_FLOAT(bf16obf16)
GEN_CLIP_POST_OP_VAL_FLOAT(f32of32)
GEN_CLIP_POST_OP_VAL_FLOAT(f32obf16)
GEN_CLIP_POST_OP_VAL_INT(f32os32)
GEN_CLIP_POST_OP_VAL_INT(f32os8)
GEN_CLIP_POST_OP_VAL_INT(f32ou8)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,float)
@@ -521,17 +529,11 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
CLIP ) /* CLIP*/ \
{ \
temp_accum = \
min \
( \
max \
( \
temp_accum, \
*( ( ACCUM_type* ) \
( post_op->eltwise + ele_i )->algo.alpha ) \
), \
*( ( ACCUM_type* ) \
( post_op->eltwise + ele_i )->algo.beta) \
); \
GEN_FUNC_NAME(get_clip_post_op_val_,LP_SFX) \
( temp_accum, \
( post_op->eltwise + ele_i )->algo.alpha, \
( post_op->eltwise + ele_i )->algo.beta \
); \
ele_i += 1; \
} \
else \

View File

@@ -412,6 +412,7 @@ static inline ACCUM_type get_bias_post_op_val_ ## BLAS_SFX \
return *( ( ACCUM_type* )post_op_bias_ptr + j ); \
} \
/* GELU TANH */
#define GEN_GELU_TANH_POSTOP_FLOAT(BLAS_SFX) \
static inline float GELU_TANH_post_op_ ## BLAS_SFX \
( \

View File

@@ -448,8 +448,19 @@ POST_OPS_RELU_5x64_OPS:
POST_OPS_RELU_SCALE_5x64_OPS:
{
zmm1 = _mm512_setzero_ps();
zmm2 =
_mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
if ( (post_ops_attr.c_stor_type == S32 ) ||
(post_ops_attr.c_stor_type == U8 ) ||
(post_ops_attr.c_stor_type == S8 ) )
{
zmm2 = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
zmm2 = _mm512_set1_ps(
*( ( float* )post_ops_list_temp->op_args2 ) );
}
__mmask16 relu_cmp_mask;
@@ -650,8 +661,21 @@ POST_OPS_GELU_ERF_5x64_OPS:
}
POST_OPS_CLIP_5x64_OPS:
{
__m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 );
__m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 );
__m512 min, max;
if( post_ops_attr.c_stor_type == S32 ||
post_ops_attr.c_stor_type == S8 ||
post_ops_attr.c_stor_type == U8 )
{
min = _mm512_cvtepi32_ps
( _mm512_set1_epi32( *( ( int32_t* ) post_ops_list_temp->op_args2 ) ) );
max = _mm512_cvtepi32_ps
( _mm512_set1_epi32( *( ( int32_t* ) post_ops_list_temp->op_args3 ) ) );
}
else
{
min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 );
max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 );
}
// c[0, 0-15]
CLIP_F32S_AVX512(zmm8, min, max)
@@ -1612,8 +1636,19 @@ POST_OPS_MATRIX_MUL_5x64_OPS:
}
POST_OPS_SWISH_5x64_OPS:
{
zmm1 =
_mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
if ( (post_ops_attr.c_stor_type == S32 ) ||
(post_ops_attr.c_stor_type == U8 ) ||
(post_ops_attr.c_stor_type == S8 ) )
{
zmm1 = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
zmm1 = _mm512_set1_ps
( *( ( float* )post_ops_list_temp->op_args2 ) );
}
__m512 al_in, r, r2, z, dn;
__m512i ex_out;
@@ -2476,8 +2511,19 @@ POST_OPS_RELU_4x64_OPS:
POST_OPS_RELU_SCALE_4x64_OPS:
{
zmm1 = _mm512_setzero_ps();
zmm2 =
_mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
if ( (post_ops_attr.c_stor_type == S32 ) ||
(post_ops_attr.c_stor_type == U8 ) ||
(post_ops_attr.c_stor_type == S8 ) )
{
zmm2 = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
zmm2 = _mm512_set1_ps(
*( ( float* )post_ops_list_temp->op_args2 ) );
}
__mmask16 relu_cmp_mask;
@@ -2642,8 +2688,21 @@ POST_OPS_GELU_ERF_4x64_OPS:
}
POST_OPS_CLIP_4x64_OPS:
{
__m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 );
__m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 );
__m512 min, max;
if( post_ops_attr.c_stor_type == S32 ||
post_ops_attr.c_stor_type == S8 ||
post_ops_attr.c_stor_type == U8 )
{
min = _mm512_cvtepi32_ps
( _mm512_set1_epi32( *( ( int32_t* ) post_ops_list_temp->op_args2 ) ) );
max = _mm512_cvtepi32_ps
( _mm512_set1_epi32( *( ( int32_t* ) post_ops_list_temp->op_args3 ) ) );
}
else
{
min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 );
max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 );
}
// c[0, 0-15]
CLIP_F32S_AVX512(zmm8, min, max)
@@ -3458,8 +3517,19 @@ POST_OPS_MATRIX_MUL_4x64_OPS:
}
POST_OPS_SWISH_4x64_OPS:
{
zmm1 =
_mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
if ( (post_ops_attr.c_stor_type == S32 ) ||
(post_ops_attr.c_stor_type == U8 ) ||
(post_ops_attr.c_stor_type == S8 ) )
{
zmm1 = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
zmm1 = _mm512_set1_ps
( *( ( float* )post_ops_list_temp->op_args2 ) );
}
__m512 al_in, r, r2, z, dn;
__m512i ex_out;
@@ -4180,8 +4250,19 @@ POST_OPS_RELU_3x64_OPS:
POST_OPS_RELU_SCALE_3x64_OPS:
{
zmm1 = _mm512_setzero_ps();
zmm2 =
_mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
if ( (post_ops_attr.c_stor_type == S32 ) ||
(post_ops_attr.c_stor_type == U8 ) ||
(post_ops_attr.c_stor_type == S8 ) )
{
zmm2 = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
zmm2 = _mm512_set1_ps(
*( ( float* )post_ops_list_temp->op_args2 ) );
}
__mmask16 relu_cmp_mask;
@@ -4310,8 +4391,21 @@ POST_OPS_GELU_ERF_3x64_OPS:
}
POST_OPS_CLIP_3x64_OPS:
{
__m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 );
__m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 );
__m512 min, max;
if( post_ops_attr.c_stor_type == S32 ||
post_ops_attr.c_stor_type == S8 ||
post_ops_attr.c_stor_type == U8 )
{
min = _mm512_cvtepi32_ps
( _mm512_set1_epi32( *( ( int32_t* ) post_ops_list_temp->op_args2 ) ) );
max = _mm512_cvtepi32_ps
( _mm512_set1_epi32( *( ( int32_t* ) post_ops_list_temp->op_args3 ) ) );
}
else
{
min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 );
max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 );
}
// c[0, 0-15]
CLIP_F32S_AVX512(zmm8, min, max)
@@ -5011,8 +5105,19 @@ POST_OPS_MATRIX_MUL_3x64_OPS:
}
POST_OPS_SWISH_3x64_OPS:
{
zmm1 =
_mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
if ( (post_ops_attr.c_stor_type == S32 ) ||
(post_ops_attr.c_stor_type == U8 ) ||
(post_ops_attr.c_stor_type == S8 ) )
{
zmm1 = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
zmm1 = _mm512_set1_ps
( *( ( float* )post_ops_list_temp->op_args2 ) );
}
__m512 al_in, r, r2, z, dn;
__m512i ex_out;
@@ -5591,8 +5696,19 @@ POST_OPS_RELU_2x64_OPS:
POST_OPS_RELU_SCALE_2x64_OPS:
{
zmm1 = _mm512_setzero_ps();
zmm2 =
_mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
if ( (post_ops_attr.c_stor_type == S32 ) ||
(post_ops_attr.c_stor_type == U8 ) ||
(post_ops_attr.c_stor_type == S8 ) )
{
zmm2 = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
zmm2 = _mm512_set1_ps(
*( ( float* )post_ops_list_temp->op_args2 ) );
}
__mmask16 relu_cmp_mask;
@@ -5685,8 +5801,21 @@ POST_OPS_GELU_ERF_2x64_OPS:
}
POST_OPS_CLIP_2x64_OPS:
{
__m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 );
__m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 );
__m512 min, max;
if( post_ops_attr.c_stor_type == S32 ||
post_ops_attr.c_stor_type == S8 ||
post_ops_attr.c_stor_type == U8 )
{
min = _mm512_cvtepi32_ps
( _mm512_set1_epi32( *( ( int32_t* ) post_ops_list_temp->op_args2 ) ) );
max = _mm512_cvtepi32_ps
( _mm512_set1_epi32( *( ( int32_t* ) post_ops_list_temp->op_args3 ) ) );
}
else
{
min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 );
max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 );
}
// c[0, 0-15]
CLIP_F32S_AVX512(zmm8, min, max)
@@ -6271,8 +6400,19 @@ POST_OPS_MATRIX_MUL_2x64_OPS:
}
POST_OPS_SWISH_2x64_OPS:
{
zmm1 =
_mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
if ( (post_ops_attr.c_stor_type == S32 ) ||
(post_ops_attr.c_stor_type == U8 ) ||
(post_ops_attr.c_stor_type == S8 ) )
{
zmm1 = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
zmm1 = _mm512_set1_ps
( *( ( float* )post_ops_list_temp->op_args2 ) );
}
__m512 al_in, r, r2, z, dn;
__m512i ex_out;
@@ -6709,8 +6849,19 @@ POST_OPS_RELU_1x64_OPS:
POST_OPS_RELU_SCALE_1x64_OPS:
{
zmm1 = _mm512_setzero_ps();
zmm2 =
_mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
if ( (post_ops_attr.c_stor_type == S32 ) ||
(post_ops_attr.c_stor_type == U8 ) ||
(post_ops_attr.c_stor_type == S8 ) )
{
zmm2 = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
zmm2 = _mm512_set1_ps(
*( ( float* )post_ops_list_temp->op_args2 ) );
}
__mmask16 relu_cmp_mask;
@@ -6767,8 +6918,21 @@ POST_OPS_GELU_ERF_1x64_OPS:
}
POST_OPS_CLIP_1x64_OPS:
{
__m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 );
__m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 );
__m512 min, max;
if( post_ops_attr.c_stor_type == S32 ||
post_ops_attr.c_stor_type == S8 ||
post_ops_attr.c_stor_type == U8 )
{
min = _mm512_cvtepi32_ps
( _mm512_set1_epi32( *( ( int32_t* ) post_ops_list_temp->op_args2 ) ) );
max = _mm512_cvtepi32_ps
( _mm512_set1_epi32( *( ( int32_t* ) post_ops_list_temp->op_args3 ) ) );
}
else
{
min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 );
max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 );
}
// c[0, 0-15]
CLIP_F32S_AVX512(zmm8, min, max)
@@ -7238,9 +7402,20 @@ POST_OPS_MATRIX_MUL_1x64_OPS:
}
POST_OPS_SWISH_1x64_OPS:
{
zmm1 =
_mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
if ( (post_ops_attr.c_stor_type == S32 ) ||
(post_ops_attr.c_stor_type == U8 ) ||
(post_ops_attr.c_stor_type == S8 ) )
{
zmm1 = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
zmm1 = _mm512_set1_ps
( *( ( float* )post_ops_list_temp->op_args2 ) );
}
__m512 al_in, r, r2, z, dn;
__m512i ex_out;

View File

@@ -536,8 +536,19 @@ POST_OPS_RELU_6x64_OPS:
POST_OPS_RELU_SCALE_6x64_OPS:
{
zmm1 = _mm512_setzero_ps();
zmm2 =
_mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
if ( (post_ops_attr.c_stor_type == S32 ) ||
(post_ops_attr.c_stor_type == U8 ) ||
(post_ops_attr.c_stor_type == S8 ) )
{
zmm2 = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
zmm2 = _mm512_set1_ps(
*( ( float* )post_ops_list_temp->op_args2 ) );
}
__mmask16 relu_cmp_mask;
@@ -774,9 +785,24 @@ POST_OPS_GELU_ERF_6x64_OPS:
}
POST_OPS_CLIP_6x64_OPS:
{
__m512 min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 );
__m512 max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 );
__m512 min, max;
if( post_ops_attr.c_stor_type == S32 ||
post_ops_attr.c_stor_type == S8 ||
post_ops_attr.c_stor_type == U8 )
{
min = _mm512_cvtepi32_ps
( _mm512_set1_epi32( *( ( int32_t* ) post_ops_list_temp->op_args2 ) ) );
max = _mm512_cvtepi32_ps
( _mm512_set1_epi32( *( ( int32_t* ) post_ops_list_temp->op_args3 ) ) );
}
else
{
min = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args2 );
max = _mm512_set1_ps( *( float* )post_ops_list_temp->op_args3 );
}
float arr[16];
// c[0, 0-15]
CLIP_F32S_AVX512(zmm8, min, max)
@@ -1857,8 +1883,20 @@ POST_OPS_MATRIX_MUL_6x64_OPS:
}
POST_OPS_SWISH_6x64_OPS:
{
zmm1 =
_mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
if ( (post_ops_attr.c_stor_type == S32 ) ||
(post_ops_attr.c_stor_type == U8 ) ||
(post_ops_attr.c_stor_type == S8 ) )
{
zmm1 = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
zmm1 = _mm512_set1_ps
( *( ( float* )post_ops_list_temp->op_args2 ) );
}
__m512 al_in, r, r2, z, dn;
__m512i ex_out;