mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Modified bench to test different types of post-ops
- Modified bench to support testing of different types of buffers for bias, mat_add and mat_mul postops. - Added support for testing integer APIs with float accumulation type. Change-Id: I72364e9ad25e6148042b93ec6d152ff82ea03e96
This commit is contained in:
@@ -36,12 +36,6 @@
|
||||
|
||||
#define POST_OPS_STR_LEN 104
|
||||
|
||||
CONVERT_TO_FLOAT(uint8_t)
|
||||
CONVERT_TO_FLOAT(int8_t)
|
||||
CONVERT_TO_FLOAT(int16_t)
|
||||
CONVERT_TO_FLOAT(float)
|
||||
CONVERT_TO_FLOAT(int32_t)
|
||||
|
||||
PRINT_MATRIX(uint8_t)
|
||||
PRINT_MATRIX(int8_t)
|
||||
PRINT_MATRIX(int16_t)
|
||||
@@ -256,11 +250,8 @@ static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX
|
||||
return out_temp_accum; \
|
||||
}\
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,u8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int16_t,float,u8s8s16ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,s8s8s16os8)
|
||||
|
||||
static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16
|
||||
(
|
||||
@@ -347,14 +338,9 @@ static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
|
||||
return temp_accum; \
|
||||
} \
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int16_t,s8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int16_t,int16_t,s8s8s16os16)
|
||||
|
||||
#define GEN_MAT_MUL_ACC_CHK_ACCUM_INT4(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \
|
||||
static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
|
||||
@@ -640,15 +626,10 @@ static inline float mat_mul_accuracy_check_accum_bf16s4f32obf16
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os8)
|
||||
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16ou8)
|
||||
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os16)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os8)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os32)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os8)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os32)
|
||||
GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os8)
|
||||
GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os16)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os8)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32os32)
|
||||
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(f32f32f32of32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32of32)
|
||||
@@ -656,15 +637,10 @@ GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32obf16)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(bf16s4f32of32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(bf16s4f32obf16)
|
||||
|
||||
GEN_TANH_POSTOP_INT(int16_t,u8s8s16os8)
|
||||
GEN_TANH_POSTOP_INT(int16_t,u8s8s16ou8)
|
||||
GEN_TANH_POSTOP_INT(int16_t,u8s8s16os16)
|
||||
GEN_TANH_POSTOP_INT(int32_t,u8s8s32os8)
|
||||
GEN_TANH_POSTOP_INT(int32_t,u8s8s32os32)
|
||||
GEN_TANH_POSTOP_INT(int32_t,s8s8s32os8)
|
||||
GEN_TANH_POSTOP_INT(int32_t,s8s8s32os32)
|
||||
GEN_TANH_POSTOP_INT(int16_t,s8s8s16os8)
|
||||
GEN_TANH_POSTOP_INT(int16_t,s8s8s16os16)
|
||||
GEN_TANH_POSTOP_FLOAT(u8s8s32os8)
|
||||
GEN_TANH_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_TANH_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_TANH_POSTOP_FLOAT(s8s8s32os32)
|
||||
|
||||
GEN_TANH_POSTOP_FLOAT(f32f32f32of32)
|
||||
GEN_TANH_POSTOP_FLOAT(bf16bf16f32of32)
|
||||
@@ -672,15 +648,10 @@ GEN_TANH_POSTOP_FLOAT(bf16bf16f32obf16)
|
||||
GEN_TANH_POSTOP_FLOAT(bf16s4f32of32)
|
||||
GEN_TANH_POSTOP_FLOAT(bf16s4f32obf16)
|
||||
|
||||
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os8)
|
||||
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16ou8)
|
||||
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os16)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os8)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os32)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os8)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os32)
|
||||
GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os8)
|
||||
GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os16)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32os8)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32os32)
|
||||
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(f32f32f32of32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32of32)
|
||||
@@ -688,15 +659,10 @@ GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32obf16)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(bf16s4f32of32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(bf16s4f32obf16)
|
||||
|
||||
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16os8)
|
||||
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16ou8)
|
||||
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16os16)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32os8)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32os32)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32os8)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32os32)
|
||||
GEN_SWISH_POSTOP_INT(int16_t,s8s8s16os8)
|
||||
GEN_SWISH_POSTOP_INT(int16_t,s8s8s16os16)
|
||||
GEN_SWISH_POSTOP_FLOAT(u8s8s32os8)
|
||||
GEN_SWISH_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_SWISH_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_SWISH_POSTOP_FLOAT(s8s8s32os32)
|
||||
|
||||
GEN_SWISH_POSTOP_FLOAT(f32f32f32of32)
|
||||
GEN_SWISH_POSTOP_FLOAT(bf16bf16f32of32)
|
||||
@@ -704,15 +670,10 @@ GEN_SWISH_POSTOP_FLOAT(bf16bf16f32obf16)
|
||||
GEN_SWISH_POSTOP_FLOAT(bf16s4f32of32)
|
||||
GEN_SWISH_POSTOP_FLOAT(bf16s4f32obf16)
|
||||
|
||||
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16os8)
|
||||
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16ou8)
|
||||
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16os16)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32os8)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32os32)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32os8)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32os32)
|
||||
GEN_SIGMOID_POSTOP_INT(int16_t,s8s8s16os8)
|
||||
GEN_SIGMOID_POSTOP_INT(int16_t,s8s8s16os16)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32os8)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32os32)
|
||||
|
||||
GEN_SIGMOID_POSTOP_FLOAT(f32f32f32of32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(bf16bf16f32of32)
|
||||
@@ -720,50 +681,35 @@ GEN_SIGMOID_POSTOP_FLOAT(bf16bf16f32obf16)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(bf16s4f32of32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(bf16s4f32obf16)
|
||||
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16bf16f32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16s4f32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16bf16f32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16s4f32obf16)
|
||||
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int32_t,u8s8s32os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,int32_t,u8s8s32os32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int16_t,u8s8s16os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(uint8_t,int16_t,u8s8s16ou8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int16_t,int16_t,u8s8s16os16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int32_t,s8s8s32os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,int32_t,s8s8s32os32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int16_t,s8s8s16os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int16_t,int16_t,s8s8s16os16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,f32f32f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16bf16f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16s4f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,u8s8s32os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,u8s8s32os32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32os32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32f32f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16bf16f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16s4f32of32)
|
||||
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16bf16f32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16s4f32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16bf16f32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16s4f32obf16)
|
||||
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int32_t,u8s8s32os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,int32_t,u8s8s32os32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int16_t,u8s8s16os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(uint8_t,int16_t,u8s8s16ou8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int16_t,int16_t,u8s8s16os16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int32_t,s8s8s32os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,int32_t,s8s8s32os32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int16_t,s8s8s16os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int16_t,int16_t,s8s8s16os16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,f32f32f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16bf16f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16s4f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,u8s8s32os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,u8s8s32os32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32os32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32f32f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16bf16f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16s4f32of32)
|
||||
|
||||
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16bf16f32obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16s4f32obf16)
|
||||
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32os32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16ou8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16os16)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os16)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(f32f32f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
|
||||
@@ -916,8 +862,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
{ \
|
||||
temp_accum = GEN_FUNC_NAME(SWISH_post_op_,BLAS_SFX) \
|
||||
(temp_accum, \
|
||||
*( ( ACCUM_type* ) \
|
||||
( post_op[bs_i]->eltwise + ele_i )->algo.alpha ) );\
|
||||
( post_op[bs_i]->eltwise + ele_i )->algo.alpha );\
|
||||
ele_i += 1; \
|
||||
} \
|
||||
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
|
||||
@@ -975,9 +920,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
float* scl_fctr = ( float* )( ( post_op[bs_i]->matrix_add )->scale_factor ); \
|
||||
dim_t scl_fctr_len = ( post_op[bs_i]->matrix_add )->scale_factor_len; \
|
||||
temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
|
||||
( *( ( C_type* )( post_op[bs_i]->matrix_add )->matrix + \
|
||||
( i * rs_m ) + ( j * cs_m ) ), \
|
||||
j, scl_fctr, scl_fctr_len ); \
|
||||
( ( post_op[bs_i]->matrix_add )->matrix, i, \
|
||||
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op[bs_i]->matrix_add)->stor_type ); \
|
||||
} \
|
||||
else if ( post_op[bs_i]->seq_vector[op_id] == MATRIX_MUL ) \
|
||||
{ \
|
||||
@@ -991,9 +935,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
float* scl_fctr = ( float* )( ( post_op[bs_i]->matrix_mul )->scale_factor ); \
|
||||
dim_t scl_fctr_len = ( post_op[bs_i]->matrix_mul )->scale_factor_len; \
|
||||
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,BLAS_SFX) \
|
||||
( *( ( C_type* )( post_op[bs_i]->matrix_mul )->matrix + \
|
||||
( i * rs_m ) + ( j * cs_m ) ), \
|
||||
j, scl_fctr, scl_fctr_len ); \
|
||||
(( post_op[bs_i]->matrix_mul )->matrix, i, \
|
||||
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op[bs_i]->matrix_mul)->stor_type ); \
|
||||
} \
|
||||
else \
|
||||
{} \
|
||||
@@ -1031,9 +974,6 @@ cleanup_acc: \
|
||||
return; \
|
||||
} \
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,float,u8s8s16os16,u8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,float,u8s8s16os8,u8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,uint8_t,int16_t,float,u8s8s16ou8,u8s8s16ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,u8s8s32os32,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,u8s8s32os8,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,float,float,bf16bf16f32of32,bf16bf16f32obf16)
|
||||
@@ -1041,26 +981,20 @@ GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,float,bf16bf16f32o
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,f32f32f32of32,f32f32f32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,float,s8s8s32os32,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,float,s8s8s32os8,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int16_t,int16_t,float,s8s8s16os16,s8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,float,s8s8s16os8,s8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,float,float,float,bf16s4f32of32,bf16bf16f32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,bfloat16,float,float,bf16s4f32obf16,bf16bf16f32obf16)
|
||||
|
||||
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,int16_t,u8s8s16os16)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,u8s8s32os32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16bf16f32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(float,float,float,float,f32f32f32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32os32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,int16_t,s8s8s16os16)
|
||||
|
||||
// Hack to fix compiler errors.
|
||||
#define GET_B_TYPE_bf16bf16f32of32 bfloat16
|
||||
#define GET_B_TYPE_u8s8s16os16 int8_t
|
||||
#define GET_B_TYPE_u8s8s32os32 int8_t
|
||||
#define GET_B_TYPE_f32f32f32of32 float
|
||||
#define GET_B_TYPE_s8s8s32os32 int8_t
|
||||
#define GET_B_TYPE_s8s8s16os16 int8_t
|
||||
|
||||
#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type, B_type, C_type, Sum_type, BLAS_SFX, REORDER_SFX, INT4_REORDER_SFX) \
|
||||
void mat_mul_bench_main_ ## BLAS_SFX \
|
||||
|
||||
@@ -34,20 +34,12 @@
|
||||
|
||||
#include "bench_lpgemm_helpers.h"
|
||||
|
||||
CONVERT_TO_FLOAT(uint8_t)
|
||||
CONVERT_TO_FLOAT(int8_t)
|
||||
CONVERT_TO_FLOAT(int16_t)
|
||||
CONVERT_TO_FLOAT(float)
|
||||
CONVERT_TO_FLOAT(int32_t)
|
||||
|
||||
PRINT_MATRIX(uint8_t)
|
||||
PRINT_MATRIX(int8_t)
|
||||
PRINT_MATRIX(int16_t)
|
||||
PRINT_MATRIX(float)
|
||||
PRINT_MATRIX(int32_t)
|
||||
|
||||
GEN_FILL_ARRAY_FUNC(int8_t)
|
||||
GEN_FILL_ARRAY_FUNC(int16_t)
|
||||
GEN_FILL_ARRAY_FUNC(float)
|
||||
GEN_FILL_ARRAY_FUNC(int32_t)
|
||||
|
||||
@@ -106,9 +98,6 @@ void mat_mul_ ## BLAS_SFX \
|
||||
c, ldc, post_op ); \
|
||||
} \
|
||||
|
||||
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16)
|
||||
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8)
|
||||
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8)
|
||||
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
|
||||
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
|
||||
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8)
|
||||
@@ -121,8 +110,6 @@ GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
|
||||
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
|
||||
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
|
||||
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32)
|
||||
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16)
|
||||
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8)
|
||||
GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
|
||||
|
||||
@@ -207,9 +194,6 @@ void mat_mul_bench_driver_ ## BLAS_SFX \
|
||||
print_result( XSTR(BLAS_SFX), n_repeats, transa, transb, m, n, k, lda, ldb, ldc, gflops); \
|
||||
} \
|
||||
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8)
|
||||
@@ -222,8 +206,6 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
|
||||
|
||||
@@ -257,9 +239,6 @@ static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX
|
||||
return out_temp_accum; \
|
||||
}\
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,u8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int16_t,float,u8s8s16ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,s8s8s16os8)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,u8s8s32ou8)
|
||||
@@ -357,12 +336,6 @@ static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
|
||||
return temp_accum; \
|
||||
} \
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int16_t,s8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int16_t,int16_t,s8s8s16os16)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
|
||||
@@ -736,20 +709,15 @@ static inline float mat_mul_accuracy_check_accum_bf16s4f32obf16
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os8)
|
||||
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16ou8)
|
||||
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os16)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os8)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32ou8)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os32)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32obf16)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32of32)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os8)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os32)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32obf16)
|
||||
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32of32)
|
||||
GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os8)
|
||||
GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os16)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os8)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32ou8)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32obf16)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32of32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32os32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32obf16)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32of32)
|
||||
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(f32f32f32of32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32of32)
|
||||
@@ -757,20 +725,15 @@ GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32obf16)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(bf16s4f32of32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(bf16s4f32obf16)
|
||||
|
||||
GEN_TANH_POSTOP_INT(int16_t,u8s8s16os8)
|
||||
GEN_TANH_POSTOP_INT(int16_t,u8s8s16ou8)
|
||||
GEN_TANH_POSTOP_INT(int16_t,u8s8s16os16)
|
||||
GEN_TANH_POSTOP_INT(int32_t,u8s8s32os8)
|
||||
GEN_TANH_POSTOP_INT(int32_t,u8s8s32ou8)
|
||||
GEN_TANH_POSTOP_INT(int32_t,u8s8s32os32)
|
||||
GEN_TANH_POSTOP_INT(int32_t,u8s8s32obf16)
|
||||
GEN_TANH_POSTOP_INT(int32_t,u8s8s32of32)
|
||||
GEN_TANH_POSTOP_INT(int32_t,s8s8s32os8)
|
||||
GEN_TANH_POSTOP_INT(int32_t,s8s8s32obf16)
|
||||
GEN_TANH_POSTOP_INT(int32_t,s8s8s32of32)
|
||||
GEN_TANH_POSTOP_INT(int32_t,s8s8s32os32)
|
||||
GEN_TANH_POSTOP_INT(int16_t,s8s8s16os8)
|
||||
GEN_TANH_POSTOP_INT(int16_t,s8s8s16os16)
|
||||
GEN_TANH_POSTOP_FLOAT(u8s8s32os8)
|
||||
GEN_TANH_POSTOP_FLOAT(u8s8s32ou8)
|
||||
GEN_TANH_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_TANH_POSTOP_FLOAT(u8s8s32obf16)
|
||||
GEN_TANH_POSTOP_FLOAT(u8s8s32of32)
|
||||
GEN_TANH_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_TANH_POSTOP_FLOAT(s8s8s32obf16)
|
||||
GEN_TANH_POSTOP_FLOAT(s8s8s32of32)
|
||||
GEN_TANH_POSTOP_FLOAT(s8s8s32os32)
|
||||
|
||||
GEN_TANH_POSTOP_FLOAT(f32f32f32of32)
|
||||
GEN_TANH_POSTOP_FLOAT(bf16bf16f32of32)
|
||||
@@ -778,20 +741,15 @@ GEN_TANH_POSTOP_FLOAT(bf16bf16f32obf16)
|
||||
GEN_TANH_POSTOP_FLOAT(bf16s4f32of32)
|
||||
GEN_TANH_POSTOP_FLOAT(bf16s4f32obf16)
|
||||
|
||||
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os8)
|
||||
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16ou8)
|
||||
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os16)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os8)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32ou8)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os32)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32obf16)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32of32)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os8)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os32)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32obf16)
|
||||
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32of32)
|
||||
GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os8)
|
||||
GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os16)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32os8)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32ou8)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32obf16)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32of32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32os32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32obf16)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32of32)
|
||||
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(f32f32f32of32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32of32)
|
||||
@@ -799,20 +757,15 @@ GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32obf16)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(bf16s4f32of32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(bf16s4f32obf16)
|
||||
|
||||
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16os8)
|
||||
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16ou8)
|
||||
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16os16)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32os8)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32ou8)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32os32)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32obf16)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32of32)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32os8)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32os32)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32obf16)
|
||||
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32of32)
|
||||
GEN_SWISH_POSTOP_INT(int16_t,s8s8s16os8)
|
||||
GEN_SWISH_POSTOP_INT(int16_t,s8s8s16os16)
|
||||
GEN_SWISH_POSTOP_INT(float,u8s8s32os8)
|
||||
GEN_SWISH_POSTOP_INT(float,u8s8s32ou8)
|
||||
GEN_SWISH_POSTOP_INT(float,u8s8s32os32)
|
||||
GEN_SWISH_POSTOP_FLOAT(u8s8s32obf16)
|
||||
GEN_SWISH_POSTOP_FLOAT(u8s8s32of32)
|
||||
GEN_SWISH_POSTOP_INT(float,s8s8s32os8)
|
||||
GEN_SWISH_POSTOP_INT(float,s8s8s32os32)
|
||||
GEN_SWISH_POSTOP_FLOAT(s8s8s32obf16)
|
||||
GEN_SWISH_POSTOP_FLOAT(s8s8s32of32)
|
||||
|
||||
GEN_SWISH_POSTOP_FLOAT(f32f32f32of32)
|
||||
GEN_SWISH_POSTOP_FLOAT(bf16bf16f32of32)
|
||||
@@ -820,20 +773,15 @@ GEN_SWISH_POSTOP_FLOAT(bf16bf16f32obf16)
|
||||
GEN_SWISH_POSTOP_FLOAT(bf16s4f32of32)
|
||||
GEN_SWISH_POSTOP_FLOAT(bf16s4f32obf16)
|
||||
|
||||
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16os8)
|
||||
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16ou8)
|
||||
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16os16)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32os8)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32ou8)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32os32)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32obf16)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32of32)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32os8)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32os32)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32obf16)
|
||||
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32of32)
|
||||
GEN_SIGMOID_POSTOP_INT(int16_t,s8s8s16os8)
|
||||
GEN_SIGMOID_POSTOP_INT(int16_t,s8s8s16os16)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32os8)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32ou8)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32obf16)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32of32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32os32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32obf16)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32of32)
|
||||
|
||||
GEN_SIGMOID_POSTOP_FLOAT(f32f32f32of32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(bf16bf16f32of32)
|
||||
@@ -841,65 +789,85 @@ GEN_SIGMOID_POSTOP_FLOAT(bf16bf16f32obf16)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(bf16s4f32of32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(bf16s4f32obf16)
|
||||
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16bf16f32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16s4f32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16bf16f32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16s4f32obf16)
|
||||
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int32_t,u8s8s32os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(uint8_t,int32_t,u8s8s32ou8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,int32_t,u8s8s32os32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(bfloat16,int32_t,u8s8s32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,int32_t,u8s8s32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int16_t,u8s8s16os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(uint8_t,int16_t,u8s8s16ou8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int16_t,int16_t,u8s8s16os16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int32_t,s8s8s32os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,int32_t,s8s8s32os32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(bfloat16,int32_t,s8s8s32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,int32_t,s8s8s32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int16_t,s8s8s16os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int16_t,int16_t,s8s8s16os16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,f32f32f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16bf16f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16s4f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32ou8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32os32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32os32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32f32f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16bf16f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16s4f32of32)
|
||||
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16bf16f32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16s4f32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16bf16f32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16s4f32obf16)
|
||||
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32ou8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32os32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32os32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32f32f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16bf16f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16s4f32of32)
|
||||
|
||||
|
||||
GEN_PRELU_POST_OP_VAL_FLOAT(f32f32f32of32)
|
||||
GEN_PRELU_POST_OP_VAL_FLOAT(bf16bf16f32of32)
|
||||
GEN_PRELU_POST_OP_VAL_FLOAT(bf16bf16f32obf16)
|
||||
GEN_PRELU_POST_OP_VAL_FLOAT(bf16s4f32of32)
|
||||
GEN_PRELU_POST_OP_VAL_FLOAT(bf16s4f32obf16)
|
||||
GEN_PRELU_POST_OP_VAL_FLOAT(u8s8s32obf16)
|
||||
GEN_PRELU_POST_OP_VAL_FLOAT(u8s8s32of32)
|
||||
GEN_PRELU_POST_OP_VAL_FLOAT(s8s8s32obf16)
|
||||
GEN_PRELU_POST_OP_VAL_FLOAT(s8s8s32of32)
|
||||
|
||||
GEN_PRELU_POST_OP_VAL_INT(u8s8s32os8)
|
||||
GEN_PRELU_POST_OP_VAL_INT(u8s8s32ou8)
|
||||
GEN_PRELU_POST_OP_VAL_INT(u8s8s32os32)
|
||||
GEN_PRELU_POST_OP_VAL_INT(s8s8s32os8)
|
||||
GEN_PRELU_POST_OP_VAL_INT(s8s8s32os32)
|
||||
|
||||
|
||||
GEN_CLIP_POST_OP_VAL_FLOAT(f32f32f32of32)
|
||||
GEN_CLIP_POST_OP_VAL_FLOAT(bf16bf16f32of32)
|
||||
GEN_CLIP_POST_OP_VAL_FLOAT(bf16bf16f32obf16)
|
||||
GEN_CLIP_POST_OP_VAL_FLOAT(bf16s4f32of32)
|
||||
GEN_CLIP_POST_OP_VAL_FLOAT(bf16s4f32obf16)
|
||||
GEN_CLIP_POST_OP_VAL_FLOAT(u8s8s32obf16)
|
||||
GEN_CLIP_POST_OP_VAL_FLOAT(u8s8s32of32)
|
||||
GEN_CLIP_POST_OP_VAL_FLOAT(s8s8s32obf16)
|
||||
GEN_CLIP_POST_OP_VAL_FLOAT(s8s8s32of32)
|
||||
|
||||
GEN_CLIP_POST_OP_VAL_INT(u8s8s32os8)
|
||||
GEN_CLIP_POST_OP_VAL_INT(u8s8s32ou8)
|
||||
GEN_CLIP_POST_OP_VAL_INT(u8s8s32os32)
|
||||
GEN_CLIP_POST_OP_VAL_INT(s8s8s32os8)
|
||||
GEN_CLIP_POST_OP_VAL_INT(s8s8s32os32)
|
||||
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int32_t,u8s8s32os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(uint8_t,int32_t,u8s8s32ou8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,int32_t,u8s8s32os32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(bfloat16,int32_t,u8s8s32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,int32_t,u8s8s32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int16_t,u8s8s16os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(uint8_t,int16_t,u8s8s16ou8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int16_t,int16_t,u8s8s16os16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int32_t,s8s8s32os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,int32_t,s8s8s32os32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(bfloat16,int32_t,s8s8s32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,int32_t,s8s8s32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int16_t,s8s8s16os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int16_t,int16_t,s8s8s16os16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,f32f32f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16bf16f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16s4f32of32)
|
||||
|
||||
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16bf16f32obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16s4f32obf16)
|
||||
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32ou8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32os32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16ou8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16os16)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32ou8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32os32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os16)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(f32f32f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
|
||||
@@ -907,12 +875,14 @@ GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,int32_t)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int32_t)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,int32_t)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int16_t,int16_t)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int16_t)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,int16_t)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,float)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,float)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,float)
|
||||
|
||||
#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,SCALE_type,BLAS_SFX,BLAS_DOWNSCALE_SFX) \
|
||||
|
||||
|
||||
#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,POST_ACCUM_type,SCALE_type,BLAS_SFX,BLAS_DOWNSCALE_SFX) \
|
||||
void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
( \
|
||||
FILE* fout, \
|
||||
@@ -1007,6 +977,15 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
rs_a, rs_b, cs_a, cs_b, rs_c_ref, cs_c_ref, i, j, k, n, \
|
||||
a_pre_op); \
|
||||
\
|
||||
POST_ACCUM_type post_temp_accum = 0; \
|
||||
if ( is_integerAPI_avx512(#BLAS_SFX) ) \
|
||||
{ \
|
||||
CVT_FUNC_NAME(ACCUM_type,POST_ACCUM_type)(temp_accum, &post_temp_accum); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
post_temp_accum = temp_accum; \
|
||||
} \
|
||||
if ( post_op != NULL ) \
|
||||
{ \
|
||||
dim_t ele_i = 0; \
|
||||
@@ -1014,7 +993,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
{ \
|
||||
if ( post_op->seq_vector[op_id] == BIAS ) \
|
||||
{ \
|
||||
temp_accum += GEN_FUNC_NAME(get_bias_post_op_val_,BLAS_SFX) \
|
||||
post_temp_accum += GEN_FUNC_NAME(get_bias_post_op_val_,BLAS_SFX) \
|
||||
( ( post_op->bias )->bias, j, ( post_op->bias )->stor_type ); \
|
||||
} \
|
||||
else if ( post_op->seq_vector[op_id] == ELTWISE ) \
|
||||
@@ -1022,66 +1001,56 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
|
||||
PRELU ) /* PReLU*/ \
|
||||
{ \
|
||||
temp_accum = ( temp_accum > 0 ) ? \
|
||||
temp_accum : \
|
||||
( temp_accum * \
|
||||
*( ( ACCUM_type* ) ( post_op->eltwise + ele_i )->algo.alpha ) ); \
|
||||
post_temp_accum = GEN_FUNC_NAME(get_prelu_post_op_val_,BLAS_SFX) \
|
||||
(post_temp_accum, ( post_op->eltwise + ele_i )->algo.alpha ); \
|
||||
ele_i += 1; \
|
||||
} \
|
||||
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
|
||||
GELU_TANH ) /* TANH GeLU*/ \
|
||||
{ \
|
||||
temp_accum = GEN_FUNC_NAME(GELU_TANH_post_op_,BLAS_SFX) (temp_accum);\
|
||||
post_temp_accum = GEN_FUNC_NAME(GELU_TANH_post_op_,BLAS_SFX) (post_temp_accum);\
|
||||
ele_i += 1; \
|
||||
} \
|
||||
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
|
||||
GELU_ERF ) /* ERF GeLU*/ \
|
||||
{ \
|
||||
temp_accum = GEN_FUNC_NAME(GELU_ERF_post_op_,BLAS_SFX) (temp_accum);\
|
||||
post_temp_accum = GEN_FUNC_NAME(GELU_ERF_post_op_,BLAS_SFX) (post_temp_accum);\
|
||||
ele_i += 1; \
|
||||
} \
|
||||
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
|
||||
SWISH ) /* SiLU*/ \
|
||||
{ \
|
||||
temp_accum = GEN_FUNC_NAME(SWISH_post_op_,BLAS_SFX) \
|
||||
(temp_accum, \
|
||||
*( ( ACCUM_type* ) \
|
||||
( post_op->eltwise + ele_i )->algo.alpha ) );\
|
||||
post_temp_accum = GEN_FUNC_NAME(SWISH_post_op_,BLAS_SFX) \
|
||||
(post_temp_accum, \
|
||||
( post_op->eltwise + ele_i )->algo.alpha );\
|
||||
ele_i += 1; \
|
||||
} \
|
||||
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
|
||||
RELU ) /* ReLU*/ \
|
||||
{ \
|
||||
temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \
|
||||
post_temp_accum = ( post_temp_accum > 0 ) ? post_temp_accum : 0 ; \
|
||||
ele_i += 1; \
|
||||
} \
|
||||
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
|
||||
TANH ) /* TANH*/ \
|
||||
{ \
|
||||
temp_accum = GEN_FUNC_NAME(TANH_post_op_,BLAS_SFX) (temp_accum);\
|
||||
post_temp_accum = GEN_FUNC_NAME(TANH_post_op_,BLAS_SFX) (post_temp_accum);\
|
||||
ele_i += 1; \
|
||||
} \
|
||||
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
|
||||
SIGMOID ) /* Sigmoid*/ \
|
||||
{ \
|
||||
temp_accum = GEN_FUNC_NAME(SIGMOID_post_op_,BLAS_SFX) (temp_accum);\
|
||||
post_temp_accum = GEN_FUNC_NAME(SIGMOID_post_op_,BLAS_SFX) (post_temp_accum);\
|
||||
ele_i += 1; \
|
||||
} \
|
||||
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
|
||||
CLIP ) /* CLIP*/ \
|
||||
{ \
|
||||
temp_accum = \
|
||||
min \
|
||||
( \
|
||||
max \
|
||||
( \
|
||||
temp_accum, \
|
||||
*( ( ACCUM_type* ) \
|
||||
( post_op->eltwise + ele_i )->algo.alpha ) \
|
||||
), \
|
||||
*( ( ACCUM_type* ) \
|
||||
( post_op->eltwise + ele_i )->algo.beta) \
|
||||
); \
|
||||
post_temp_accum = GEN_FUNC_NAME(get_clip_post_op_val_,BLAS_SFX) \
|
||||
( post_temp_accum, \
|
||||
( post_op->eltwise + ele_i )->algo.alpha, \
|
||||
( post_op->eltwise + ele_i )->algo.beta \
|
||||
); \
|
||||
ele_i += 1; \
|
||||
} \
|
||||
else \
|
||||
@@ -1089,8 +1058,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
} \
|
||||
else if ( post_op->seq_vector[op_id] == SCALE ) \
|
||||
{ \
|
||||
temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
|
||||
(temp_accum, post_op, j); \
|
||||
post_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
|
||||
(post_temp_accum, post_op, j); \
|
||||
} \
|
||||
else if ( post_op->seq_vector[op_id] == MATRIX_ADD ) \
|
||||
{ \
|
||||
@@ -1101,31 +1070,11 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
cs_m = rs_m; \
|
||||
rs_m = 1; \
|
||||
} \
|
||||
const char* api_name = #BLAS_SFX; \
|
||||
/*The C_type is defined as f32 (32-bit floating point),while the ACCUM_type
|
||||
is defined as s32 (32-bit signed integer). This type mismatch can lead to
|
||||
unexpected behavior or random values during conversions.To address this issue,
|
||||
we implement specific handling for the following APIs. */ \
|
||||
\
|
||||
if ( ( strcmp( api_name, "u8s8s32of32" ) == 0) || ( strcmp( api_name, "s8s8s32of32" ) == 0) \
|
||||
|| ( strcmp( api_name, "u8s8s32obf16" ) == 0) || ( strcmp( api_name, "s8s8s32obf16" ) == 0) ) \
|
||||
{ \
|
||||
float* scl_fctr = ( float* )( ( post_op->matrix_add )->scale_factor ); \
|
||||
dim_t scl_fctr_len = ( post_op->matrix_add )->scale_factor_len; \
|
||||
temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
|
||||
( *( ( ACCUM_type* )( post_op->matrix_add )->matrix + \
|
||||
( i * rs_m ) + ( j * cs_m ) ), \
|
||||
j, scl_fctr, scl_fctr_len ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float* scl_fctr = ( float* )( ( post_op->matrix_add )->scale_factor ); \
|
||||
dim_t scl_fctr_len = ( post_op->matrix_add )->scale_factor_len; \
|
||||
temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
|
||||
( *( ( C_type* )( post_op->matrix_add )->matrix + \
|
||||
( i * rs_m ) + ( j * cs_m ) ), \
|
||||
j, scl_fctr, scl_fctr_len ); \
|
||||
} \
|
||||
float* scl_fctr = ( float* )( ( post_op->matrix_add )->scale_factor ); \
|
||||
dim_t scl_fctr_len = ( post_op->matrix_add )->scale_factor_len; \
|
||||
post_temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
|
||||
( ( post_op->matrix_add )->matrix, i, \
|
||||
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op->matrix_add )->stor_type ); \
|
||||
} \
|
||||
else if ( post_op->seq_vector[op_id] == MATRIX_MUL ) \
|
||||
{ \
|
||||
@@ -1136,40 +1085,20 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
cs_m = rs_m; \
|
||||
rs_m = 1; \
|
||||
} \
|
||||
const char* api_name = #BLAS_SFX; \
|
||||
/*The C_type is defined as f32 (32-bit floating point),while the ACCUM_type
|
||||
is defined as s32 (32-bit signed integer). This type mismatch can lead to
|
||||
unexpected behavior or random values during conversions.To address this issue,
|
||||
we implement specific handling for the following APIs. */ \
|
||||
\
|
||||
if ( ( strcmp( api_name, "u8s8s32of32" ) == 0) || ( strcmp( api_name, "s8s8s32of32" ) == 0) \
|
||||
|| ( strcmp( api_name, "u8s8s32obf16" ) == 0) || ( strcmp( api_name, "s8s8s32obf16" ) == 0) ) \
|
||||
{ \
|
||||
float* scl_fctr = ( float* )( ( post_op->matrix_mul )->scale_factor ); \
|
||||
dim_t scl_fctr_len = ( post_op->matrix_mul )->scale_factor_len; \
|
||||
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,BLAS_SFX) \
|
||||
( *( ( ACCUM_type* )( post_op->matrix_mul )->matrix + \
|
||||
( i * rs_m ) + ( j * cs_m ) ), \
|
||||
j, scl_fctr, scl_fctr_len ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float* scl_fctr = ( float* )( ( post_op->matrix_mul )->scale_factor ); \
|
||||
dim_t scl_fctr_len = ( post_op->matrix_mul )->scale_factor_len; \
|
||||
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,BLAS_SFX) \
|
||||
( *( ( C_type* )( post_op->matrix_mul )->matrix + \
|
||||
( i * rs_m ) + ( j * cs_m ) ), \
|
||||
j, scl_fctr, scl_fctr_len ); \
|
||||
} \
|
||||
float* scl_fctr = ( float* )( ( post_op->matrix_mul )->scale_factor ); \
|
||||
dim_t scl_fctr_len = ( post_op->matrix_mul )->scale_factor_len; \
|
||||
post_temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,BLAS_SFX) \
|
||||
( ( post_op->matrix_mul )->matrix, i, \
|
||||
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op->matrix_mul )->stor_type ); \
|
||||
} \
|
||||
else \
|
||||
{} \
|
||||
} \
|
||||
} \
|
||||
/* Need to convert to downscaled type if required.*/ \
|
||||
mat_mul_get_output_type_val ## ACCUM_type ## C_type \
|
||||
mat_mul_get_output_type_val ## POST_ACCUM_type ## C_type \
|
||||
( \
|
||||
&out_temp_accum, &temp_accum \
|
||||
&out_temp_accum, &post_temp_accum \
|
||||
); \
|
||||
\
|
||||
float comp_float, ref_float; \
|
||||
@@ -1198,56 +1127,50 @@ cleanup_acc: \
|
||||
return; \
|
||||
} \
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,float,u8s8s16os16,u8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,float,u8s8s16os8,u8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,uint8_t,int16_t,float,u8s8s16ou8,u8s8s16ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int16_t,int16_t,float,s8s8s16os16,s8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,float,s8s8s16os8,s8s8s16os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,float,u8s8s32os32,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,float,u8s8s32os8,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,uint8_t,int32_t,float,float,u8s8s32ou8,u8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,bfloat16,int32_t,float,float,u8s8s32obf16,u8s8s32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,float,int32_t,float,float,u8s8s32of32,u8s8s32of32)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,u8s8s32os32,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,u8s8s32os8,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,uint8_t,int32_t,float,u8s8s32ou8,u8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,bfloat16,int32_t,float,u8s8s32obf16,u8s8s32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,float,int32_t,float,u8s8s32of32,u8s8s32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,float,float,float,float,bf16s4f32of32,bf16bf16f32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,bfloat16,float,float,float,bf16s4f32obf16,bf16bf16f32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,float,float,float,bf16bf16f32of32,bf16bf16f32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,float,float,bf16bf16f32obf16,bf16bf16f32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,float,f32f32f32of32,f32f32f32of32)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,float,float,float,bf16s4f32of32,bf16bf16f32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,bfloat16,float,float,bf16s4f32obf16,bf16bf16f32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,float,float,bf16bf16f32of32,bf16bf16f32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,float,bf16bf16f32obf16,bf16bf16f32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,f32f32f32of32,f32f32f32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,int32_t,float,s8s8s32os32,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,int32_t,float,s8s8s32os8,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,bfloat16,int32_t,int32_t,float,s8s8s32obf16,s8s8s32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,float,int32_t,float,int32_t,s8s8s32of32,s8s8s32of32)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,float,s8s8s32os32,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,float,s8s8s32os8,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,bfloat16,int32_t,float,s8s8s32obf16,s8s8s32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,float,int32_t,float,s8s8s32of32,s8s8s32of32)
|
||||
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,int16_t,u8s8s16os16)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,int16_t,s8s8s16os16)
|
||||
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,u8s8s32os32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int8_t,float,int32_t,u8s8s32os8)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,uint8_t,float,int32_t,u8s8s32ou8)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,float,float,int32_t,u8s8s32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,bfloat16,float,int32_t,u8s8s32obf16)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,float,u8s8s32os32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int8_t,float,float,u8s8s32os8)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,uint8_t,float,float,u8s8s32ou8)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,float,float,float,u8s8s32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,bfloat16,float,float,u8s8s32obf16)
|
||||
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32os32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32os8)
|
||||
//GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32of32)
|
||||
//GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32obf16)
|
||||
//GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32ou8)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int8_t,float,int32_t,s8s8s32os8)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,float,float,int32_t,s8s8s32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,bfloat16,float,int32_t,s8s8s32obf16)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,uint8_t,float,int32_t,s8s8s32ou8)
|
||||
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16bf16f32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,bfloat16,bf16bf16f32obf16)
|
||||
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16s4f32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,bfloat16,bf16s4f32obf16)
|
||||
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(float,float,float,float,f32f32f32of32)
|
||||
|
||||
|
||||
|
||||
// Hack to fix compiler errors.
|
||||
#define GET_B_TYPE_bf16bf16f32of32 bfloat16
|
||||
#define GET_B_TYPE_u8s8s16os16 int8_t
|
||||
#define GET_B_TYPE_u8s8s32os32 int8_t
|
||||
#define GET_B_TYPE_f32f32f32of32 float
|
||||
#define GET_B_TYPE_s8s8s32os32 int8_t
|
||||
#define GET_B_TYPE_s8s8s16os16 int8_t
|
||||
|
||||
#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type, B_type, C_type, Sum_type, BLAS_SFX, REORDER_SFX, INT4_REORDER_SFX) \
|
||||
void mat_mul_bench_main_ ## BLAS_SFX \
|
||||
@@ -1337,7 +1260,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \
|
||||
( strcmp( post_ops_str, "none" ) != 0 ) ) || \
|
||||
( global_dscale_out == 'y' ) || ( global_pre_op == 'y' ) ) \
|
||||
{ \
|
||||
post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,REORDER_SFX)( m, n, k, post_ops_str, stor_order ); \
|
||||
post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,BLAS_SFX)( m, n, k, post_ops_str, stor_order ); \
|
||||
if ( post_op == NULL ) \
|
||||
{ \
|
||||
printf(" post op struct allocation failure, returning.\n"); \
|
||||
@@ -1403,6 +1326,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \
|
||||
if ( bench_mode == 'a' ) \
|
||||
{ \
|
||||
printf(" Running accuracy check.\n"); \
|
||||
fflush(stdout); \
|
||||
GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \
|
||||
( \
|
||||
fout, stor_order, transa, transb, m, n, k, \
|
||||
@@ -1426,9 +1350,6 @@ void mat_mul_bench_main_ ## BLAS_SFX \
|
||||
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32,bf16bf16f32of32,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16,bf16bf16f32of32,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16,u8s8s16os16,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8,u8s8s16os16,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8,u8s8s16os16,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32,u8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8,u8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8,u8s8s32os32,u8s4s32os32)
|
||||
@@ -1439,8 +1360,6 @@ GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32,s8s8s32os3
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8,s8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16,s8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32,s8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16,s8s8s16os16,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8,s8s8s16os16,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32,bf16bf16f32of32,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16,bf16bf16f32of32,bf16s4f32of32)
|
||||
|
||||
@@ -1481,11 +1400,7 @@ int main( int argc, char** argv )
|
||||
" 5. s8s8s32os32 -d s8 = s8s8s32os8.\n" \
|
||||
" 6. s8s8s32os32 -d f32 = s8s8s32of32.\n" \
|
||||
" 7. s8s8s32os32 -d bf16 = s8s8s32obf16.\n" \
|
||||
" 8. s8s8s16os16 -d s8 = s8s8s16os8.\n" \
|
||||
" 9. s8s8s16os16 -d u8 = u8s8s16ou8.\n" \
|
||||
" 10. u8s8s16os16 -d s8 = u8s8s16os8.\n" \
|
||||
" 11. u8s8s16os16 -d u8 = u8s8s16ou8.\n" \
|
||||
" 12. bf16bf16f32of32 -d bf16 = bf16bf16f32obf16.\n" \
|
||||
" 8. bf16bf16f32of32 -d bf16 = bf16bf16f32obf16.\n" \
|
||||
" Example: ./bench_lpgemm -m a -n 2 -o bias,relu -d bf16 -i input.txt\n" \
|
||||
);
|
||||
exit( 1 );
|
||||
@@ -1735,53 +1650,6 @@ int main( int argc, char** argv )
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
#if 0
|
||||
if ( ( strcmp( gemm_type_str, "u8s8s16os16" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'n';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = SHRT_MIN;
|
||||
DSCALE_CLIP_MAX = SHRT_MAX;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os16)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( gemm_type_str, "u8s8s16os8" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'y';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = -128;
|
||||
DSCALE_CLIP_MAX = +127;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os8)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( gemm_type_str, "u8s8s16ou8" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'y';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = 0;
|
||||
DSCALE_CLIP_MAX = +255;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16ou8)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
#endif
|
||||
if ( ( strcmp( gemm_type_str, "bf16bf16f32of32" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
@@ -1906,38 +1774,6 @@ int main( int argc, char** argv )
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
#if 0
|
||||
if ( ( strcmp( gemm_type_str, "s8s8s16os16" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'n';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = SHRT_MIN;
|
||||
DSCALE_CLIP_MAX = SHRT_MAX;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os16)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( gemm_type_str, "s8s8s16os8" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'y';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = -128;
|
||||
DSCALE_CLIP_MAX = +127;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os8)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,8 +36,6 @@
|
||||
|
||||
GEN_FILL_ARRAY_FUNC(float)
|
||||
|
||||
CONVERT_TO_FLOAT(float)
|
||||
|
||||
void print_result
|
||||
(
|
||||
const char* msg,
|
||||
@@ -195,13 +193,13 @@ static inline float eltwise_ops_accuracy_check_downscale_f32of32
|
||||
return out_temp_accum;
|
||||
}
|
||||
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32of32)
|
||||
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32of32)
|
||||
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float)
|
||||
|
||||
@@ -316,8 +314,7 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
|
||||
{ \
|
||||
temp_accum = GEN_FUNC_NAME(SWISH_post_op_,LP_SFX) \
|
||||
(temp_accum, \
|
||||
*( ( ACCUM_type* ) \
|
||||
( post_op->eltwise + ele_i )->algo.alpha ) );\
|
||||
( post_op->eltwise + ele_i )->algo.alpha );\
|
||||
ele_i += 1; \
|
||||
} \
|
||||
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
|
||||
@@ -375,9 +372,8 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
|
||||
float* scl_fctr = ( float* )( ( post_op->matrix_add )->scale_factor ); \
|
||||
dim_t scl_fctr_len = ( post_op->matrix_add )->scale_factor_len; \
|
||||
temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,LP_SFX) \
|
||||
( *( ( B_type* )( post_op->matrix_add )->matrix + \
|
||||
( i * rs_m ) + ( j * cs_m ) ), \
|
||||
j, scl_fctr, scl_fctr_len ); \
|
||||
( ( post_op->matrix_add )->matrix, i, \
|
||||
j, rs_m, cs_m, scl_fctr, scl_fctr_len,( post_op->matrix_add )->stor_type ); \
|
||||
} \
|
||||
else if ( post_op->seq_vector[op_id] == MATRIX_MUL ) \
|
||||
{ \
|
||||
@@ -391,9 +387,8 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
|
||||
float* scl_fctr = ( float* )( ( post_op->matrix_mul )->scale_factor ); \
|
||||
dim_t scl_fctr_len = ( post_op->matrix_mul )->scale_factor_len; \
|
||||
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,LP_SFX) \
|
||||
( *( ( B_type* )( post_op->matrix_mul )->matrix + \
|
||||
( i * rs_m ) + ( j * cs_m ) ), \
|
||||
j, scl_fctr, scl_fctr_len ); \
|
||||
( ( post_op->matrix_mul )->matrix, i, \
|
||||
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op->matrix_add )->stor_type ); \
|
||||
} \
|
||||
else \
|
||||
{} \
|
||||
|
||||
@@ -69,7 +69,7 @@ char global_pre_op = 'n';
|
||||
#define XSTR(str) _XSTR(str)
|
||||
|
||||
#define GEN_FUNC_NAME(prototype,ctype) prototype ## ctype
|
||||
|
||||
#define CVT_FUNC_NAME(stype, dtype) stype ## _to_ ## dtype
|
||||
// Inplace to lower func.
|
||||
static inline void str_tolower( char* str )
|
||||
{
|
||||
@@ -83,6 +83,22 @@ static inline void GEN_FUNC_NAME(ctype,_to_float) ( ctype val, float* float_val
|
||||
*float_val = (float) val; \
|
||||
} \
|
||||
|
||||
CONVERT_TO_FLOAT(uint8_t)
|
||||
CONVERT_TO_FLOAT(int8_t)
|
||||
CONVERT_TO_FLOAT(int16_t)
|
||||
CONVERT_TO_FLOAT(float)
|
||||
CONVERT_TO_FLOAT(int32_t)
|
||||
|
||||
|
||||
#define CONVERT_ITSELF(ctype) \
|
||||
static inline void GEN_FUNC_NAME(ctype,_to_ ## ctype) ( ctype val, ctype* ctype_val ) \
|
||||
{ \
|
||||
*ctype_val = val; \
|
||||
}
|
||||
|
||||
CONVERT_ITSELF(int16_t)
|
||||
CONVERT_ITSELF(int32_t)
|
||||
|
||||
static inline void float_to_bf16( float* float_value, bfloat16* bf16_val )
|
||||
{
|
||||
/*Set offset 2 to copy most significant 2 bytes of float
|
||||
@@ -169,6 +185,43 @@ static inline void lpgemm_free( void* p )
|
||||
}
|
||||
}
|
||||
|
||||
bool is_integerAPI_avx512( char* api_name )
|
||||
{
|
||||
if ( ( strcmp( api_name, "u8s8s32of32" ) == 0) || ( strcmp( api_name, "u8s8s32os8" ) == 0) \
|
||||
|| ( strcmp( api_name, "u8s8s32obf16" ) == 0) || ( strcmp( api_name, "u8s8s32os32" ) == 0) \
|
||||
|| ( strcmp( api_name, "u8s8s32ou8" ) == 0) ) \
|
||||
{ \
|
||||
return TRUE; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
return FALSE; \
|
||||
}
|
||||
}
|
||||
bool is_integer( char* type )
|
||||
{
|
||||
if ( ( strcmp( type, "int8_t" ) == 0 ) || ( strcmp( type, "int16_t" ) == 0 ) \
|
||||
|| ( strcmp( type, "int32_t" ) == 0 ) || ( strcmp( type, "uint8_t" ) == 0 ) ) \
|
||||
{ \
|
||||
return TRUE; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
return FALSE; \
|
||||
}
|
||||
}
|
||||
bool is_bf16API_avx512( char* api_name )
|
||||
{
|
||||
if ( ( strcmp( api_name, "bf16bf16f32of32" ) == 0) || ( strcmp( api_name, "bf16bf16f32obf16" ) == 0) \
|
||||
|| ( strcmp( api_name, "bf16s4f32of32" ) == 0) || strcmp( api_name, "bf16s4f32obf16")) \
|
||||
{ \
|
||||
return TRUE; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
return FALSE; \
|
||||
}
|
||||
}
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
/* Matrix fill helper macros. */
|
||||
#define GEN_FILL_ARRAY_FUNC(ctype) \
|
||||
@@ -237,6 +290,61 @@ static inline void fill_array_post_ops_bfloat16( void* arr, dim_t size )
|
||||
}
|
||||
|
||||
/* POST-OPS Helper macros. */
|
||||
/* CLIP */
|
||||
#define GEN_CLIP_POST_OP_VAL_INT(BLAS_SFX) \
|
||||
static inline float get_clip_post_op_val_ ## BLAS_SFX \
|
||||
( \
|
||||
float post_temp_accum, \
|
||||
void* post_op_alpha_ptr, \
|
||||
void* post_op_beta_ptr \
|
||||
) \
|
||||
{ \
|
||||
float alpha, beta; \
|
||||
int32_t_to_float(*( ( int32_t* )post_op_alpha_ptr), &alpha); \
|
||||
int32_t_to_float(*( ( int32_t* )post_op_beta_ptr), &beta); \
|
||||
return min( max( post_temp_accum, alpha),beta); \
|
||||
}
|
||||
|
||||
#define GEN_CLIP_POST_OP_VAL_FLOAT(BLAS_SFX) \
|
||||
static inline float get_clip_post_op_val_ ## BLAS_SFX \
|
||||
( \
|
||||
float post_temp_accum, \
|
||||
void* post_op_alpha_ptr, \
|
||||
void* post_op_beta_ptr \
|
||||
) \
|
||||
{ \
|
||||
return min( max( post_temp_accum, *( ( float* )post_op_alpha_ptr ) ), \
|
||||
*( ( float* )post_op_beta_ptr ) ); \
|
||||
}
|
||||
|
||||
/* PRELU */
|
||||
#define GEN_PRELU_POST_OP_VAL_FLOAT(BLAS_SFX) \
|
||||
static inline float get_prelu_post_op_val_ ## BLAS_SFX \
|
||||
( \
|
||||
float post_temp_accum, \
|
||||
void* post_op_alpha_ptr \
|
||||
) \
|
||||
{ \
|
||||
return (( post_temp_accum > 0 ) ? \
|
||||
post_temp_accum : \
|
||||
( post_temp_accum * \
|
||||
(*( float* )post_op_alpha_ptr) )); \
|
||||
}
|
||||
|
||||
#define GEN_PRELU_POST_OP_VAL_INT(BLAS_SFX) \
|
||||
static inline float get_prelu_post_op_val_ ## BLAS_SFX \
|
||||
( \
|
||||
float post_temp_accum, \
|
||||
void* post_op_alpha_ptr \
|
||||
) \
|
||||
{ \
|
||||
float ret_val; \
|
||||
int32_t_to_float(*( ( int32_t* )post_op_alpha_ptr), &ret_val); \
|
||||
\
|
||||
return ( post_temp_accum > 0 ) ? \
|
||||
post_temp_accum : \
|
||||
( post_temp_accum * ret_val ); \
|
||||
}
|
||||
|
||||
/* Bias. */
|
||||
#define GEN_GET_BIAS_POST_OP_VAL_BF16(BLAS_SFX) \
|
||||
@@ -283,39 +391,25 @@ static inline ACCUM_type get_bias_post_op_val_ ## BLAS_SFX \
|
||||
{ \
|
||||
if(bias_stor_type == AOCL_GEMM_BF16) \
|
||||
{ \
|
||||
int32_t ret_val = 0.0; \
|
||||
bfloat16_to_int32_t( *( ( bfloat16* )post_op_bias_ptr + j ), &ret_val ); \
|
||||
float ret_val = 0.0; \
|
||||
bfloat16_to_float( *( ( bfloat16* )post_op_bias_ptr + j ), &ret_val ); \
|
||||
return ret_val; \
|
||||
} \
|
||||
if(bias_stor_type == AOCL_GEMM_INT8) \
|
||||
{ \
|
||||
int32_t ret_val = 0.0; \
|
||||
int8_t_to_int32_t( *( ( int8_t* )post_op_bias_ptr + j ), &ret_val ); \
|
||||
float ret_val = 0.0; \
|
||||
int8_t_to_float( *( ( int8_t* )post_op_bias_ptr + j ), &ret_val ); \
|
||||
return ret_val; \
|
||||
} \
|
||||
if(bias_stor_type == AOCL_GEMM_F32) \
|
||||
if(bias_stor_type == AOCL_GEMM_INT32) \
|
||||
{ \
|
||||
int32_t ret_val = 0.0; \
|
||||
ret_val = (int32_t) *( ( float* )post_op_bias_ptr + j ); \
|
||||
float ret_val = 0.0; \
|
||||
int32_t_to_float( *( ( int32_t* )post_op_bias_ptr + j ), &ret_val ); \
|
||||
return ret_val; \
|
||||
} \
|
||||
return *( ( ACCUM_type* )post_op_bias_ptr + j ); \
|
||||
} \
|
||||
|
||||
/* GELU Tanh. */
|
||||
#define GEN_GELU_TANH_POSTOP_INT(ACCUM_type,BLAS_SFX) \
|
||||
static inline ACCUM_type GELU_TANH_post_op_ ## BLAS_SFX \
|
||||
( \
|
||||
ACCUM_type temp_accum \
|
||||
) \
|
||||
{ \
|
||||
float gelu_reference = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \
|
||||
( 0.044715 * ((double)temp_accum * (double)temp_accum * \
|
||||
(double)temp_accum ) ) ) ) ); \
|
||||
temp_accum = round (gelu_reference); \
|
||||
return temp_accum; \
|
||||
} \
|
||||
|
||||
#define GEN_GELU_TANH_POSTOP_FLOAT(BLAS_SFX) \
|
||||
static inline float GELU_TANH_post_op_ ## BLAS_SFX \
|
||||
( \
|
||||
@@ -329,17 +423,6 @@ static inline float GELU_TANH_post_op_ ## BLAS_SFX \
|
||||
} \
|
||||
|
||||
/* GELU Erf. */
|
||||
#define GEN_GELU_ERF_POSTOP_INT(ACCUM_type,BLAS_SFX) \
|
||||
static inline ACCUM_type GELU_ERF_post_op_ ## BLAS_SFX \
|
||||
( \
|
||||
ACCUM_type temp_accum \
|
||||
) \
|
||||
{ \
|
||||
float gelu_reference = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); \
|
||||
temp_accum = round (gelu_reference); \
|
||||
return temp_accum; \
|
||||
} \
|
||||
|
||||
#define GEN_GELU_ERF_POSTOP_FLOAT(BLAS_SFX) \
|
||||
static inline float GELU_ERF_post_op_ ## BLAS_SFX \
|
||||
( \
|
||||
@@ -351,17 +434,6 @@ static inline float GELU_ERF_post_op_ ## BLAS_SFX \
|
||||
} \
|
||||
|
||||
/* TANH. */
|
||||
#define GEN_TANH_POSTOP_INT(ACCUM_type,BLAS_SFX) \
|
||||
static inline ACCUM_type TANH_post_op_ ## BLAS_SFX \
|
||||
( \
|
||||
ACCUM_type temp_accum \
|
||||
) \
|
||||
{ \
|
||||
float tanh_reference = tanhf( ( double )temp_accum ); \
|
||||
temp_accum = round( tanh_reference ); \
|
||||
return temp_accum; \
|
||||
} \
|
||||
|
||||
#define GEN_TANH_POSTOP_FLOAT(BLAS_SFX) \
|
||||
static inline float TANH_post_op_ ## BLAS_SFX \
|
||||
( \
|
||||
@@ -373,18 +445,6 @@ static inline float TANH_post_op_ ## BLAS_SFX \
|
||||
} \
|
||||
|
||||
/* SIGMOID. */
|
||||
#define GEN_SIGMOID_POSTOP_INT(ACCUM_type,BLAS_SFX) \
|
||||
static inline ACCUM_type SIGMOID_post_op_ ## BLAS_SFX \
|
||||
( \
|
||||
ACCUM_type temp_accum \
|
||||
) \
|
||||
{ \
|
||||
float sigmoid_reference = ( 1 / ( 1 + \
|
||||
(dim_t) round( expf( temp_accum * -1 ) ) ) ); \
|
||||
temp_accum = round (sigmoid_reference); \
|
||||
return temp_accum; \
|
||||
} \
|
||||
|
||||
#define GEN_SIGMOID_POSTOP_FLOAT(BLAS_SFX) \
|
||||
static inline float SIGMOID_post_op_ ## BLAS_SFX \
|
||||
( \
|
||||
@@ -401,11 +461,13 @@ static inline float SIGMOID_post_op_ ## BLAS_SFX \
|
||||
static inline ACCUM_type SWISH_post_op_ ## BLAS_SFX \
|
||||
( \
|
||||
ACCUM_type temp_accum, \
|
||||
ACCUM_type alpha \
|
||||
void* alpha \
|
||||
) \
|
||||
{ \
|
||||
float alpha_val; \
|
||||
int32_t_to_float(*( ( int32_t* )alpha), &alpha_val); \
|
||||
float swish_reference = ( temp_accum / ( 1 + \
|
||||
expf( ( double )alpha * temp_accum * -1 ) ) ); \
|
||||
expf( ( double )(alpha_val) * temp_accum * -1 ) ) ); \
|
||||
temp_accum = round (swish_reference); \
|
||||
return temp_accum; \
|
||||
} \
|
||||
@@ -414,42 +476,25 @@ static inline ACCUM_type SWISH_post_op_ ## BLAS_SFX \
|
||||
static inline float SWISH_post_op_ ## BLAS_SFX \
|
||||
( \
|
||||
float temp_accum, \
|
||||
float alpha \
|
||||
void* alpha \
|
||||
) \
|
||||
{ \
|
||||
temp_accum = ( temp_accum / ( 1 + \
|
||||
expf( ( double )alpha * temp_accum * -1 ) ) ); \
|
||||
expf( ( double )(*((float*)alpha) * temp_accum * -1 ) ) )); \
|
||||
return temp_accum; \
|
||||
} \
|
||||
|
||||
/* Matrix Add. */
|
||||
#define GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(C_type,BLAS_SFX) \
|
||||
static inline float get_matrix_add_post_op_val_ ## BLAS_SFX \
|
||||
( \
|
||||
C_type val, \
|
||||
dim_t j, \
|
||||
float* scl_fctr, \
|
||||
dim_t scl_fctr_len \
|
||||
) \
|
||||
{ \
|
||||
float ret_val = 0.0; \
|
||||
dim_t j_scale = j; \
|
||||
if ( scl_fctr_len == 1 ) \
|
||||
{ \
|
||||
j_scale = 0; \
|
||||
} \
|
||||
\
|
||||
bfloat16_to_float( val, &ret_val ); \
|
||||
return ( ret_val * *( scl_fctr + j_scale ) ); \
|
||||
} \
|
||||
|
||||
#define GEN_GET_MATRIX_ADD_POST_OP_VAL(C_type,ACCUM_type,BLAS_SFX) \
|
||||
#define GEN_GET_MATRIX_ADD_POST_OP_VAL(ACCUM_type,BLAS_SFX) \
|
||||
static inline ACCUM_type get_matrix_add_post_op_val_ ## BLAS_SFX \
|
||||
( \
|
||||
C_type val, \
|
||||
void* mat_add_ptr, \
|
||||
dim_t i, \
|
||||
dim_t j, \
|
||||
dim_t rs_m, \
|
||||
dim_t cs_m, \
|
||||
float* scl_fctr, \
|
||||
dim_t scl_fctr_len \
|
||||
dim_t scl_fctr_len, \
|
||||
AOCL_PARAMS_STORAGE_TYPES matadd_stor_type \
|
||||
) \
|
||||
{ \
|
||||
dim_t j_scale = j; \
|
||||
@@ -457,36 +502,94 @@ static inline ACCUM_type get_matrix_add_post_op_val_ ## BLAS_SFX \
|
||||
{ \
|
||||
j_scale = 0; \
|
||||
} \
|
||||
return (ACCUM_type) ( ( float )val * *( scl_fctr + j_scale ) ); \
|
||||
if( matadd_stor_type == AOCL_GEMM_BF16 ) \
|
||||
{ \
|
||||
float ret_val = 0.0; \
|
||||
bfloat16 val = *( ( bfloat16* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ); \
|
||||
bfloat16_to_float( val, &ret_val ); \
|
||||
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
|
||||
} \
|
||||
if( matadd_stor_type == AOCL_GEMM_INT8 ) \
|
||||
{ \
|
||||
float ret_val = 0.0; \
|
||||
int8_t_to_float( *( ( int8_t* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ), &ret_val ); \
|
||||
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
|
||||
} \
|
||||
if( matadd_stor_type == AOCL_GEMM_INT32 ) \
|
||||
{ \
|
||||
float ret_val = 0.0; \
|
||||
int32_t_to_float( *( ( int32_t* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ), &ret_val ); \
|
||||
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
|
||||
} \
|
||||
if( matadd_stor_type == AOCL_GEMM_F32 ) \
|
||||
{ \
|
||||
float ret_val = 0.0; \
|
||||
ret_val = *( ( float* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ); \
|
||||
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
|
||||
} \
|
||||
/* default case */ \
|
||||
if( is_integerAPI_avx512(#BLAS_SFX) ) \
|
||||
{ \
|
||||
if( strcmp( #BLAS_SFX, "u8s8s32os8" ) == 0 ) \
|
||||
{ \
|
||||
float ret_val = 0.0; \
|
||||
int8_t_to_float( *( ( int8_t* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ), &ret_val ); \
|
||||
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
|
||||
} \
|
||||
float ret_val = 0.0; \
|
||||
int32_t_to_float( *( ( int32_t* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ), &ret_val ); \
|
||||
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
float ret_val = 0.0; \
|
||||
bfloat16 val = *( ( bfloat16* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ); \
|
||||
bfloat16_to_float( val, &ret_val ); \
|
||||
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
|
||||
} \
|
||||
float ret_val = 0.0; \
|
||||
ret_val = *( ( float* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ); \
|
||||
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(C_type,BLAS_SFX) \
|
||||
#define GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(BLAS_SFX) \
|
||||
static inline float get_matrix_mul_post_op_val_ ## BLAS_SFX \
|
||||
( \
|
||||
C_type val, \
|
||||
void* mat_add_ptr, \
|
||||
dim_t i, \
|
||||
dim_t j, \
|
||||
dim_t rs_m, \
|
||||
dim_t cs_m, \
|
||||
float* scl_fctr, \
|
||||
dim_t scl_fctr_len \
|
||||
dim_t scl_fctr_len, \
|
||||
AOCL_PARAMS_STORAGE_TYPES matadd_stor_type \
|
||||
) \
|
||||
{ \
|
||||
return GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
|
||||
( \
|
||||
val, j, scl_fctr, scl_fctr_len \
|
||||
mat_add_ptr, i, j, rs_m, cs_m, scl_fctr, scl_fctr_len, matadd_stor_type \
|
||||
); \
|
||||
} \
|
||||
|
||||
#define GEN_GET_MATRIX_MUL_POST_OP_VAL(C_type,ACCUM_type,BLAS_SFX) \
|
||||
#define GEN_GET_MATRIX_MUL_POST_OP_VAL(ACCUM_type,BLAS_SFX) \
|
||||
static inline ACCUM_type get_matrix_mul_post_op_val_ ## BLAS_SFX \
|
||||
( \
|
||||
C_type val, \
|
||||
void* mat_add_ptr, \
|
||||
dim_t i, \
|
||||
dim_t j, \
|
||||
dim_t rs_m, \
|
||||
dim_t cs_m, \
|
||||
float* scl_fctr, \
|
||||
dim_t scl_fctr_len \
|
||||
dim_t scl_fctr_len, \
|
||||
AOCL_PARAMS_STORAGE_TYPES matadd_stor_type \
|
||||
) \
|
||||
{ \
|
||||
return GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
|
||||
( \
|
||||
val, j, scl_fctr, scl_fctr_len \
|
||||
mat_add_ptr, i, j, rs_m, cs_m, scl_fctr, scl_fctr_len, matadd_stor_type \
|
||||
); \
|
||||
} \
|
||||
|
||||
@@ -586,6 +689,7 @@ static inline void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops )
|
||||
if ( post_ops->matrix_add != NULL )
|
||||
{
|
||||
free( ( post_ops->matrix_add )->matrix );
|
||||
//free( ( post_ops->matrix_add )->scale_factor );
|
||||
free( post_ops->matrix_add );
|
||||
}
|
||||
|
||||
@@ -599,6 +703,7 @@ static inline void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops )
|
||||
if ( post_ops->matrix_mul != NULL )
|
||||
{
|
||||
free( ( post_ops->matrix_mul )->matrix );
|
||||
//free( ( post_ops->matrix_mul )->scale_factor );
|
||||
free( post_ops->matrix_mul );
|
||||
}
|
||||
|
||||
@@ -770,6 +875,10 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
dim_t activator_idx = 0; \
|
||||
dim_t clip_idx = 0; \
|
||||
char * bias_stor_type = ""; \
|
||||
bool is_matadd_stor_type = FALSE; \
|
||||
char* matadd_stor_type = ""; \
|
||||
bool is_matmul_stor_type = FALSE; \
|
||||
char* matmul_stor_type = ""; \
|
||||
bool is_group_quant = FALSE; \
|
||||
bool is_pre_op_scale_scalar = FALSE; \
|
||||
bool is_pre_op_scale_f32 = TRUE; \
|
||||
@@ -899,12 +1008,62 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
else if ( strcmp( ops_tok, "matrix_add" ) == 0 ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = MATRIX_ADD; \
|
||||
ops_tok = strtok( NULL, ", " ); \
|
||||
if( ( strcmp( ops_tok, "na" ) == 0 ) ) \
|
||||
{ \
|
||||
is_matadd_stor_type = FALSE; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "f32" ) == 0 ) ) \
|
||||
{ \
|
||||
is_matadd_stor_type = TRUE; \
|
||||
matadd_stor_type = "F32"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "bf16" ) == 0 ) ) \
|
||||
{ \
|
||||
is_matadd_stor_type = TRUE; \
|
||||
matadd_stor_type = "BF16"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "s32" ) == 0 ) ) \
|
||||
{ \
|
||||
is_matadd_stor_type = TRUE; \
|
||||
matadd_stor_type = "S32"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "s8" ) == 0 ) ) \
|
||||
{ \
|
||||
is_matadd_stor_type = TRUE; \
|
||||
matadd_stor_type = "S8"; \
|
||||
} \
|
||||
is_matrix_add = TRUE; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
else if ( strcmp( ops_tok, "matrix_mul" ) == 0 ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = MATRIX_MUL; \
|
||||
ops_tok = strtok( NULL, ", " ); \
|
||||
if( ( strcmp( ops_tok, "na" ) == 0 ) ) \
|
||||
{ \
|
||||
is_matmul_stor_type = FALSE; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "f32" ) == 0 ) ) \
|
||||
{ \
|
||||
is_matmul_stor_type = TRUE; \
|
||||
matmul_stor_type = "F32"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "bf16" ) == 0 ) ) \
|
||||
{ \
|
||||
is_matmul_stor_type = TRUE; \
|
||||
matmul_stor_type = "BF16"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "s32" ) == 0 ) ) \
|
||||
{ \
|
||||
is_matmul_stor_type = TRUE; \
|
||||
matmul_stor_type = "S32"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "s8" ) == 0 ) ) \
|
||||
{ \
|
||||
is_matmul_stor_type = TRUE; \
|
||||
matmul_stor_type = "S8"; \
|
||||
} \
|
||||
is_matrix_mul = TRUE; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
@@ -988,32 +1147,50 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
\
|
||||
if ( is_bias == TRUE ) \
|
||||
{ \
|
||||
/* Allocate bias buffer, return early if alloc fails.*/ \
|
||||
( post_ops->bias )->bias = malloc( n * sizeof( C_type ) ); \
|
||||
if ( ( post_ops->bias )->bias == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
if(is_bias_stor_type == TRUE) \
|
||||
{ \
|
||||
if( ( strcmp( bias_stor_type, "BF16" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_BF16; \
|
||||
/* Allocate bias buffer, return early if alloc fails.*/ \
|
||||
( post_ops->bias )->bias = malloc( n * sizeof( bfloat16 ) ); \
|
||||
if ( ( post_ops->bias )->bias == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,bfloat16)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
else if( ( strcmp( bias_stor_type, "F32" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_F32; \
|
||||
/* Allocate bias buffer, return early if alloc fails.*/ \
|
||||
( post_ops->bias )->bias = malloc( n * sizeof( float ) ); \
|
||||
if ( ( post_ops->bias )->bias == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,float)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
else if( ( strcmp( bias_stor_type, "S8" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_INT8; \
|
||||
/* Allocate bias buffer, return early if alloc fails.*/ \
|
||||
( post_ops->bias )->bias = malloc( n * sizeof( int8_t ) ); \
|
||||
if ( ( post_ops->bias )->bias == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,int8_t)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
else if( ( strcmp( bias_stor_type, "S32" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_INT32; \
|
||||
/* Allocate bias buffer, return early if alloc fails.*/ \
|
||||
( post_ops->bias )->bias = malloc( n * sizeof( int32_t ) ); \
|
||||
if ( ( post_ops->bias )->bias == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,int32_t)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
else {} \
|
||||
@@ -1021,17 +1198,13 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->bias )->stor_type = NULLTYPE; \
|
||||
if( global_dscale_out == 'y') \
|
||||
/* Allocate bias buffer, return early if alloc fails.*/ \
|
||||
( post_ops->bias )->bias = malloc( n * sizeof( BIAS_type ) ); \
|
||||
if ( ( post_ops->bias )->bias == NULL ) \
|
||||
{ \
|
||||
if ( strcmp(#BIAS_type, "bfloat16") == 0 ) { \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_BF16; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,BIAS_type)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,C_type)( ( post_ops->bias )->bias, n ); \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,BIAS_type)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
@@ -1078,12 +1251,27 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
|
||||
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \
|
||||
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
|
||||
/* If output is float/bfloat16, param type will be float otherwise s32 */ \
|
||||
if( is_integer(#C_type) ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( int32_t ) ); \
|
||||
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
*( ( int32_t* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( int32_t )6; \
|
||||
} \
|
||||
*( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )6; \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( float ) ); \
|
||||
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
*( ( float* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( float )6; \
|
||||
} \
|
||||
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.algo_type = PRELU; \
|
||||
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.algo_type = PRELU; \
|
||||
} \
|
||||
@@ -1092,12 +1280,25 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
|
||||
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \
|
||||
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
|
||||
/* If output is float/bfloat16, params type will be float otherwise s32 */ \
|
||||
if( is_integer(#C_type) ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( int32_t ) ); \
|
||||
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
*( ( int32_t* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( int32_t )2; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( float ) ); \
|
||||
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
*( ( float* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( float )2; \
|
||||
} \
|
||||
*( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )2; \
|
||||
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.algo_type = SWISH; \
|
||||
} \
|
||||
@@ -1123,18 +1324,37 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
( post_ops->eltwise + clip_idx )->scale_factor = NULL; \
|
||||
( post_ops->eltwise + clip_idx )->algo.alpha = NULL; \
|
||||
( post_ops->eltwise + clip_idx )->algo.beta = NULL; \
|
||||
( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( C_type ) ); \
|
||||
if ( ( post_ops->eltwise + clip_idx )->algo.alpha == NULL ) \
|
||||
/* If output is float/bfloat16, params type will be float otherwise s32 */ \
|
||||
if( is_integer(#C_type) ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( int32_t ) ); \
|
||||
if ( ( post_ops->eltwise + clip_idx )->algo.alpha == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( int32_t ) ); \
|
||||
if ( ( post_ops->eltwise + clip_idx )->algo.beta == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
*( ( int32_t* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( int32_t ) ( -64 ); \
|
||||
*( ( int32_t* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( int32_t ) ( 23 ); \
|
||||
} \
|
||||
( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( C_type ) ); \
|
||||
if ( ( post_ops->eltwise + clip_idx )->algo.beta == NULL ) \
|
||||
else \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( float ) ); \
|
||||
if ( ( post_ops->eltwise + clip_idx )->algo.alpha == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( float ) ); \
|
||||
if ( ( post_ops->eltwise + clip_idx )->algo.beta == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
*( ( float* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( float ) ( -64 ); \
|
||||
*( ( float* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( float ) ( 23 ); \
|
||||
} \
|
||||
*( ( C_type* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( C_type ) ( -64 ); \
|
||||
*( ( C_type* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( C_type ) ( 23 ); \
|
||||
( post_ops->eltwise + clip_idx )->algo.algo_type = CLIP; \
|
||||
} \
|
||||
else if ( is_tanh == TRUE ) \
|
||||
@@ -1199,28 +1419,99 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
\
|
||||
if ( is_matrix_add == TRUE ) \
|
||||
{ \
|
||||
/* Allocate add matrix buffer, return early if alloc fails.*/ \
|
||||
dim_t ele_dsize = 0; \
|
||||
if ( global_dscale_out == 'y' ) \
|
||||
if( is_matadd_stor_type == TRUE) \
|
||||
{ \
|
||||
ele_dsize = sizeof( C_DSCALE_type ); \
|
||||
if( ( strcmp( matadd_stor_type, "BF16" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->matrix_add )->stor_type = AOCL_GEMM_BF16; \
|
||||
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(bfloat16) ); \
|
||||
if ( ( post_ops->matrix_add )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,bfloat16)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else if( ( strcmp( matadd_stor_type, "F32" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->matrix_add )->stor_type = AOCL_GEMM_F32; \
|
||||
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(float) ); \
|
||||
if ( ( post_ops->matrix_add )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,float)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else if( ( strcmp( matadd_stor_type, "S32" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->matrix_add )->stor_type = AOCL_GEMM_INT32; \
|
||||
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(int32_t) ); \
|
||||
if ( ( post_ops->matrix_add )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,int32_t)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else if( ( strcmp( matadd_stor_type, "S8" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->matrix_add )->stor_type = AOCL_GEMM_INT8; \
|
||||
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(int8_t) ); \
|
||||
if ( ( post_ops->matrix_add )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,int8_t)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else {} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
ele_dsize = sizeof( C_type ); \
|
||||
} \
|
||||
( post_ops->matrix_add )->matrix = malloc( m * n * ele_dsize ); \
|
||||
if ( ( post_ops->matrix_add )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
if ( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
|
||||
/* default is int32_t for integer APIs and float for others */ \
|
||||
if( is_integerAPI_avx512(#BLAS_SFX)) \
|
||||
{ \
|
||||
if( strcmp(#C_type, "int8_t") == 0 ) \
|
||||
{ \
|
||||
( post_ops->matrix_add )->stor_type = NULLTYPE; \
|
||||
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(int8_t) ); \
|
||||
if ( ( post_ops->matrix_add )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,int8_t)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->matrix_add )->stor_type = NULLTYPE; \
|
||||
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(int32_t) ); \
|
||||
if ( ( post_ops->matrix_add )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,int32_t)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
( post_ops->matrix_add )->stor_type = NULLTYPE; \
|
||||
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(C_DSCALE_type) ); \
|
||||
if ( ( post_ops->matrix_add )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->matrix_add )->stor_type = NULLTYPE; \
|
||||
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(float) ); \
|
||||
if ( ( post_ops->matrix_add )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,float)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \
|
||||
{ \
|
||||
@@ -1245,35 +1536,76 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
temp_dscale_ptr[i] = ( ( DSCALE_type )2 ); \
|
||||
} \
|
||||
( post_ops->matrix_add )->scale_factor_len = n_scale; \
|
||||
/* Set buffer type same as c_store type for now.
|
||||
* TODO: Update to cover more data types. */ \
|
||||
( post_ops->matrix_add )->stor_type = NULLTYPE; \
|
||||
} \
|
||||
\
|
||||
if ( is_matrix_mul == TRUE ) \
|
||||
{ \
|
||||
/* Allocate mul matrix buffer, return early if alloc fails.*/ \
|
||||
dim_t ele_dsize = 0; \
|
||||
if ( global_dscale_out == 'y' ) \
|
||||
if( is_matmul_stor_type == TRUE) \
|
||||
{ \
|
||||
ele_dsize = sizeof( C_DSCALE_type ); \
|
||||
if( ( strcmp( matmul_stor_type, "BF16" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->matrix_mul )->stor_type = AOCL_GEMM_BF16; \
|
||||
( post_ops->matrix_mul )->matrix = malloc( m * n * sizeof(bfloat16) ); \
|
||||
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,bfloat16)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else if( ( strcmp( matmul_stor_type, "F32" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->matrix_mul )->stor_type = AOCL_GEMM_F32; \
|
||||
( post_ops->matrix_mul )->matrix = malloc( m * n * sizeof(float) ); \
|
||||
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,float)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else if( ( strcmp( matmul_stor_type, "S32" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->matrix_mul )->stor_type = AOCL_GEMM_INT32; \
|
||||
( post_ops->matrix_mul )->matrix = malloc( m * n * sizeof(int32_t) ); \
|
||||
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,int32_t)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else if( ( strcmp( matmul_stor_type, "S8" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->matrix_mul )->stor_type = AOCL_GEMM_INT8; \
|
||||
( post_ops->matrix_mul )->matrix = malloc( m * n * sizeof(int8_t) ); \
|
||||
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,int8_t)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else {} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
ele_dsize = sizeof( C_type ); \
|
||||
} \
|
||||
( post_ops->matrix_mul )->matrix = malloc( m * n * ele_dsize ); \
|
||||
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
if ( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
|
||||
if( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
( post_ops->matrix_mul )->stor_type = NULLTYPE; \
|
||||
( post_ops->matrix_mul )->matrix = malloc( m * n * sizeof(C_DSCALE_type) ); \
|
||||
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->matrix_mul )->stor_type = NULLTYPE; \
|
||||
( post_ops->matrix_mul )->matrix = malloc( m * n * sizeof(float) ); \
|
||||
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,float)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
} \
|
||||
if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \
|
||||
{ \
|
||||
|
||||
@@ -1200,8 +1200,7 @@
|
||||
|
||||
// BF16 buffer for matrix add/mul in u8s8s32.
|
||||
#define BF16_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
|
||||
scr = _mm512_cvtepi32_ps( \
|
||||
_mm512_sllv_epi32 \
|
||||
scr =(__m512)(_mm512_sllv_epi32 \
|
||||
( \
|
||||
_mm512_cvtepi16_epi32 \
|
||||
( \
|
||||
@@ -1212,8 +1211,7 @@
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
), _mm512_set1_epi32( 16 ) \
|
||||
) \
|
||||
); \
|
||||
) );\
|
||||
scr = _mm512_mul_ps(scr, scl_fct ); \
|
||||
|
||||
#define BF16_MATRIX_ADD_1COL_PAR(mask,scr0,scl_fct0,m_ind) \
|
||||
|
||||
Reference in New Issue
Block a user