mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Fix zero point datatype issue.
Description Due to different datatype for zero point during post-op creation and accuracy check we see an accuracy issue for u8/s8s8s32 apis with output type f32/bf16. AMD-Internal: [CPUPL-6456] Change-Id: If8925988841af87cb5687c84aade607967c744fe
This commit is contained in:
committed by
Nallani Bhaskar
parent
a0005c60ce
commit
c813bfa609
@@ -232,15 +232,15 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,u8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,u8s8s32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,u8s8s32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int8_t,int32_t,float,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,uint8_t,int32_t,float,u8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,float,int32_t,float,u8s8s32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,bfloat16,int32_t,float,u8s8s32obf16)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,s8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,s8s8s32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,s8s8s32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int8_t,int32_t,float,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,uint8_t,int32_t,float,s8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,float,int32_t,float,s8s8s32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,bfloat16,int32_t,float,s8s8s32obf16)
|
||||
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)
|
||||
|
||||
@@ -216,15 +216,15 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
|
||||
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,u8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,u8s8s32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,u8s8s32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int8_t,int32_t,float,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,uint8_t,int32_t,float,u8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,float,int32_t,float,u8s8s32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,bfloat16,int32_t,float,u8s8s32obf16)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,s8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,s8s8s32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,s8s8s32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int8_t,int32_t,float,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,uint8_t,int32_t,float,s8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,float,int32_t,float,s8s8s32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,bfloat16,int32_t,float,s8s8s32obf16)
|
||||
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)
|
||||
@@ -1162,6 +1162,8 @@ int main( int argc, char** argv )
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'n';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = INT_MIN;
|
||||
DSCALE_CLIP_MAX = INT_MAX;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32of32)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
@@ -1175,6 +1177,8 @@ int main( int argc, char** argv )
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'n';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = INT_MIN;
|
||||
DSCALE_CLIP_MAX = INT_MAX;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32obf16)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
@@ -1315,6 +1319,8 @@ int main( int argc, char** argv )
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'n';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = INT_MIN;
|
||||
DSCALE_CLIP_MAX = INT_MAX;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32obf16)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
@@ -1328,6 +1334,8 @@ int main( int argc, char** argv )
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'n';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = INT_MIN;
|
||||
DSCALE_CLIP_MAX = INT_MAX;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32of32)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
|
||||
@@ -771,7 +771,7 @@ void print_matrix_bfloat16
|
||||
}
|
||||
}
|
||||
|
||||
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
|
||||
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(ZP_type,C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
|
||||
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
|
||||
(\
|
||||
ACCUM_type temp_accum,\
|
||||
@@ -795,7 +795,7 @@ static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX
|
||||
( ACCUM_type )min( \
|
||||
max( nearbyintf( ( SCALE_type )( temp_accum ) * \
|
||||
( *( ( SCALE_type* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
|
||||
*( ( C_type* )( post_op->sum )->zero_point + j_zp ), \
|
||||
*( ( ZP_type* )( post_op->sum )->zero_point + j_zp ), \
|
||||
DSCALE_CLIP_MIN ), \
|
||||
DSCALE_CLIP_MAX ); \
|
||||
return out_temp_accum; \
|
||||
|
||||
Reference in New Issue
Block a user