Modified bench to test different types of post-ops

- Modified bench to support testing of different types of buffers
  for bias, mat_add and mat_mul postops.
- Added support for testing integer APIs with float accumulation
  type.

Change-Id: I72364e9ad25e6148042b93ec6d152ff82ea03e96
This commit is contained in:
Meghana Vankadari
2025-02-03 05:52:11 +05:30
parent 0701a4388a
commit 13e7ada3f2
5 changed files with 740 additions and 645 deletions

View File

@@ -36,12 +36,6 @@
#define POST_OPS_STR_LEN 104
CONVERT_TO_FLOAT(uint8_t)
CONVERT_TO_FLOAT(int8_t)
CONVERT_TO_FLOAT(int16_t)
CONVERT_TO_FLOAT(float)
CONVERT_TO_FLOAT(int32_t)
PRINT_MATRIX(uint8_t)
PRINT_MATRIX(int8_t)
PRINT_MATRIX(int16_t)
@@ -256,11 +250,8 @@ static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX
return out_temp_accum; \
}\
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,u8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int16_t,float,u8s8s16ou8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,s8s8s16os8)
static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16
(
@@ -347,14 +338,9 @@ static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
return temp_accum; \
} \
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8)
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8)
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16)
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int16_t,s8s8s16os8)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int16_t,int16_t,s8s8s16os16)
#define GEN_MAT_MUL_ACC_CHK_ACCUM_INT4(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \
static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
@@ -640,15 +626,10 @@ static inline float mat_mul_accuracy_check_accum_bf16s4f32obf16
return temp_accum;
}
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os8)
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os16)
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os8)
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os32)
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os8)
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os32)
GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os8)
GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os16)
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os8)
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os32)
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32os8)
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32os32)
GEN_GELU_TANH_POSTOP_FLOAT(f32f32f32of32)
GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32of32)
@@ -656,15 +637,10 @@ GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_GELU_TANH_POSTOP_FLOAT(bf16s4f32of32)
GEN_GELU_TANH_POSTOP_FLOAT(bf16s4f32obf16)
GEN_TANH_POSTOP_INT(int16_t,u8s8s16os8)
GEN_TANH_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_TANH_POSTOP_INT(int16_t,u8s8s16os16)
GEN_TANH_POSTOP_INT(int32_t,u8s8s32os8)
GEN_TANH_POSTOP_INT(int32_t,u8s8s32os32)
GEN_TANH_POSTOP_INT(int32_t,s8s8s32os8)
GEN_TANH_POSTOP_INT(int32_t,s8s8s32os32)
GEN_TANH_POSTOP_INT(int16_t,s8s8s16os8)
GEN_TANH_POSTOP_INT(int16_t,s8s8s16os16)
GEN_TANH_POSTOP_FLOAT(u8s8s32os8)
GEN_TANH_POSTOP_FLOAT(u8s8s32os32)
GEN_TANH_POSTOP_FLOAT(s8s8s32os8)
GEN_TANH_POSTOP_FLOAT(s8s8s32os32)
GEN_TANH_POSTOP_FLOAT(f32f32f32of32)
GEN_TANH_POSTOP_FLOAT(bf16bf16f32of32)
@@ -672,15 +648,10 @@ GEN_TANH_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_TANH_POSTOP_FLOAT(bf16s4f32of32)
GEN_TANH_POSTOP_FLOAT(bf16s4f32obf16)
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os8)
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os16)
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os8)
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os32)
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os8)
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os32)
GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os8)
GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os16)
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32os8)
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32os32)
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32os8)
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32os32)
GEN_GELU_ERF_POSTOP_FLOAT(f32f32f32of32)
GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32of32)
@@ -688,15 +659,10 @@ GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_GELU_ERF_POSTOP_FLOAT(bf16s4f32of32)
GEN_GELU_ERF_POSTOP_FLOAT(bf16s4f32obf16)
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16os8)
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16os16)
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32os8)
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32os32)
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32os8)
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32os32)
GEN_SWISH_POSTOP_INT(int16_t,s8s8s16os8)
GEN_SWISH_POSTOP_INT(int16_t,s8s8s16os16)
GEN_SWISH_POSTOP_FLOAT(u8s8s32os8)
GEN_SWISH_POSTOP_FLOAT(u8s8s32os32)
GEN_SWISH_POSTOP_FLOAT(s8s8s32os8)
GEN_SWISH_POSTOP_FLOAT(s8s8s32os32)
GEN_SWISH_POSTOP_FLOAT(f32f32f32of32)
GEN_SWISH_POSTOP_FLOAT(bf16bf16f32of32)
@@ -704,15 +670,10 @@ GEN_SWISH_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_SWISH_POSTOP_FLOAT(bf16s4f32of32)
GEN_SWISH_POSTOP_FLOAT(bf16s4f32obf16)
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16os8)
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16os16)
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32os8)
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32os32)
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32os8)
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32os32)
GEN_SIGMOID_POSTOP_INT(int16_t,s8s8s16os8)
GEN_SIGMOID_POSTOP_INT(int16_t,s8s8s16os16)
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32os8)
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32os32)
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32os8)
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32os32)
GEN_SIGMOID_POSTOP_FLOAT(f32f32f32of32)
GEN_SIGMOID_POSTOP_FLOAT(bf16bf16f32of32)
@@ -720,50 +681,35 @@ GEN_SIGMOID_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_SIGMOID_POSTOP_FLOAT(bf16s4f32of32)
GEN_SIGMOID_POSTOP_FLOAT(bf16s4f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16bf16f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16s4f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16bf16f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16s4f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int32_t,u8s8s32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,int32_t,u8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int16_t,u8s8s16os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(uint8_t,int16_t,u8s8s16ou8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int16_t,int16_t,u8s8s16os16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int32_t,s8s8s32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,int32_t,s8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int16_t,s8s8s16os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int16_t,int16_t,s8s8s16os16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,f32f32f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16bf16f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16s4f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,u8s8s32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,u8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32f32f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16bf16f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16s4f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16bf16f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16s4f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16bf16f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16s4f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int32_t,u8s8s32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,int32_t,u8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int16_t,u8s8s16os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(uint8_t,int16_t,u8s8s16ou8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int16_t,int16_t,u8s8s16os16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int32_t,s8s8s32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,int32_t,s8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int16_t,s8s8s16os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int16_t,int16_t,s8s8s16os16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,f32f32f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16bf16f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16s4f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,u8s8s32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,u8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32f32f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16bf16f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16s4f32of32)
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16bf16f32obf16)
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16s4f32obf16)
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32os8)
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16os8)
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16ou8)
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16os16)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os8)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os8)
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os16)
GEN_GET_BIAS_POST_OP_VAL_f32(f32f32f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
@@ -916,8 +862,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
{ \
temp_accum = GEN_FUNC_NAME(SWISH_post_op_,BLAS_SFX) \
(temp_accum, \
*( ( ACCUM_type* ) \
( post_op[bs_i]->eltwise + ele_i )->algo.alpha ) );\
( post_op[bs_i]->eltwise + ele_i )->algo.alpha );\
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
@@ -975,9 +920,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
float* scl_fctr = ( float* )( ( post_op[bs_i]->matrix_add )->scale_factor ); \
dim_t scl_fctr_len = ( post_op[bs_i]->matrix_add )->scale_factor_len; \
temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
( *( ( C_type* )( post_op[bs_i]->matrix_add )->matrix + \
( i * rs_m ) + ( j * cs_m ) ), \
j, scl_fctr, scl_fctr_len ); \
( ( post_op[bs_i]->matrix_add )->matrix, i, \
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op[bs_i]->matrix_add)->stor_type ); \
} \
else if ( post_op[bs_i]->seq_vector[op_id] == MATRIX_MUL ) \
{ \
@@ -991,9 +935,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
float* scl_fctr = ( float* )( ( post_op[bs_i]->matrix_mul )->scale_factor ); \
dim_t scl_fctr_len = ( post_op[bs_i]->matrix_mul )->scale_factor_len; \
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,BLAS_SFX) \
( *( ( C_type* )( post_op[bs_i]->matrix_mul )->matrix + \
( i * rs_m ) + ( j * cs_m ) ), \
j, scl_fctr, scl_fctr_len ); \
(( post_op[bs_i]->matrix_mul )->matrix, i, \
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op[bs_i]->matrix_mul)->stor_type ); \
} \
else \
{} \
@@ -1031,9 +974,6 @@ cleanup_acc: \
return; \
} \
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,float,u8s8s16os16,u8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,float,u8s8s16os8,u8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,uint8_t,int16_t,float,u8s8s16ou8,u8s8s16ou8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,u8s8s32os32,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,u8s8s32os8,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,float,float,bf16bf16f32of32,bf16bf16f32obf16)
@@ -1041,26 +981,20 @@ GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,float,bf16bf16f32o
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,f32f32f32of32,f32f32f32of32)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,float,s8s8s32os32,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,float,s8s8s32os8,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int16_t,int16_t,float,s8s8s16os16,s8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,float,s8s8s16os8,s8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,float,float,float,bf16s4f32of32,bf16bf16f32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,bfloat16,float,float,bf16s4f32obf16,bf16bf16f32obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,int16_t,u8s8s16os16)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,u8s8s32os32)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16bf16f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(float,float,float,float,f32f32f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32os32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,int16_t,s8s8s16os16)
// Hack to fix compiler errors.
#define GET_B_TYPE_bf16bf16f32of32 bfloat16
#define GET_B_TYPE_u8s8s16os16 int8_t
#define GET_B_TYPE_u8s8s32os32 int8_t
#define GET_B_TYPE_f32f32f32of32 float
#define GET_B_TYPE_s8s8s32os32 int8_t
#define GET_B_TYPE_s8s8s16os16 int8_t
#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type, B_type, C_type, Sum_type, BLAS_SFX, REORDER_SFX, INT4_REORDER_SFX) \
void mat_mul_bench_main_ ## BLAS_SFX \

View File

@@ -34,20 +34,12 @@
#include "bench_lpgemm_helpers.h"
CONVERT_TO_FLOAT(uint8_t)
CONVERT_TO_FLOAT(int8_t)
CONVERT_TO_FLOAT(int16_t)
CONVERT_TO_FLOAT(float)
CONVERT_TO_FLOAT(int32_t)
PRINT_MATRIX(uint8_t)
PRINT_MATRIX(int8_t)
PRINT_MATRIX(int16_t)
PRINT_MATRIX(float)
PRINT_MATRIX(int32_t)
GEN_FILL_ARRAY_FUNC(int8_t)
GEN_FILL_ARRAY_FUNC(int16_t)
GEN_FILL_ARRAY_FUNC(float)
GEN_FILL_ARRAY_FUNC(int32_t)
@@ -106,9 +98,6 @@ void mat_mul_ ## BLAS_SFX \
c, ldc, post_op ); \
} \
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16)
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8)
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8)
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8)
@@ -121,8 +110,6 @@ GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8)
GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
@@ -207,9 +194,6 @@ void mat_mul_bench_driver_ ## BLAS_SFX \
print_result( XSTR(BLAS_SFX), n_repeats, transa, transb, m, n, k, lda, ldb, ldc, gflops); \
} \
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16)
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8)
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8)
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8)
@@ -222,8 +206,6 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
@@ -257,9 +239,6 @@ static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX
return out_temp_accum; \
}\
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,u8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int16_t,float,u8s8s16ou8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,s8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,u8s8s32ou8)
@@ -357,12 +336,6 @@ static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
return temp_accum; \
} \
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8)
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8)
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int16_t,s8s8s16os8)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int16_t,int16_t,s8s8s16os16)
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
@@ -736,20 +709,15 @@ static inline float mat_mul_accuracy_check_accum_bf16s4f32obf16
return temp_accum;
}
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os8)
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os16)
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os8)
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32ou8)
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os32)
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32obf16)
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32of32)
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os8)
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os32)
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32obf16)
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32of32)
GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os8)
GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os16)
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os8)
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32ou8)
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os32)
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32obf16)
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32of32)
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32os8)
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32os32)
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32obf16)
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32of32)
GEN_GELU_TANH_POSTOP_FLOAT(f32f32f32of32)
GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32of32)
@@ -757,20 +725,15 @@ GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_GELU_TANH_POSTOP_FLOAT(bf16s4f32of32)
GEN_GELU_TANH_POSTOP_FLOAT(bf16s4f32obf16)
GEN_TANH_POSTOP_INT(int16_t,u8s8s16os8)
GEN_TANH_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_TANH_POSTOP_INT(int16_t,u8s8s16os16)
GEN_TANH_POSTOP_INT(int32_t,u8s8s32os8)
GEN_TANH_POSTOP_INT(int32_t,u8s8s32ou8)
GEN_TANH_POSTOP_INT(int32_t,u8s8s32os32)
GEN_TANH_POSTOP_INT(int32_t,u8s8s32obf16)
GEN_TANH_POSTOP_INT(int32_t,u8s8s32of32)
GEN_TANH_POSTOP_INT(int32_t,s8s8s32os8)
GEN_TANH_POSTOP_INT(int32_t,s8s8s32obf16)
GEN_TANH_POSTOP_INT(int32_t,s8s8s32of32)
GEN_TANH_POSTOP_INT(int32_t,s8s8s32os32)
GEN_TANH_POSTOP_INT(int16_t,s8s8s16os8)
GEN_TANH_POSTOP_INT(int16_t,s8s8s16os16)
GEN_TANH_POSTOP_FLOAT(u8s8s32os8)
GEN_TANH_POSTOP_FLOAT(u8s8s32ou8)
GEN_TANH_POSTOP_FLOAT(u8s8s32os32)
GEN_TANH_POSTOP_FLOAT(u8s8s32obf16)
GEN_TANH_POSTOP_FLOAT(u8s8s32of32)
GEN_TANH_POSTOP_FLOAT(s8s8s32os8)
GEN_TANH_POSTOP_FLOAT(s8s8s32obf16)
GEN_TANH_POSTOP_FLOAT(s8s8s32of32)
GEN_TANH_POSTOP_FLOAT(s8s8s32os32)
GEN_TANH_POSTOP_FLOAT(f32f32f32of32)
GEN_TANH_POSTOP_FLOAT(bf16bf16f32of32)
@@ -778,20 +741,15 @@ GEN_TANH_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_TANH_POSTOP_FLOAT(bf16s4f32of32)
GEN_TANH_POSTOP_FLOAT(bf16s4f32obf16)
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os8)
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os16)
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os8)
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32ou8)
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os32)
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32obf16)
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32of32)
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os8)
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os32)
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32obf16)
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32of32)
GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os8)
GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os16)
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32os8)
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32ou8)
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32os32)
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32obf16)
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32of32)
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32os8)
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32os32)
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32obf16)
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32of32)
GEN_GELU_ERF_POSTOP_FLOAT(f32f32f32of32)
GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32of32)
@@ -799,20 +757,15 @@ GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_GELU_ERF_POSTOP_FLOAT(bf16s4f32of32)
GEN_GELU_ERF_POSTOP_FLOAT(bf16s4f32obf16)
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16os8)
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16os16)
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32os8)
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32ou8)
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32os32)
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32obf16)
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32of32)
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32os8)
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32os32)
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32obf16)
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32of32)
GEN_SWISH_POSTOP_INT(int16_t,s8s8s16os8)
GEN_SWISH_POSTOP_INT(int16_t,s8s8s16os16)
GEN_SWISH_POSTOP_INT(float,u8s8s32os8)
GEN_SWISH_POSTOP_INT(float,u8s8s32ou8)
GEN_SWISH_POSTOP_INT(float,u8s8s32os32)
GEN_SWISH_POSTOP_FLOAT(u8s8s32obf16)
GEN_SWISH_POSTOP_FLOAT(u8s8s32of32)
GEN_SWISH_POSTOP_INT(float,s8s8s32os8)
GEN_SWISH_POSTOP_INT(float,s8s8s32os32)
GEN_SWISH_POSTOP_FLOAT(s8s8s32obf16)
GEN_SWISH_POSTOP_FLOAT(s8s8s32of32)
GEN_SWISH_POSTOP_FLOAT(f32f32f32of32)
GEN_SWISH_POSTOP_FLOAT(bf16bf16f32of32)
@@ -820,20 +773,15 @@ GEN_SWISH_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_SWISH_POSTOP_FLOAT(bf16s4f32of32)
GEN_SWISH_POSTOP_FLOAT(bf16s4f32obf16)
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16os8)
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16os16)
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32os8)
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32ou8)
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32os32)
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32obf16)
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32of32)
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32os8)
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32os32)
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32obf16)
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32of32)
GEN_SIGMOID_POSTOP_INT(int16_t,s8s8s16os8)
GEN_SIGMOID_POSTOP_INT(int16_t,s8s8s16os16)
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32os8)
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32ou8)
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32os32)
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32obf16)
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32of32)
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32os8)
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32os32)
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32obf16)
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32of32)
GEN_SIGMOID_POSTOP_FLOAT(f32f32f32of32)
GEN_SIGMOID_POSTOP_FLOAT(bf16bf16f32of32)
@@ -841,65 +789,85 @@ GEN_SIGMOID_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_SIGMOID_POSTOP_FLOAT(bf16s4f32of32)
GEN_SIGMOID_POSTOP_FLOAT(bf16s4f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16bf16f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16s4f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16bf16f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16s4f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int32_t,u8s8s32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(uint8_t,int32_t,u8s8s32ou8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,int32_t,u8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(bfloat16,int32_t,u8s8s32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,int32_t,u8s8s32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int16_t,u8s8s16os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(uint8_t,int16_t,u8s8s16ou8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int16_t,int16_t,u8s8s16os16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int32_t,s8s8s32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,int32_t,s8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(bfloat16,int32_t,s8s8s32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,int32_t,s8s8s32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int16_t,s8s8s16os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int16_t,int16_t,s8s8s16os16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,f32f32f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16bf16f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16s4f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32ou8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32f32f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16bf16f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16s4f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16bf16f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16s4f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16bf16f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16s4f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32ou8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32f32f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16bf16f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16s4f32of32)
GEN_PRELU_POST_OP_VAL_FLOAT(f32f32f32of32)
GEN_PRELU_POST_OP_VAL_FLOAT(bf16bf16f32of32)
GEN_PRELU_POST_OP_VAL_FLOAT(bf16bf16f32obf16)
GEN_PRELU_POST_OP_VAL_FLOAT(bf16s4f32of32)
GEN_PRELU_POST_OP_VAL_FLOAT(bf16s4f32obf16)
GEN_PRELU_POST_OP_VAL_FLOAT(u8s8s32obf16)
GEN_PRELU_POST_OP_VAL_FLOAT(u8s8s32of32)
GEN_PRELU_POST_OP_VAL_FLOAT(s8s8s32obf16)
GEN_PRELU_POST_OP_VAL_FLOAT(s8s8s32of32)
GEN_PRELU_POST_OP_VAL_INT(u8s8s32os8)
GEN_PRELU_POST_OP_VAL_INT(u8s8s32ou8)
GEN_PRELU_POST_OP_VAL_INT(u8s8s32os32)
GEN_PRELU_POST_OP_VAL_INT(s8s8s32os8)
GEN_PRELU_POST_OP_VAL_INT(s8s8s32os32)
GEN_CLIP_POST_OP_VAL_FLOAT(f32f32f32of32)
GEN_CLIP_POST_OP_VAL_FLOAT(bf16bf16f32of32)
GEN_CLIP_POST_OP_VAL_FLOAT(bf16bf16f32obf16)
GEN_CLIP_POST_OP_VAL_FLOAT(bf16s4f32of32)
GEN_CLIP_POST_OP_VAL_FLOAT(bf16s4f32obf16)
GEN_CLIP_POST_OP_VAL_FLOAT(u8s8s32obf16)
GEN_CLIP_POST_OP_VAL_FLOAT(u8s8s32of32)
GEN_CLIP_POST_OP_VAL_FLOAT(s8s8s32obf16)
GEN_CLIP_POST_OP_VAL_FLOAT(s8s8s32of32)
GEN_CLIP_POST_OP_VAL_INT(u8s8s32os8)
GEN_CLIP_POST_OP_VAL_INT(u8s8s32ou8)
GEN_CLIP_POST_OP_VAL_INT(u8s8s32os32)
GEN_CLIP_POST_OP_VAL_INT(s8s8s32os8)
GEN_CLIP_POST_OP_VAL_INT(s8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int32_t,u8s8s32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(uint8_t,int32_t,u8s8s32ou8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,int32_t,u8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(bfloat16,int32_t,u8s8s32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,int32_t,u8s8s32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int16_t,u8s8s16os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(uint8_t,int16_t,u8s8s16ou8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int16_t,int16_t,u8s8s16os16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int32_t,s8s8s32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,int32_t,s8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(bfloat16,int32_t,s8s8s32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,int32_t,s8s8s32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int16_t,s8s8s16os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int16_t,int16_t,s8s8s16os16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,f32f32f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16bf16f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16s4f32of32)
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16bf16f32obf16)
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16s4f32obf16)
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32os8)
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32ou8)
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32obf16)
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32of32)
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16os8)
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16ou8)
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16os16)
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32os8)
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32ou8)
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32obf16)
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32of32)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os8)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32obf16)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32of32)
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os8)
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os16)
GEN_GET_BIAS_POST_OP_VAL_f32(f32f32f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
@@ -907,12 +875,14 @@ GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,int32_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int32_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,int32_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int16_t,int16_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int16_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,int16_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,float)
#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,SCALE_type,BLAS_SFX,BLAS_DOWNSCALE_SFX) \
#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,POST_ACCUM_type,SCALE_type,BLAS_SFX,BLAS_DOWNSCALE_SFX) \
void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
( \
FILE* fout, \
@@ -1007,6 +977,15 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
rs_a, rs_b, cs_a, cs_b, rs_c_ref, cs_c_ref, i, j, k, n, \
a_pre_op); \
\
POST_ACCUM_type post_temp_accum = 0; \
if ( is_integerAPI_avx512(#BLAS_SFX) ) \
{ \
CVT_FUNC_NAME(ACCUM_type,POST_ACCUM_type)(temp_accum, &post_temp_accum); \
} \
else \
{ \
post_temp_accum = temp_accum; \
} \
if ( post_op != NULL ) \
{ \
dim_t ele_i = 0; \
@@ -1014,7 +993,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
{ \
if ( post_op->seq_vector[op_id] == BIAS ) \
{ \
temp_accum += GEN_FUNC_NAME(get_bias_post_op_val_,BLAS_SFX) \
post_temp_accum += GEN_FUNC_NAME(get_bias_post_op_val_,BLAS_SFX) \
( ( post_op->bias )->bias, j, ( post_op->bias )->stor_type ); \
} \
else if ( post_op->seq_vector[op_id] == ELTWISE ) \
@@ -1022,66 +1001,56 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
PRELU ) /* PReLU*/ \
{ \
temp_accum = ( temp_accum > 0 ) ? \
temp_accum : \
( temp_accum * \
*( ( ACCUM_type* ) ( post_op->eltwise + ele_i )->algo.alpha ) ); \
post_temp_accum = GEN_FUNC_NAME(get_prelu_post_op_val_,BLAS_SFX) \
(post_temp_accum, ( post_op->eltwise + ele_i )->algo.alpha ); \
ele_i += 1; \
} \
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
GELU_TANH ) /* TANH GeLU*/ \
{ \
temp_accum = GEN_FUNC_NAME(GELU_TANH_post_op_,BLAS_SFX) (temp_accum);\
post_temp_accum = GEN_FUNC_NAME(GELU_TANH_post_op_,BLAS_SFX) (post_temp_accum);\
ele_i += 1; \
} \
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
GELU_ERF ) /* ERF GeLU*/ \
{ \
temp_accum = GEN_FUNC_NAME(GELU_ERF_post_op_,BLAS_SFX) (temp_accum);\
post_temp_accum = GEN_FUNC_NAME(GELU_ERF_post_op_,BLAS_SFX) (post_temp_accum);\
ele_i += 1; \
} \
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
SWISH ) /* SiLU*/ \
{ \
temp_accum = GEN_FUNC_NAME(SWISH_post_op_,BLAS_SFX) \
(temp_accum, \
*( ( ACCUM_type* ) \
( post_op->eltwise + ele_i )->algo.alpha ) );\
post_temp_accum = GEN_FUNC_NAME(SWISH_post_op_,BLAS_SFX) \
(post_temp_accum, \
( post_op->eltwise + ele_i )->algo.alpha );\
ele_i += 1; \
} \
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
RELU ) /* ReLU*/ \
{ \
temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \
post_temp_accum = ( post_temp_accum > 0 ) ? post_temp_accum : 0 ; \
ele_i += 1; \
} \
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
TANH ) /* TANH*/ \
{ \
temp_accum = GEN_FUNC_NAME(TANH_post_op_,BLAS_SFX) (temp_accum);\
post_temp_accum = GEN_FUNC_NAME(TANH_post_op_,BLAS_SFX) (post_temp_accum);\
ele_i += 1; \
} \
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
SIGMOID ) /* Sigmoid*/ \
{ \
temp_accum = GEN_FUNC_NAME(SIGMOID_post_op_,BLAS_SFX) (temp_accum);\
post_temp_accum = GEN_FUNC_NAME(SIGMOID_post_op_,BLAS_SFX) (post_temp_accum);\
ele_i += 1; \
} \
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
CLIP ) /* CLIP*/ \
{ \
temp_accum = \
min \
( \
max \
( \
temp_accum, \
*( ( ACCUM_type* ) \
( post_op->eltwise + ele_i )->algo.alpha ) \
), \
*( ( ACCUM_type* ) \
( post_op->eltwise + ele_i )->algo.beta) \
); \
post_temp_accum = GEN_FUNC_NAME(get_clip_post_op_val_,BLAS_SFX) \
( post_temp_accum, \
( post_op->eltwise + ele_i )->algo.alpha, \
( post_op->eltwise + ele_i )->algo.beta \
); \
ele_i += 1; \
} \
else \
@@ -1089,8 +1058,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
} \
else if ( post_op->seq_vector[op_id] == SCALE ) \
{ \
temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
(temp_accum, post_op, j); \
post_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
(post_temp_accum, post_op, j); \
} \
else if ( post_op->seq_vector[op_id] == MATRIX_ADD ) \
{ \
@@ -1101,31 +1070,11 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
cs_m = rs_m; \
rs_m = 1; \
} \
const char* api_name = #BLAS_SFX; \
/*The C_type is defined as f32 (32-bit floating point),while the ACCUM_type
is defined as s32 (32-bit signed integer). This type mismatch can lead to
unexpected behavior or random values during conversions.To address this issue,
we implement specific handling for the following APIs. */ \
\
if ( ( strcmp( api_name, "u8s8s32of32" ) == 0) || ( strcmp( api_name, "s8s8s32of32" ) == 0) \
|| ( strcmp( api_name, "u8s8s32obf16" ) == 0) || ( strcmp( api_name, "s8s8s32obf16" ) == 0) ) \
{ \
float* scl_fctr = ( float* )( ( post_op->matrix_add )->scale_factor ); \
dim_t scl_fctr_len = ( post_op->matrix_add )->scale_factor_len; \
temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
( *( ( ACCUM_type* )( post_op->matrix_add )->matrix + \
( i * rs_m ) + ( j * cs_m ) ), \
j, scl_fctr, scl_fctr_len ); \
} \
else \
{ \
float* scl_fctr = ( float* )( ( post_op->matrix_add )->scale_factor ); \
dim_t scl_fctr_len = ( post_op->matrix_add )->scale_factor_len; \
temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
( *( ( C_type* )( post_op->matrix_add )->matrix + \
( i * rs_m ) + ( j * cs_m ) ), \
j, scl_fctr, scl_fctr_len ); \
} \
float* scl_fctr = ( float* )( ( post_op->matrix_add )->scale_factor ); \
dim_t scl_fctr_len = ( post_op->matrix_add )->scale_factor_len; \
post_temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
( ( post_op->matrix_add )->matrix, i, \
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op->matrix_add )->stor_type ); \
} \
else if ( post_op->seq_vector[op_id] == MATRIX_MUL ) \
{ \
@@ -1136,40 +1085,20 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
cs_m = rs_m; \
rs_m = 1; \
} \
const char* api_name = #BLAS_SFX; \
/*The C_type is defined as f32 (32-bit floating point),while the ACCUM_type
is defined as s32 (32-bit signed integer). This type mismatch can lead to
unexpected behavior or random values during conversions.To address this issue,
we implement specific handling for the following APIs. */ \
\
if ( ( strcmp( api_name, "u8s8s32of32" ) == 0) || ( strcmp( api_name, "s8s8s32of32" ) == 0) \
|| ( strcmp( api_name, "u8s8s32obf16" ) == 0) || ( strcmp( api_name, "s8s8s32obf16" ) == 0) ) \
{ \
float* scl_fctr = ( float* )( ( post_op->matrix_mul )->scale_factor ); \
dim_t scl_fctr_len = ( post_op->matrix_mul )->scale_factor_len; \
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,BLAS_SFX) \
( *( ( ACCUM_type* )( post_op->matrix_mul )->matrix + \
( i * rs_m ) + ( j * cs_m ) ), \
j, scl_fctr, scl_fctr_len ); \
} \
else \
{ \
float* scl_fctr = ( float* )( ( post_op->matrix_mul )->scale_factor ); \
dim_t scl_fctr_len = ( post_op->matrix_mul )->scale_factor_len; \
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,BLAS_SFX) \
( *( ( C_type* )( post_op->matrix_mul )->matrix + \
( i * rs_m ) + ( j * cs_m ) ), \
j, scl_fctr, scl_fctr_len ); \
} \
float* scl_fctr = ( float* )( ( post_op->matrix_mul )->scale_factor ); \
dim_t scl_fctr_len = ( post_op->matrix_mul )->scale_factor_len; \
post_temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,BLAS_SFX) \
( ( post_op->matrix_mul )->matrix, i, \
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op->matrix_mul )->stor_type ); \
} \
else \
{} \
} \
} \
/* Need to convert to downscaled type if required.*/ \
mat_mul_get_output_type_val ## ACCUM_type ## C_type \
mat_mul_get_output_type_val ## POST_ACCUM_type ## C_type \
( \
&out_temp_accum, &temp_accum \
&out_temp_accum, &post_temp_accum \
); \
\
float comp_float, ref_float; \
@@ -1198,56 +1127,50 @@ cleanup_acc: \
return; \
} \
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,float,u8s8s16os16,u8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,float,u8s8s16os8,u8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,uint8_t,int16_t,float,u8s8s16ou8,u8s8s16ou8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int16_t,int16_t,float,s8s8s16os16,s8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,float,s8s8s16os8,s8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,float,u8s8s32os32,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,float,u8s8s32os8,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,uint8_t,int32_t,float,float,u8s8s32ou8,u8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,bfloat16,int32_t,float,float,u8s8s32obf16,u8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,float,int32_t,float,float,u8s8s32of32,u8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,u8s8s32os32,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,u8s8s32os8,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,uint8_t,int32_t,float,u8s8s32ou8,u8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,bfloat16,int32_t,float,u8s8s32obf16,u8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,float,int32_t,float,u8s8s32of32,u8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,float,float,float,float,bf16s4f32of32,bf16bf16f32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,bfloat16,float,float,float,bf16s4f32obf16,bf16bf16f32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,float,float,float,bf16bf16f32of32,bf16bf16f32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,float,float,bf16bf16f32obf16,bf16bf16f32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,float,f32f32f32of32,f32f32f32of32)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,float,float,float,bf16s4f32of32,bf16bf16f32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,bfloat16,float,float,bf16s4f32obf16,bf16bf16f32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,float,float,bf16bf16f32of32,bf16bf16f32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,float,bf16bf16f32obf16,bf16bf16f32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,f32f32f32of32,f32f32f32of32)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,int32_t,float,s8s8s32os32,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,int32_t,float,s8s8s32os8,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,bfloat16,int32_t,int32_t,float,s8s8s32obf16,s8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,float,int32_t,float,int32_t,s8s8s32of32,s8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,float,s8s8s32os32,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,float,s8s8s32os8,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,bfloat16,int32_t,float,s8s8s32obf16,s8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,float,int32_t,float,s8s8s32of32,s8s8s32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,int16_t,u8s8s16os16)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,int16_t,s8s8s16os16)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,u8s8s32os32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int8_t,float,int32_t,u8s8s32os8)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,uint8_t,float,int32_t,u8s8s32ou8)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,float,float,int32_t,u8s8s32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,bfloat16,float,int32_t,u8s8s32obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,float,u8s8s32os32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int8_t,float,float,u8s8s32os8)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,uint8_t,float,float,u8s8s32ou8)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,float,float,float,u8s8s32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,bfloat16,float,float,u8s8s32obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32os32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32os8)
//GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32of32)
//GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32obf16)
//GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32ou8)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int8_t,float,int32_t,s8s8s32os8)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,float,float,int32_t,s8s8s32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,bfloat16,float,int32_t,s8s8s32obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,uint8_t,float,int32_t,s8s8s32ou8)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16bf16f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,bfloat16,bf16bf16f32obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16s4f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,bfloat16,bf16s4f32obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(float,float,float,float,f32f32f32of32)
// Hack to fix compiler errors.
#define GET_B_TYPE_bf16bf16f32of32 bfloat16
#define GET_B_TYPE_u8s8s16os16 int8_t
#define GET_B_TYPE_u8s8s32os32 int8_t
#define GET_B_TYPE_f32f32f32of32 float
#define GET_B_TYPE_s8s8s32os32 int8_t
#define GET_B_TYPE_s8s8s16os16 int8_t
#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type, B_type, C_type, Sum_type, BLAS_SFX, REORDER_SFX, INT4_REORDER_SFX) \
void mat_mul_bench_main_ ## BLAS_SFX \
@@ -1337,7 +1260,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \
( strcmp( post_ops_str, "none" ) != 0 ) ) || \
( global_dscale_out == 'y' ) || ( global_pre_op == 'y' ) ) \
{ \
post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,REORDER_SFX)( m, n, k, post_ops_str, stor_order ); \
post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,BLAS_SFX)( m, n, k, post_ops_str, stor_order ); \
if ( post_op == NULL ) \
{ \
printf(" post op struct allocation failure, returning.\n"); \
@@ -1403,6 +1326,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \
if ( bench_mode == 'a' ) \
{ \
printf(" Running accuracy check.\n"); \
fflush(stdout); \
GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \
( \
fout, stor_order, transa, transb, m, n, k, \
@@ -1426,9 +1350,6 @@ void mat_mul_bench_main_ ## BLAS_SFX \
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32,bf16bf16f32of32,bf16s4f32of32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16,bf16bf16f32of32,bf16s4f32of32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16,u8s8s16os16,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8,u8s8s16os16,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8,u8s8s16os16,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32,u8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8,u8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8,u8s8s32os32,u8s4s32os32)
@@ -1439,8 +1360,6 @@ GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32,s8s8s32os3
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8,s8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16,s8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32,s8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16,s8s8s16os16,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8,s8s8s16os16,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32,bf16bf16f32of32,bf16s4f32of32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16,bf16bf16f32of32,bf16s4f32of32)
@@ -1481,11 +1400,7 @@ int main( int argc, char** argv )
" 5. s8s8s32os32 -d s8 = s8s8s32os8.\n" \
" 6. s8s8s32os32 -d f32 = s8s8s32of32.\n" \
" 7. s8s8s32os32 -d bf16 = s8s8s32obf16.\n" \
" 8. s8s8s16os16 -d s8 = s8s8s16os8.\n" \
" 9. s8s8s16os16 -d u8 = u8s8s16ou8.\n" \
" 10. u8s8s16os16 -d s8 = u8s8s16os8.\n" \
" 11. u8s8s16os16 -d u8 = u8s8s16ou8.\n" \
" 12. bf16bf16f32of32 -d bf16 = bf16bf16f32obf16.\n" \
" 8. bf16bf16f32of32 -d bf16 = bf16bf16f32obf16.\n" \
" Example: ./bench_lpgemm -m a -n 2 -o bias,relu -d bf16 -i input.txt\n" \
);
exit( 1 );
@@ -1735,53 +1650,6 @@ int main( int argc, char** argv )
post_ops_str_dest
);
}
#if 0
if ( ( strcmp( gemm_type_str, "u8s8s16os16" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'n';
global_pre_op = 'n';
DSCALE_CLIP_MIN = SHRT_MIN;
DSCALE_CLIP_MAX = SHRT_MAX;
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os16)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest
);
}
if ( ( strcmp( gemm_type_str, "u8s8s16os8" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'y';
global_pre_op = 'n';
DSCALE_CLIP_MIN = -128;
DSCALE_CLIP_MAX = +127;
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os8)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest
);
}
if ( ( strcmp( gemm_type_str, "u8s8s16ou8" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'y';
global_pre_op = 'n';
DSCALE_CLIP_MIN = 0;
DSCALE_CLIP_MAX = +255;
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16ou8)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest
);
}
#endif
if ( ( strcmp( gemm_type_str, "bf16bf16f32of32" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
@@ -1906,38 +1774,6 @@ int main( int argc, char** argv )
post_ops_str_dest
);
}
#if 0
if ( ( strcmp( gemm_type_str, "s8s8s16os16" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'n';
global_pre_op = 'n';
DSCALE_CLIP_MIN = SHRT_MIN;
DSCALE_CLIP_MAX = SHRT_MAX;
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os16)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest
);
}
if ( ( strcmp( gemm_type_str, "s8s8s16os8" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'y';
global_pre_op = 'n';
DSCALE_CLIP_MIN = -128;
DSCALE_CLIP_MAX = +127;
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os8)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest
);
}
#endif
}
}

View File

@@ -36,8 +36,6 @@
GEN_FILL_ARRAY_FUNC(float)
CONVERT_TO_FLOAT(float)
void print_result
(
const char* msg,
@@ -195,13 +193,13 @@ static inline float eltwise_ops_accuracy_check_downscale_f32of32
return out_temp_accum;
}
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32of32)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float)
@@ -316,8 +314,7 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
{ \
temp_accum = GEN_FUNC_NAME(SWISH_post_op_,LP_SFX) \
(temp_accum, \
*( ( ACCUM_type* ) \
( post_op->eltwise + ele_i )->algo.alpha ) );\
( post_op->eltwise + ele_i )->algo.alpha );\
ele_i += 1; \
} \
else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \
@@ -375,9 +372,8 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
float* scl_fctr = ( float* )( ( post_op->matrix_add )->scale_factor ); \
dim_t scl_fctr_len = ( post_op->matrix_add )->scale_factor_len; \
temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,LP_SFX) \
( *( ( B_type* )( post_op->matrix_add )->matrix + \
( i * rs_m ) + ( j * cs_m ) ), \
j, scl_fctr, scl_fctr_len ); \
( ( post_op->matrix_add )->matrix, i, \
j, rs_m, cs_m, scl_fctr, scl_fctr_len,( post_op->matrix_add )->stor_type ); \
} \
else if ( post_op->seq_vector[op_id] == MATRIX_MUL ) \
{ \
@@ -391,9 +387,8 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
float* scl_fctr = ( float* )( ( post_op->matrix_mul )->scale_factor ); \
dim_t scl_fctr_len = ( post_op->matrix_mul )->scale_factor_len; \
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,LP_SFX) \
( *( ( B_type* )( post_op->matrix_mul )->matrix + \
( i * rs_m ) + ( j * cs_m ) ), \
j, scl_fctr, scl_fctr_len ); \
( ( post_op->matrix_mul )->matrix, i, \
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op->matrix_add )->stor_type ); \
} \
else \
{} \

View File

@@ -69,7 +69,7 @@ char global_pre_op = 'n';
#define XSTR(str) _XSTR(str)
#define GEN_FUNC_NAME(prototype,ctype) prototype ## ctype
#define CVT_FUNC_NAME(stype, dtype) stype ## _to_ ## dtype
// Inplace to lower func.
static inline void str_tolower( char* str )
{
@@ -83,6 +83,22 @@ static inline void GEN_FUNC_NAME(ctype,_to_float) ( ctype val, float* float_val
*float_val = (float) val; \
} \
CONVERT_TO_FLOAT(uint8_t)
CONVERT_TO_FLOAT(int8_t)
CONVERT_TO_FLOAT(int16_t)
CONVERT_TO_FLOAT(float)
CONVERT_TO_FLOAT(int32_t)
#define CONVERT_ITSELF(ctype) \
static inline void GEN_FUNC_NAME(ctype,_to_ ## ctype) ( ctype val, ctype* ctype_val ) \
{ \
*ctype_val = val; \
}
CONVERT_ITSELF(int16_t)
CONVERT_ITSELF(int32_t)
static inline void float_to_bf16( float* float_value, bfloat16* bf16_val )
{
/*Set offset 2 to copy most significant 2 bytes of float
@@ -169,6 +185,43 @@ static inline void lpgemm_free( void* p )
}
}
bool is_integerAPI_avx512( char* api_name )
{
if ( ( strcmp( api_name, "u8s8s32of32" ) == 0) || ( strcmp( api_name, "u8s8s32os8" ) == 0) \
|| ( strcmp( api_name, "u8s8s32obf16" ) == 0) || ( strcmp( api_name, "u8s8s32os32" ) == 0) \
|| ( strcmp( api_name, "u8s8s32ou8" ) == 0) ) \
{ \
return TRUE; \
} \
else \
{ \
return FALSE; \
}
}
bool is_integer( char* type )
{
if ( ( strcmp( type, "int8_t" ) == 0 ) || ( strcmp( type, "int16_t" ) == 0 ) \
|| ( strcmp( type, "int32_t" ) == 0 ) || ( strcmp( type, "uint8_t" ) == 0 ) ) \
{ \
return TRUE; \
} \
else \
{ \
return FALSE; \
}
}
bool is_bf16API_avx512( char* api_name )
{
if ( ( strcmp( api_name, "bf16bf16f32of32" ) == 0) || ( strcmp( api_name, "bf16bf16f32obf16" ) == 0) \
|| ( strcmp( api_name, "bf16s4f32of32" ) == 0) || strcmp( api_name, "bf16s4f32obf16")) \
{ \
return TRUE; \
} \
else \
{ \
return FALSE; \
}
}
#ifdef BLIS_ENABLE_OPENMP
/* Matrix fill helper macros. */
#define GEN_FILL_ARRAY_FUNC(ctype) \
@@ -237,6 +290,61 @@ static inline void fill_array_post_ops_bfloat16( void* arr, dim_t size )
}
/* POST-OPS Helper macros. */
/* CLIP */
#define GEN_CLIP_POST_OP_VAL_INT(BLAS_SFX) \
static inline float get_clip_post_op_val_ ## BLAS_SFX \
( \
float post_temp_accum, \
void* post_op_alpha_ptr, \
void* post_op_beta_ptr \
) \
{ \
float alpha, beta; \
int32_t_to_float(*( ( int32_t* )post_op_alpha_ptr), &alpha); \
int32_t_to_float(*( ( int32_t* )post_op_beta_ptr), &beta); \
return min( max( post_temp_accum, alpha),beta); \
}
#define GEN_CLIP_POST_OP_VAL_FLOAT(BLAS_SFX) \
static inline float get_clip_post_op_val_ ## BLAS_SFX \
( \
float post_temp_accum, \
void* post_op_alpha_ptr, \
void* post_op_beta_ptr \
) \
{ \
return min( max( post_temp_accum, *( ( float* )post_op_alpha_ptr ) ), \
*( ( float* )post_op_beta_ptr ) ); \
}
/* PRELU */
#define GEN_PRELU_POST_OP_VAL_FLOAT(BLAS_SFX) \
static inline float get_prelu_post_op_val_ ## BLAS_SFX \
( \
float post_temp_accum, \
void* post_op_alpha_ptr \
) \
{ \
return (( post_temp_accum > 0 ) ? \
post_temp_accum : \
( post_temp_accum * \
(*( float* )post_op_alpha_ptr) )); \
}
#define GEN_PRELU_POST_OP_VAL_INT(BLAS_SFX) \
static inline float get_prelu_post_op_val_ ## BLAS_SFX \
( \
float post_temp_accum, \
void* post_op_alpha_ptr \
) \
{ \
float ret_val; \
int32_t_to_float(*( ( int32_t* )post_op_alpha_ptr), &ret_val); \
\
return ( post_temp_accum > 0 ) ? \
post_temp_accum : \
( post_temp_accum * ret_val ); \
}
/* Bias. */
#define GEN_GET_BIAS_POST_OP_VAL_BF16(BLAS_SFX) \
@@ -283,39 +391,25 @@ static inline ACCUM_type get_bias_post_op_val_ ## BLAS_SFX \
{ \
if(bias_stor_type == AOCL_GEMM_BF16) \
{ \
int32_t ret_val = 0.0; \
bfloat16_to_int32_t( *( ( bfloat16* )post_op_bias_ptr + j ), &ret_val ); \
float ret_val = 0.0; \
bfloat16_to_float( *( ( bfloat16* )post_op_bias_ptr + j ), &ret_val ); \
return ret_val; \
} \
if(bias_stor_type == AOCL_GEMM_INT8) \
{ \
int32_t ret_val = 0.0; \
int8_t_to_int32_t( *( ( int8_t* )post_op_bias_ptr + j ), &ret_val ); \
float ret_val = 0.0; \
int8_t_to_float( *( ( int8_t* )post_op_bias_ptr + j ), &ret_val ); \
return ret_val; \
} \
if(bias_stor_type == AOCL_GEMM_F32) \
if(bias_stor_type == AOCL_GEMM_INT32) \
{ \
int32_t ret_val = 0.0; \
ret_val = (int32_t) *( ( float* )post_op_bias_ptr + j ); \
float ret_val = 0.0; \
int32_t_to_float( *( ( int32_t* )post_op_bias_ptr + j ), &ret_val ); \
return ret_val; \
} \
return *( ( ACCUM_type* )post_op_bias_ptr + j ); \
} \
/* GELU Tanh. */
#define GEN_GELU_TANH_POSTOP_INT(ACCUM_type,BLAS_SFX) \
static inline ACCUM_type GELU_TANH_post_op_ ## BLAS_SFX \
( \
ACCUM_type temp_accum \
) \
{ \
float gelu_reference = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \
( 0.044715 * ((double)temp_accum * (double)temp_accum * \
(double)temp_accum ) ) ) ) ); \
temp_accum = round (gelu_reference); \
return temp_accum; \
} \
#define GEN_GELU_TANH_POSTOP_FLOAT(BLAS_SFX) \
static inline float GELU_TANH_post_op_ ## BLAS_SFX \
( \
@@ -329,17 +423,6 @@ static inline float GELU_TANH_post_op_ ## BLAS_SFX \
} \
/* GELU Erf. */
#define GEN_GELU_ERF_POSTOP_INT(ACCUM_type,BLAS_SFX) \
static inline ACCUM_type GELU_ERF_post_op_ ## BLAS_SFX \
( \
ACCUM_type temp_accum \
) \
{ \
float gelu_reference = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); \
temp_accum = round (gelu_reference); \
return temp_accum; \
} \
#define GEN_GELU_ERF_POSTOP_FLOAT(BLAS_SFX) \
static inline float GELU_ERF_post_op_ ## BLAS_SFX \
( \
@@ -351,17 +434,6 @@ static inline float GELU_ERF_post_op_ ## BLAS_SFX \
} \
/* TANH. */
#define GEN_TANH_POSTOP_INT(ACCUM_type,BLAS_SFX) \
static inline ACCUM_type TANH_post_op_ ## BLAS_SFX \
( \
ACCUM_type temp_accum \
) \
{ \
float tanh_reference = tanhf( ( double )temp_accum ); \
temp_accum = round( tanh_reference ); \
return temp_accum; \
} \
#define GEN_TANH_POSTOP_FLOAT(BLAS_SFX) \
static inline float TANH_post_op_ ## BLAS_SFX \
( \
@@ -373,18 +445,6 @@ static inline float TANH_post_op_ ## BLAS_SFX \
} \
/* SIGMOID. */
#define GEN_SIGMOID_POSTOP_INT(ACCUM_type,BLAS_SFX) \
static inline ACCUM_type SIGMOID_post_op_ ## BLAS_SFX \
( \
ACCUM_type temp_accum \
) \
{ \
float sigmoid_reference = ( 1 / ( 1 + \
(dim_t) round( expf( temp_accum * -1 ) ) ) ); \
temp_accum = round (sigmoid_reference); \
return temp_accum; \
} \
#define GEN_SIGMOID_POSTOP_FLOAT(BLAS_SFX) \
static inline float SIGMOID_post_op_ ## BLAS_SFX \
( \
@@ -401,11 +461,13 @@ static inline float SIGMOID_post_op_ ## BLAS_SFX \
static inline ACCUM_type SWISH_post_op_ ## BLAS_SFX \
( \
ACCUM_type temp_accum, \
ACCUM_type alpha \
void* alpha \
) \
{ \
float alpha_val; \
int32_t_to_float(*( ( int32_t* )alpha), &alpha_val); \
float swish_reference = ( temp_accum / ( 1 + \
expf( ( double )alpha * temp_accum * -1 ) ) ); \
expf( ( double )(alpha_val) * temp_accum * -1 ) ) ); \
temp_accum = round (swish_reference); \
return temp_accum; \
} \
@@ -414,42 +476,25 @@ static inline ACCUM_type SWISH_post_op_ ## BLAS_SFX \
static inline float SWISH_post_op_ ## BLAS_SFX \
( \
float temp_accum, \
float alpha \
void* alpha \
) \
{ \
temp_accum = ( temp_accum / ( 1 + \
expf( ( double )alpha * temp_accum * -1 ) ) ); \
expf( ( double )(*((float*)alpha) * temp_accum * -1 ) ) )); \
return temp_accum; \
} \
/* Matrix Add. */
#define GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(C_type,BLAS_SFX) \
static inline float get_matrix_add_post_op_val_ ## BLAS_SFX \
( \
C_type val, \
dim_t j, \
float* scl_fctr, \
dim_t scl_fctr_len \
) \
{ \
float ret_val = 0.0; \
dim_t j_scale = j; \
if ( scl_fctr_len == 1 ) \
{ \
j_scale = 0; \
} \
\
bfloat16_to_float( val, &ret_val ); \
return ( ret_val * *( scl_fctr + j_scale ) ); \
} \
#define GEN_GET_MATRIX_ADD_POST_OP_VAL(C_type,ACCUM_type,BLAS_SFX) \
#define GEN_GET_MATRIX_ADD_POST_OP_VAL(ACCUM_type,BLAS_SFX) \
static inline ACCUM_type get_matrix_add_post_op_val_ ## BLAS_SFX \
( \
C_type val, \
void* mat_add_ptr, \
dim_t i, \
dim_t j, \
dim_t rs_m, \
dim_t cs_m, \
float* scl_fctr, \
dim_t scl_fctr_len \
dim_t scl_fctr_len, \
AOCL_PARAMS_STORAGE_TYPES matadd_stor_type \
) \
{ \
dim_t j_scale = j; \
@@ -457,36 +502,94 @@ static inline ACCUM_type get_matrix_add_post_op_val_ ## BLAS_SFX \
{ \
j_scale = 0; \
} \
return (ACCUM_type) ( ( float )val * *( scl_fctr + j_scale ) ); \
if( matadd_stor_type == AOCL_GEMM_BF16 ) \
{ \
float ret_val = 0.0; \
bfloat16 val = *( ( bfloat16* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ); \
bfloat16_to_float( val, &ret_val ); \
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
} \
if( matadd_stor_type == AOCL_GEMM_INT8 ) \
{ \
float ret_val = 0.0; \
int8_t_to_float( *( ( int8_t* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ), &ret_val ); \
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
} \
if( matadd_stor_type == AOCL_GEMM_INT32 ) \
{ \
float ret_val = 0.0; \
int32_t_to_float( *( ( int32_t* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ), &ret_val ); \
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
} \
if( matadd_stor_type == AOCL_GEMM_F32 ) \
{ \
float ret_val = 0.0; \
ret_val = *( ( float* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ); \
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
} \
/* default case */ \
if( is_integerAPI_avx512(#BLAS_SFX) ) \
{ \
if( strcmp( #BLAS_SFX, "u8s8s32os8" ) == 0 ) \
{ \
float ret_val = 0.0; \
int8_t_to_float( *( ( int8_t* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ), &ret_val ); \
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
} \
float ret_val = 0.0; \
int32_t_to_float( *( ( int32_t* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ), &ret_val ); \
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
} \
else \
{ \
if( global_dscale_out == 'y' ) \
{ \
float ret_val = 0.0; \
bfloat16 val = *( ( bfloat16* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ); \
bfloat16_to_float( val, &ret_val ); \
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
} \
float ret_val = 0.0; \
ret_val = *( ( float* )mat_add_ptr + ( i * rs_m ) + ( j * cs_m ) ); \
return ( ( float )ret_val * *( scl_fctr + j_scale ) ); \
} \
} \
#define GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(C_type,BLAS_SFX) \
#define GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(BLAS_SFX) \
static inline float get_matrix_mul_post_op_val_ ## BLAS_SFX \
( \
C_type val, \
void* mat_add_ptr, \
dim_t i, \
dim_t j, \
dim_t rs_m, \
dim_t cs_m, \
float* scl_fctr, \
dim_t scl_fctr_len \
dim_t scl_fctr_len, \
AOCL_PARAMS_STORAGE_TYPES matadd_stor_type \
) \
{ \
return GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
( \
val, j, scl_fctr, scl_fctr_len \
mat_add_ptr, i, j, rs_m, cs_m, scl_fctr, scl_fctr_len, matadd_stor_type \
); \
} \
#define GEN_GET_MATRIX_MUL_POST_OP_VAL(C_type,ACCUM_type,BLAS_SFX) \
#define GEN_GET_MATRIX_MUL_POST_OP_VAL(ACCUM_type,BLAS_SFX) \
static inline ACCUM_type get_matrix_mul_post_op_val_ ## BLAS_SFX \
( \
C_type val, \
void* mat_add_ptr, \
dim_t i, \
dim_t j, \
dim_t rs_m, \
dim_t cs_m, \
float* scl_fctr, \
dim_t scl_fctr_len \
dim_t scl_fctr_len, \
AOCL_PARAMS_STORAGE_TYPES matadd_stor_type \
) \
{ \
return GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
( \
val, j, scl_fctr, scl_fctr_len \
mat_add_ptr, i, j, rs_m, cs_m, scl_fctr, scl_fctr_len, matadd_stor_type \
); \
} \
@@ -586,6 +689,7 @@ static inline void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops )
if ( post_ops->matrix_add != NULL )
{
free( ( post_ops->matrix_add )->matrix );
//free( ( post_ops->matrix_add )->scale_factor );
free( post_ops->matrix_add );
}
@@ -599,6 +703,7 @@ static inline void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops )
if ( post_ops->matrix_mul != NULL )
{
free( ( post_ops->matrix_mul )->matrix );
//free( ( post_ops->matrix_mul )->scale_factor );
free( post_ops->matrix_mul );
}
@@ -770,6 +875,10 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
dim_t activator_idx = 0; \
dim_t clip_idx = 0; \
char * bias_stor_type = ""; \
bool is_matadd_stor_type = FALSE; \
char* matadd_stor_type = ""; \
bool is_matmul_stor_type = FALSE; \
char* matmul_stor_type = ""; \
bool is_group_quant = FALSE; \
bool is_pre_op_scale_scalar = FALSE; \
bool is_pre_op_scale_f32 = TRUE; \
@@ -899,12 +1008,62 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
else if ( strcmp( ops_tok, "matrix_add" ) == 0 ) \
{ \
post_ops->seq_vector[cur_op_index] = MATRIX_ADD; \
ops_tok = strtok( NULL, ", " ); \
if( ( strcmp( ops_tok, "na" ) == 0 ) ) \
{ \
is_matadd_stor_type = FALSE; \
} \
else if ( ( strcmp( ops_tok, "f32" ) == 0 ) ) \
{ \
is_matadd_stor_type = TRUE; \
matadd_stor_type = "F32"; \
} \
else if ( ( strcmp( ops_tok, "bf16" ) == 0 ) ) \
{ \
is_matadd_stor_type = TRUE; \
matadd_stor_type = "BF16"; \
} \
else if ( ( strcmp( ops_tok, "s32" ) == 0 ) ) \
{ \
is_matadd_stor_type = TRUE; \
matadd_stor_type = "S32"; \
} \
else if ( ( strcmp( ops_tok, "s8" ) == 0 ) ) \
{ \
is_matadd_stor_type = TRUE; \
matadd_stor_type = "S8"; \
} \
is_matrix_add = TRUE; \
cur_op_index++; \
} \
else if ( strcmp( ops_tok, "matrix_mul" ) == 0 ) \
{ \
post_ops->seq_vector[cur_op_index] = MATRIX_MUL; \
ops_tok = strtok( NULL, ", " ); \
if( ( strcmp( ops_tok, "na" ) == 0 ) ) \
{ \
is_matmul_stor_type = FALSE; \
} \
else if ( ( strcmp( ops_tok, "f32" ) == 0 ) ) \
{ \
is_matmul_stor_type = TRUE; \
matmul_stor_type = "F32"; \
} \
else if ( ( strcmp( ops_tok, "bf16" ) == 0 ) ) \
{ \
is_matmul_stor_type = TRUE; \
matmul_stor_type = "BF16"; \
} \
else if ( ( strcmp( ops_tok, "s32" ) == 0 ) ) \
{ \
is_matmul_stor_type = TRUE; \
matmul_stor_type = "S32"; \
} \
else if ( ( strcmp( ops_tok, "s8" ) == 0 ) ) \
{ \
is_matmul_stor_type = TRUE; \
matmul_stor_type = "S8"; \
} \
is_matrix_mul = TRUE; \
cur_op_index++; \
} \
@@ -988,32 +1147,50 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
\
if ( is_bias == TRUE ) \
{ \
/* Allocate bias buffer, return early if alloc fails.*/ \
( post_ops->bias )->bias = malloc( n * sizeof( C_type ) ); \
if ( ( post_ops->bias )->bias == NULL ) \
{ \
goto err_handler; \
} \
if(is_bias_stor_type == TRUE) \
{ \
if( ( strcmp( bias_stor_type, "BF16" ) == 0 ) ) \
{ \
( post_ops->bias )->stor_type = AOCL_GEMM_BF16; \
/* Allocate bias buffer, return early if alloc fails.*/ \
( post_ops->bias )->bias = malloc( n * sizeof( bfloat16 ) ); \
if ( ( post_ops->bias )->bias == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_post_ops_,bfloat16)( ( post_ops->bias )->bias, n ); \
} \
else if( ( strcmp( bias_stor_type, "F32" ) == 0 ) ) \
{ \
( post_ops->bias )->stor_type = AOCL_GEMM_F32; \
/* Allocate bias buffer, return early if alloc fails.*/ \
( post_ops->bias )->bias = malloc( n * sizeof( float ) ); \
if ( ( post_ops->bias )->bias == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_post_ops_,float)( ( post_ops->bias )->bias, n ); \
} \
else if( ( strcmp( bias_stor_type, "S8" ) == 0 ) ) \
{ \
( post_ops->bias )->stor_type = AOCL_GEMM_INT8; \
/* Allocate bias buffer, return early if alloc fails.*/ \
( post_ops->bias )->bias = malloc( n * sizeof( int8_t ) ); \
if ( ( post_ops->bias )->bias == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_post_ops_,int8_t)( ( post_ops->bias )->bias, n ); \
} \
else if( ( strcmp( bias_stor_type, "S32" ) == 0 ) ) \
{ \
( post_ops->bias )->stor_type = AOCL_GEMM_INT32; \
/* Allocate bias buffer, return early if alloc fails.*/ \
( post_ops->bias )->bias = malloc( n * sizeof( int32_t ) ); \
if ( ( post_ops->bias )->bias == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_post_ops_,int32_t)( ( post_ops->bias )->bias, n ); \
} \
else {} \
@@ -1021,17 +1198,13 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
else \
{ \
( post_ops->bias )->stor_type = NULLTYPE; \
if( global_dscale_out == 'y') \
/* Allocate bias buffer, return early if alloc fails.*/ \
( post_ops->bias )->bias = malloc( n * sizeof( BIAS_type ) ); \
if ( ( post_ops->bias )->bias == NULL ) \
{ \
if ( strcmp(#BIAS_type, "bfloat16") == 0 ) { \
( post_ops->bias )->stor_type = AOCL_GEMM_BF16; \
} \
GEN_FUNC_NAME(fill_array_post_ops_,BIAS_type)( ( post_ops->bias )->bias, n ); \
} \
else \
{ \
GEN_FUNC_NAME(fill_array_post_ops_,C_type)( ( post_ops->bias )->bias, n ); \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_post_ops_,BIAS_type)( ( post_ops->bias )->bias, n ); \
} \
} \
\
@@ -1078,12 +1251,27 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
/* If output is float/bfloat16, param type will be float otherwise s32 */ \
if( is_integer(#C_type) ) \
{ \
goto err_handler; \
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( int32_t ) ); \
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
{ \
goto err_handler; \
} \
*( ( int32_t* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( int32_t )6; \
} \
*( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )6; \
else \
{ \
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( float ) ); \
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
{ \
goto err_handler; \
} \
*( ( float* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( float )6; \
} \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = PRELU; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = PRELU; \
} \
@@ -1092,12 +1280,25 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
/* If output is float/bfloat16, params type will be float otherwise s32 */ \
if( is_integer(#C_type) ) \
{ \
goto err_handler; \
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( int32_t ) ); \
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
{ \
goto err_handler; \
} \
*( ( int32_t* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( int32_t )2; \
} \
else \
{ \
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( float ) ); \
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
{ \
goto err_handler; \
} \
*( ( float* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( float )2; \
} \
*( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )2; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = SWISH; \
} \
@@ -1123,18 +1324,37 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
( post_ops->eltwise + clip_idx )->scale_factor = NULL; \
( post_ops->eltwise + clip_idx )->algo.alpha = NULL; \
( post_ops->eltwise + clip_idx )->algo.beta = NULL; \
( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( C_type ) ); \
if ( ( post_ops->eltwise + clip_idx )->algo.alpha == NULL ) \
/* If output is float/bfloat16, params type will be float otherwise s32 */ \
if( is_integer(#C_type) ) \
{ \
goto err_handler; \
( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( int32_t ) ); \
if ( ( post_ops->eltwise + clip_idx )->algo.alpha == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( int32_t ) ); \
if ( ( post_ops->eltwise + clip_idx )->algo.beta == NULL ) \
{ \
goto err_handler; \
} \
*( ( int32_t* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( int32_t ) ( -64 ); \
*( ( int32_t* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( int32_t ) ( 23 ); \
} \
( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( C_type ) ); \
if ( ( post_ops->eltwise + clip_idx )->algo.beta == NULL ) \
else \
{ \
goto err_handler; \
( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( float ) ); \
if ( ( post_ops->eltwise + clip_idx )->algo.alpha == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( float ) ); \
if ( ( post_ops->eltwise + clip_idx )->algo.beta == NULL ) \
{ \
goto err_handler; \
} \
*( ( float* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( float ) ( -64 ); \
*( ( float* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( float ) ( 23 ); \
} \
*( ( C_type* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( C_type ) ( -64 ); \
*( ( C_type* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( C_type ) ( 23 ); \
( post_ops->eltwise + clip_idx )->algo.algo_type = CLIP; \
} \
else if ( is_tanh == TRUE ) \
@@ -1199,28 +1419,99 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
\
if ( is_matrix_add == TRUE ) \
{ \
/* Allocate add matrix buffer, return early if alloc fails.*/ \
dim_t ele_dsize = 0; \
if ( global_dscale_out == 'y' ) \
if( is_matadd_stor_type == TRUE) \
{ \
ele_dsize = sizeof( C_DSCALE_type ); \
if( ( strcmp( matadd_stor_type, "BF16" ) == 0 ) ) \
{ \
( post_ops->matrix_add )->stor_type = AOCL_GEMM_BF16; \
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(bfloat16) ); \
if ( ( post_ops->matrix_add )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,bfloat16)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
else if( ( strcmp( matadd_stor_type, "F32" ) == 0 ) ) \
{ \
( post_ops->matrix_add )->stor_type = AOCL_GEMM_F32; \
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(float) ); \
if ( ( post_ops->matrix_add )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,float)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
else if( ( strcmp( matadd_stor_type, "S32" ) == 0 ) ) \
{ \
( post_ops->matrix_add )->stor_type = AOCL_GEMM_INT32; \
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(int32_t) ); \
if ( ( post_ops->matrix_add )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,int32_t)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
else if( ( strcmp( matadd_stor_type, "S8" ) == 0 ) ) \
{ \
( post_ops->matrix_add )->stor_type = AOCL_GEMM_INT8; \
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(int8_t) ); \
if ( ( post_ops->matrix_add )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,int8_t)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
else {} \
} \
else \
{ \
ele_dsize = sizeof( C_type ); \
} \
( post_ops->matrix_add )->matrix = malloc( m * n * ele_dsize ); \
if ( ( post_ops->matrix_add )->matrix == NULL ) \
{ \
goto err_handler; \
} \
if ( global_dscale_out == 'y' ) \
{ \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
else \
{ \
GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
/* default is int32_t for integer APIs and float for others */ \
if( is_integerAPI_avx512(#BLAS_SFX)) \
{ \
if( strcmp(#C_type, "int8_t") == 0 ) \
{ \
( post_ops->matrix_add )->stor_type = NULLTYPE; \
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(int8_t) ); \
if ( ( post_ops->matrix_add )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,int8_t)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
else \
{ \
( post_ops->matrix_add )->stor_type = NULLTYPE; \
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(int32_t) ); \
if ( ( post_ops->matrix_add )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,int32_t)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
} \
else \
{ \
if( global_dscale_out == 'y' ) \
{ \
( post_ops->matrix_add )->stor_type = NULLTYPE; \
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(C_DSCALE_type) ); \
if ( ( post_ops->matrix_add )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
else \
{ \
( post_ops->matrix_add )->stor_type = NULLTYPE; \
( post_ops->matrix_add )->matrix = malloc( m * n * sizeof(float) ); \
if ( ( post_ops->matrix_add )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,float)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
} \
} \
if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \
{ \
@@ -1245,35 +1536,76 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
temp_dscale_ptr[i] = ( ( DSCALE_type )2 ); \
} \
( post_ops->matrix_add )->scale_factor_len = n_scale; \
/* Set buffer type same as c_store type for now.
* TODO: Update to cover more data types. */ \
( post_ops->matrix_add )->stor_type = NULLTYPE; \
} \
\
if ( is_matrix_mul == TRUE ) \
{ \
/* Allocate mul matrix buffer, return early if alloc fails.*/ \
dim_t ele_dsize = 0; \
if ( global_dscale_out == 'y' ) \
if( is_matmul_stor_type == TRUE) \
{ \
ele_dsize = sizeof( C_DSCALE_type ); \
if( ( strcmp( matmul_stor_type, "BF16" ) == 0 ) ) \
{ \
( post_ops->matrix_mul )->stor_type = AOCL_GEMM_BF16; \
( post_ops->matrix_mul )->matrix = malloc( m * n * sizeof(bfloat16) ); \
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,bfloat16)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
} \
else if( ( strcmp( matmul_stor_type, "F32" ) == 0 ) ) \
{ \
( post_ops->matrix_mul )->stor_type = AOCL_GEMM_F32; \
( post_ops->matrix_mul )->matrix = malloc( m * n * sizeof(float) ); \
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,float)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
} \
else if( ( strcmp( matmul_stor_type, "S32" ) == 0 ) ) \
{ \
( post_ops->matrix_mul )->stor_type = AOCL_GEMM_INT32; \
( post_ops->matrix_mul )->matrix = malloc( m * n * sizeof(int32_t) ); \
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,int32_t)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
} \
else if( ( strcmp( matmul_stor_type, "S8" ) == 0 ) ) \
{ \
( post_ops->matrix_mul )->stor_type = AOCL_GEMM_INT8; \
( post_ops->matrix_mul )->matrix = malloc( m * n * sizeof(int8_t) ); \
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,int8_t)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
} \
else {} \
} \
else \
{ \
ele_dsize = sizeof( C_type ); \
} \
( post_ops->matrix_mul )->matrix = malloc( m * n * ele_dsize ); \
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
{ \
goto err_handler; \
} \
if ( global_dscale_out == 'y' ) \
{ \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
} \
else \
{ \
GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
if( global_dscale_out == 'y' ) \
{ \
( post_ops->matrix_mul )->stor_type = NULLTYPE; \
( post_ops->matrix_mul )->matrix = malloc( m * n * sizeof(C_DSCALE_type) ); \
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
} \
else \
{ \
( post_ops->matrix_mul )->stor_type = NULLTYPE; \
( post_ops->matrix_mul )->matrix = malloc( m * n * sizeof(float) ); \
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,float)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
} \
} \
if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \
{ \

View File

@@ -1200,8 +1200,7 @@
// BF16 buffer for matrix add/mul in u8s8s32.
#define BF16_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
scr = _mm512_cvtepi32_ps( \
_mm512_sllv_epi32 \
scr =(__m512)(_mm512_sllv_epi32 \
( \
_mm512_cvtepi16_epi32 \
( \
@@ -1212,8 +1211,7 @@
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
), _mm512_set1_epi32( 16 ) \
) \
); \
) );\
scr = _mm512_mul_ps(scr, scl_fct ); \
#define BF16_MATRIX_ADD_1COL_PAR(mask,scr0,scl_fct0,m_ind) \