Fix zero point datatype issue.

Description
 Due to different datatype for zero point during post-op creation
 and accuracy check we see an accuracy issue for u8/s8s8s32 apis
 with output type f32/bf16.

 AMD-Internal: [CPUPL-6456]

Change-Id: If8925988841af87cb5687c84aade607967c744fe
This commit is contained in:
Deepak Negi
2025-02-19 11:02:47 +00:00
committed by Nallani Bhaskar
parent a0005c60ce
commit c813bfa609
3 changed files with 26 additions and 18 deletions

View File

@@ -232,15 +232,15 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,u8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,u8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,u8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int8_t,int32_t,float,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,uint8_t,int32_t,float,u8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,float,int32_t,float,u8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,bfloat16,int32_t,float,u8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,s8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,s8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,s8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int8_t,int32_t,float,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,uint8_t,int32_t,float,s8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,float,int32_t,float,s8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,bfloat16,int32_t,float,s8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)

View File

@@ -216,15 +216,15 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,u8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,u8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,u8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int8_t,int32_t,float,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,uint8_t,int32_t,float,u8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,float,int32_t,float,u8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,bfloat16,int32_t,float,u8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,s8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,s8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,s8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int8_t,int32_t,float,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,uint8_t,int32_t,float,s8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,float,int32_t,float,s8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,bfloat16,int32_t,float,s8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)
@@ -1162,6 +1162,8 @@ int main( int argc, char** argv )
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'n';
global_pre_op = 'n';
DSCALE_CLIP_MIN = INT_MIN;
DSCALE_CLIP_MAX = INT_MAX;
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32of32)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
@@ -1175,6 +1177,8 @@ int main( int argc, char** argv )
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'n';
global_pre_op = 'n';
DSCALE_CLIP_MIN = INT_MIN;
DSCALE_CLIP_MAX = INT_MAX;
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32obf16)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
@@ -1315,6 +1319,8 @@ int main( int argc, char** argv )
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'n';
global_pre_op = 'n';
DSCALE_CLIP_MIN = INT_MIN;
DSCALE_CLIP_MAX = INT_MAX;
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32obf16)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
@@ -1328,6 +1334,8 @@ int main( int argc, char** argv )
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'n';
global_pre_op = 'n';
DSCALE_CLIP_MIN = INT_MIN;
DSCALE_CLIP_MAX = INT_MAX;
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32of32)
(
fin, fout, stor_order, transa, transb, op_a, op_b,

View File

@@ -771,7 +771,7 @@ void print_matrix_bfloat16
}
}
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(ZP_type,C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
(\
ACCUM_type temp_accum,\
@@ -795,7 +795,7 @@ static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX
( ACCUM_type )min( \
max( nearbyintf( ( SCALE_type )( temp_accum ) * \
( *( ( SCALE_type* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
*( ( C_type* )( post_op->sum )->zero_point + j_zp ), \
*( ( ZP_type* )( post_op->sum )->zero_point + j_zp ), \
DSCALE_CLIP_MIN ), \
DSCALE_CLIP_MAX ); \
return out_temp_accum; \