mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Bug Fixes for Accuracy issues in Int8 API (#62)
- In U8 GEMV n=1 kernels, the default zp condition was S8 ZP type, which leads to accuracy issues which u8s8s32u8 API is used. - Few modifications in bench code to take the correct path for accuracy check.
This commit is contained in:
@@ -184,7 +184,6 @@ static inline void lpgemm_free( void* p )
|
||||
bli_free_user(p);
|
||||
}
|
||||
}
|
||||
|
||||
bool is_integerAPI_avx512( char* api_name )
|
||||
{
|
||||
if ( ( strcmp( api_name, "u8s8s32of32" ) == 0) || ( strcmp( api_name, "u8s8s32os8" ) == 0) \
|
||||
@@ -544,7 +543,7 @@ static inline ACCUM_type get_matrix_add_post_op_val_ ## BLAS_SFX \
|
||||
/* default case */ \
|
||||
if( is_integerAPI_avx512(#BLAS_SFX) ) \
|
||||
{ \
|
||||
if( strcmp( #BLAS_SFX, "u8s8s32os8" ) == 0 ) \
|
||||
if( ( strcmp( #BLAS_SFX, "s8s8s32os8" ) == 0 ) || ( strcmp( #BLAS_SFX, "u8s8s32os8" ) ) == 0 ) \
|
||||
{ \
|
||||
float ret_val = 0.0; \
|
||||
int8_t_to_float( *( ( int8_t* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ), &ret_val ); \
|
||||
@@ -1646,7 +1645,7 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
} \
|
||||
else if ( strcmp( ops_tok, "sf_stor_type" ) == 0) \
|
||||
{ \
|
||||
ops_tok = strtok( NULL, ", " ); \
|
||||
ops_tok = strtok( NULL, ", " ); \
|
||||
if( ( strcmp( ops_tok, "na" ) == 0 ) ) \
|
||||
{ \
|
||||
is_sf_stor_type = FALSE; \
|
||||
@@ -2230,6 +2229,7 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->sum )->sf_stor_type = NULLTYPE; \
|
||||
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( DSCALE_type ) ); \
|
||||
if ( ( post_ops->sum )->scale_factor == NULL ) \
|
||||
{ \
|
||||
|
||||
@@ -764,15 +764,14 @@ LPGEMV_N_EQ1_KERN(uint8_t, int8_t, int32_t, u8s8s32os32)
|
||||
{
|
||||
S32_F32_ZP_BCST(zero_point0)
|
||||
}
|
||||
else if( post_ops_list_temp->zp_stor_type == U8 )
|
||||
{
|
||||
U8_F32_ZP_BCST(zero_point0)
|
||||
}
|
||||
else
|
||||
else if( post_ops_list_temp->zp_stor_type == S8 )
|
||||
{
|
||||
S8_F32_ZP_BCST(zero_point0)
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
U8_F32_ZP_BCST(zero_point0)
|
||||
}
|
||||
MULADD_RND_F32(acc_8, scale0, zero_point0 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
|
||||
Reference in New Issue
Block a user