Element wise post-op APIs are upgraded with new post-ops

Description:

1. Added new output types for f32 element wise API's to support
   s8, u8, s32 , bf16 outputs.

2. Updated the base f32 API to support all the post-ops supported in
   gemm API's

AMD Internal: [SWLCSG-3384]

Change-Id: I1a7caac76876ddc5a121840b4e585ded37ca81e8
This commit is contained in:
Deepak Negi
2025-02-07 02:25:14 +05:30
committed by Nallani Bhaskar
parent 0bae96d7ac
commit 3a7523b51b
7 changed files with 5847 additions and 1818 deletions

View File

@@ -34,7 +34,18 @@
#include "bench_lpgemm_helpers.h"
GEN_FILL_ARRAY_FUNC(int8_t)
GEN_FILL_ARRAY_FUNC(float)
GEN_FILL_ARRAY_FUNC(int32_t)
void fill_array_uint8_t ( void* arr, dim_t size )
{
if( size < 0 ) return;
uint8_t* temp_arr = ( uint8_t* ) arr;
for ( dim_t i = 0; i < size; ++i )
{
temp_arr[i] = ( uint8_t )( rand() % 5 );
}
}
void print_result
(
@@ -87,31 +98,60 @@ ACCUM_type eltwise_ops_get_temp_accum_ ## LP_SFX \
} \
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32of32)
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32os32)
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32obf16)
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32os8)
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32ou8)
GEN_GET_BIAS_POST_OP_VAL(float,bf16of32)
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16obf16)
GEN_GET_BIAS_POST_OP_VAL(float,f32of32)
GEN_GET_BIAS_POST_OP_VAL(float,f32os32)
GEN_GET_BIAS_POST_OP_VAL(float,f32obf16)
GEN_GET_BIAS_POST_OP_VAL(float,f32os8)
GEN_GET_BIAS_POST_OP_VAL(float,f32ou8)
GEN_GELU_TANH_POSTOP_FLOAT(bf16of32)
GEN_GELU_TANH_POSTOP_FLOAT(bf16obf16)
GEN_GELU_TANH_POSTOP_FLOAT(f32of32)
GEN_GELU_TANH_POSTOP_FLOAT(f32os32)
GEN_GELU_TANH_POSTOP_FLOAT(f32obf16)
GEN_GELU_TANH_POSTOP_FLOAT(f32os8)
GEN_GELU_TANH_POSTOP_FLOAT(f32ou8)
GEN_TANH_POSTOP_FLOAT(bf16of32)
GEN_TANH_POSTOP_FLOAT(bf16obf16)
GEN_TANH_POSTOP_FLOAT(f32of32)
GEN_TANH_POSTOP_FLOAT(f32os32)
GEN_TANH_POSTOP_FLOAT(f32obf16)
GEN_TANH_POSTOP_FLOAT(f32os8)
GEN_TANH_POSTOP_FLOAT(f32ou8)
GEN_GELU_ERF_POSTOP_FLOAT(bf16of32)
GEN_GELU_ERF_POSTOP_FLOAT(bf16obf16)
GEN_GELU_ERF_POSTOP_FLOAT(f32of32)
GEN_GELU_ERF_POSTOP_FLOAT(f32os32)
GEN_GELU_ERF_POSTOP_FLOAT(f32obf16)
GEN_GELU_ERF_POSTOP_FLOAT(f32os8)
GEN_GELU_ERF_POSTOP_FLOAT(f32ou8)
GEN_SWISH_POSTOP_FLOAT(bf16of32)
GEN_SWISH_POSTOP_FLOAT(bf16obf16)
GEN_SWISH_POSTOP_FLOAT(f32of32)
GEN_SWISH_POSTOP_FLOAT(f32os32)
GEN_SWISH_POSTOP_FLOAT(f32obf16)
GEN_SWISH_POSTOP_FLOAT(f32os8)
GEN_SWISH_POSTOP_FLOAT(f32ou8)
GEN_SIGMOID_POSTOP_FLOAT(bf16of32)
GEN_SIGMOID_POSTOP_FLOAT(bf16obf16)
GEN_SIGMOID_POSTOP_FLOAT(f32of32)
GEN_SIGMOID_POSTOP_FLOAT(f32os32)
GEN_SIGMOID_POSTOP_FLOAT(f32obf16)
GEN_SIGMOID_POSTOP_FLOAT(f32os8)
GEN_SIGMOID_POSTOP_FLOAT(f32ou8)
static inline float eltwise_ops_accuracy_check_downscale_bf16of32
(
@@ -174,7 +214,7 @@ static inline float eltwise_ops_accuracy_check_downscale_f32of32
dim_t j
)
{
dim_t j_scale = j;
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
@@ -193,15 +233,140 @@ static inline float eltwise_ops_accuracy_check_downscale_f32of32
return out_temp_accum;
}
static inline float eltwise_ops_accuracy_check_downscale_f32os32
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float zp_float = *( ( float* )( post_op->sum )->zero_point + j_zp );
float out_temp_accum = ( temp_accum *
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
zp_float );
return out_temp_accum;
}
static inline float eltwise_ops_accuracy_check_downscale_f32os8
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float out_temp_accum = \
( float )min( \
max( nearbyintf( ( float )( temp_accum ) * \
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
*( ( float* )( post_op->sum )->zero_point + j_zp ), \
DSCALE_CLIP_MIN ), \
DSCALE_CLIP_MAX ); \
return out_temp_accum;
}
static inline float eltwise_ops_accuracy_check_downscale_f32ou8
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float out_temp_accum = \
( float )min( \
max( nearbyintf( ( float )( temp_accum ) * \
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
*( ( float* )( post_op->sum )->zero_point + j_zp ), \
DSCALE_CLIP_MIN ), \
DSCALE_CLIP_MAX ); \
return out_temp_accum;
}
static inline float eltwise_ops_accuracy_check_downscale_f32obf16
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float zp_float = *( ( float* )( post_op->sum )->zero_point + j_zp );
float out_temp_accum = ( temp_accum *
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
zp_float );
return out_temp_accum;
}
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32ou8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32ou8)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,float)
#define GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(A_type,B_type,ACCUM_type,LP_SFX) \
void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
@@ -388,7 +553,7 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
dim_t scl_fctr_len = ( post_op->matrix_mul )->scale_factor_len; \
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,LP_SFX) \
( ( post_op->matrix_mul )->matrix, i, \
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op->matrix_add )->stor_type ); \
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op->matrix_mul )->stor_type ); \
} \
else \
{} \
@@ -399,13 +564,14 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
( \
&out_temp_accum, &temp_accum \
); \
\
if ( ( ( *( b + ( rs_b * i ) + ( cs_b * j ) ) - out_temp_accum ) > 1.0E-5 ) || \
( (out_temp_accum - *( b + ( rs_b * i ) + ( cs_b * j ) ) ) > 1.0E-5 ) ) \
{ \
float comp_float, ref_float; \
\
float comp_float, ref_float; \
GEN_FUNC_NAME(B_type,_to_float)(*( b + ( rs_b * i ) + ( cs_b * j ) ), &comp_float); \
GEN_FUNC_NAME(B_type,_to_float)(out_temp_accum, &ref_float); \
\
if ( ( ( comp_float - ref_float ) > 1.0E-5 ) || \
( ( ref_float - comp_float ) > 1.0E-5 ) ) \
{ \
if ( fout ) \
{ \
fprintf( fout, "%s Failure input m: %ld, n: %ld," \
@@ -427,6 +593,10 @@ cleanup_acc: \
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(bfloat16,float,float,bf16of32)
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,bf16obf16)
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,float,float,f32of32)
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,bfloat16,float,f32obf16)
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,int32_t,float,f32os32)
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,int8_t,float,f32os8)
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,uint8_t,float,f32ou8)
#define GEN_ELTWISE_OPS_BENCH_DRV_FUNC(A_type,B_type,LP_SFX) \
void eltwise_ops_bench_driver_ ## LP_SFX \
@@ -471,544 +641,19 @@ void eltwise_ops_bench_driver_ ## LP_SFX \
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(bfloat16,float,bf16of32)
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(bfloat16,bfloat16,bf16obf16)
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,float,f32of32)
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,bfloat16,f32obf16)
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,int32_t,f32os32)
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,int8_t,f32os8)
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,uint8_t,f32ou8)
#define GEN_ELTWISE_OPS_POST_OPS_CREATOR(C_DSCALE_type,C_type,DSCALE_type,BLAS_SFX) \
static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
( \
dim_t m, \
dim_t n, \
char* post_ops_str, \
char stor_order \
) \
{ \
if ( ( ( post_ops_str == NULL ) || \
( strcmp( post_ops_str, "none" ) == 0 ) ) && \
( global_dscale_out == 'n' ) ) \
{ \
return NULL; \
} \
\
aocl_post_op* post_ops = NULL; \
post_ops = ( aocl_post_op* ) malloc( sizeof( aocl_post_op ) ); \
\
if ( post_ops == NULL ) \
{ \
return NULL; \
} \
post_ops->eltwise = NULL; \
\
/* Only supporting 8 post ops at max for now.*/ \
dim_t max_post_ops_seq_length = 8; \
post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \
malloc \
( \
max_post_ops_seq_length * \
sizeof( AOCL_POST_OP_TYPE ) \
); \
\
if ( post_ops->seq_vector == NULL ) \
{ \
goto err_handler; \
} \
\
/* Parse post ops list.*/ \
dim_t cur_op_index = 0; \
/* Ensure the buffers that use NULL check in deinit code is properly set to NULL.*/ \
\
/* Bench limitation: can only support 1 bias, but LPGEMM can support
* multiple scale post-ops. */ \
post_ops->bias = NULL; \
post_ops->bias = malloc( sizeof( aocl_post_op_bias ) ); \
if ( post_ops->bias == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->bias )->bias = NULL; \
\
/* Bench limitation: can only support 1 scale, but LPGEMM can support
* multiple scale post-ops. */ \
post_ops->sum = NULL; \
post_ops->sum = malloc( sizeof( aocl_post_op_sum ) ); \
if ( post_ops->sum == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->sum )->scale_factor = NULL; \
( post_ops->sum )->buff = NULL; \
( post_ops->sum )->zero_point = NULL; \
( post_ops->sum )->scale_factor_len = 0; \
( post_ops->sum )->zero_point_len = 0; \
\
/* Bench limitation: can only support 1 matrix add, but LPGEMM can support
* multiple matrix add post-ops. */ \
post_ops->matrix_add = NULL; \
post_ops->matrix_add = malloc( sizeof( aocl_post_op_matrix_add ) ); \
if ( post_ops->matrix_add == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->matrix_add )->matrix = NULL; \
( post_ops->matrix_add )->ldm = 0; \
\
/* Bench limitation: can only support 1 matrix mul, but LPGEMM can support
* multiple matrix mul post-ops. */ \
post_ops->matrix_mul = NULL; \
post_ops->matrix_mul = malloc( sizeof( aocl_post_op_matrix_mul ) ); \
if ( post_ops->matrix_mul == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->matrix_mul )->matrix = NULL; \
( post_ops->matrix_mul )->ldm = 0; \
\
bool is_bias = FALSE; \
bool is_relu = FALSE; \
bool is_param_relu = FALSE; \
bool is_gelu_tanh = FALSE; \
bool is_gelu_erf = FALSE; \
bool is_swish = FALSE; \
bool is_clip = FALSE; \
bool is_scalar_scale = FALSE; \
bool is_scalar_zp = FALSE; \
bool is_matrix_add = FALSE; \
bool is_matrix_mul = FALSE; \
bool is_tanh = FALSE; \
bool is_sigmoid = FALSE; \
bool is_bias_stor_type = FALSE; \
dim_t activator_idx = 0; \
dim_t clip_idx = 0; \
char * bias_stor_type = ""; \
\
/* Post-Ops string parser. */ \
dim_t num_eltwise = 0; \
if ( strcmp( post_ops_str, "none" ) != 0 ) \
{ \
char* ops_tok = strtok(post_ops_str, ", =" ); \
\
/* Ensure only one activator is used as an eltwise post-op.*/ \
bool is_activator_set = FALSE; \
while ( ops_tok ) \
{ \
str_tolower( ops_tok ); \
if ( strcmp( ops_tok, "bias" ) == 0 ) \
{ \
post_ops->seq_vector[cur_op_index] = BIAS; \
ops_tok = strtok( NULL, ", " ); \
if( ( strcmp( ops_tok, "na" ) == 0 ) ) \
{ \
is_bias_stor_type = FALSE; \
} \
else if ( ( strcmp( ops_tok, "f32" ) == 0 ) ) \
{ \
is_bias_stor_type = TRUE; \
bias_stor_type = "F32"; \
} \
else if ( ( strcmp( ops_tok, "bf16" ) == 0 ) ) \
{ \
is_bias_stor_type = TRUE; \
bias_stor_type = "BF16"; \
} \
is_bias = TRUE; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "relu" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_relu = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "prelu" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_param_relu = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "swish" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_swish = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "gelu_tanh" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_gelu_tanh = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "gelu_erf" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_gelu_erf = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( strcmp( ops_tok, "clip" ) == 0 ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_clip = TRUE; \
num_eltwise += 1; \
clip_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "tanh" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_tanh = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "sigmoid" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_sigmoid = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( strcmp( ops_tok, "scale" ) == 0 ) \
{ \
ops_tok = strtok( NULL, ", " ); \
str_tolower( ops_tok ); \
if ( ( strcmp( ops_tok, "scalar" ) == 0 ) || \
( strcmp( ops_tok, "s" ) == 0 ) ) \
{ \
is_scalar_scale = TRUE; \
} \
} \
else if ( strcmp( ops_tok, "zp" ) == 0 ) \
{ \
ops_tok = strtok( NULL, ", " ); \
str_tolower( ops_tok ); \
if ( ( strcmp( ops_tok, "scalar" ) == 0 ) || \
( strcmp( ops_tok, "s" ) == 0 ) ) \
{ \
is_scalar_zp = TRUE; \
} \
} \
else if ( strcmp( ops_tok, "matrix_add" ) == 0 ) \
{ \
post_ops->seq_vector[cur_op_index] = MATRIX_ADD; \
is_matrix_add = TRUE; \
cur_op_index++; \
} \
else if ( strcmp( ops_tok, "matrix_mul" ) == 0 ) \
{ \
post_ops->seq_vector[cur_op_index] = MATRIX_MUL; \
is_matrix_mul = TRUE; \
cur_op_index++; \
} \
\
ops_tok = strtok( NULL, ", =" ); \
} \
} \
\
if ( is_bias == TRUE ) \
{ \
/* Allocate bias buffer, return early if alloc fails.*/ \
( post_ops->bias )->bias = malloc( n * sizeof( C_type ) ); \
if ( ( post_ops->bias )->bias == NULL ) \
{ \
goto err_handler; \
} \
if(is_bias_stor_type == TRUE) \
{ \
if( ( strcmp( bias_stor_type, "BF16" ) == 0 ) ) \
{ \
( post_ops->bias )->stor_type = AOCL_GEMM_BF16; \
} \
else if( ( strcmp( bias_stor_type, "F32" ) == 0 ) ) \
{ \
( post_ops->bias )->stor_type = AOCL_GEMM_F32; \
} \
else {} \
} \
else \
{ \
( post_ops->bias )->stor_type = NULLTYPE ; \
} \
if ( global_dscale_out == 'y' ) \
{ \
GEN_FUNC_NAME(fill_array_post_ops_,C_DSCALE_type)( ( post_ops->bias )->bias, n ); \
} \
else \
{ \
GEN_FUNC_NAME(fill_array_post_ops_,C_type)( ( post_ops->bias )->bias, n ); \
} \
} \
\
if ( num_eltwise > 0 ) \
{ \
if ( num_eltwise > 1 ) \
{ \
if ( activator_idx < clip_idx ) \
{ \
activator_idx = 0; \
clip_idx = 1; \
} \
else \
{ \
activator_idx = 1; \
clip_idx = 0; \
} \
} \
else \
{ \
activator_idx = 0; \
clip_idx = 0; \
} \
\
post_ops->num_eltwise = num_eltwise; \
post_ops->eltwise = malloc( num_eltwise * sizeof( aocl_post_op_eltwise ) ); \
if ( post_ops->eltwise == NULL ) \
{ \
goto err_handler; \
} \
\
/* Only one of relu, prelu, swish, gelu_tanh, gelu_erf allowed as
* an activator. */ \
if ( is_relu == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = RELU; \
} \
else if ( is_param_relu == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
{ \
goto err_handler; \
} \
*( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )6; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = PRELU; \
} \
if ( is_swish == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
{ \
goto err_handler; \
} \
*( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )2; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = SWISH; \
} \
else if ( is_gelu_tanh == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_TANH; \
} \
else if ( is_gelu_erf == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_ERF; \
} \
else if ( is_tanh == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = TANH; \
} \
if ( is_sigmoid == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = SIGMOID; \
} \
if ( is_clip == TRUE ) \
{ \
( post_ops->eltwise + clip_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + clip_idx )->scale_factor = NULL; \
( post_ops->eltwise + clip_idx )->algo.alpha = NULL; \
( post_ops->eltwise + clip_idx )->algo.beta = NULL; \
( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( DSCALE_type ) ); \
if ( ( post_ops->eltwise + clip_idx )->algo.alpha == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( DSCALE_type ) ); \
if ( ( post_ops->eltwise + clip_idx )->algo.beta == NULL ) \
{ \
goto err_handler; \
} \
*( ( DSCALE_type* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( DSCALE_type ) ( -64 ); \
*( ( DSCALE_type* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( DSCALE_type ) ( 23 ); \
( post_ops->eltwise + clip_idx )->algo.algo_type = CLIP; \
} \
} \
\
if ( global_dscale_out == 'y' ) \
{ \
post_ops->seq_vector[cur_op_index] = SCALE; \
cur_op_index++; \
\
( post_ops->sum )->is_power_of_2 = FALSE; \
if ( global_dscale_out == 'y' ) \
{ \
dim_t n_scale = n; \
if ( is_scalar_scale == TRUE ) \
{ \
n_scale = 1; \
} \
\
dim_t n_zp = n; \
if ( is_scalar_zp == TRUE ) \
{ \
n_zp = 1; \
} \
\
/* Allocate scale buffer, return early if alloc fails.*/ \
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( DSCALE_type ) ); \
if ( ( post_ops->sum )->scale_factor == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->sum )->zero_point = malloc( n_zp * sizeof( C_DSCALE_type ) ); \
if ( ( post_ops->sum )->zero_point == NULL ) \
{ \
goto err_handler; \
} \
\
/* Fill scale factor and zero points.*/ \
DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )( post_ops->sum )->scale_factor; \
for ( dim_t i = 0; i < n_scale; ++i ) \
{ \
temp_dscale_ptr[i] = ( ( DSCALE_type )1 )/ ( ( DSCALE_type )1000 ); \
} \
( post_ops->sum )->scale_factor_len = n_scale; \
\
C_DSCALE_type* temp_dzero_point_ptr = ( C_DSCALE_type* )( post_ops->sum )->zero_point; \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( temp_dzero_point_ptr, n_zp ); \
( post_ops->sum )->zero_point_len = n_zp; \
} \
} \
\
if ( is_matrix_add == TRUE ) \
{ \
/* Allocate bias buffer, return early if alloc fails.*/ \
dim_t ele_dsize = 0; \
if ( global_dscale_out == 'y' ) \
{ \
ele_dsize = sizeof( C_DSCALE_type ); \
} \
else \
{ \
ele_dsize = sizeof( C_type ); \
} \
( post_ops->matrix_add )->matrix = malloc( m * n * ele_dsize ); \
if ( ( post_ops->matrix_add )->matrix == NULL ) \
{ \
goto err_handler; \
} \
if ( global_dscale_out == 'y' ) \
{ \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
else \
{ \
GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \
{ \
( post_ops->matrix_add )->ldm = m; \
} \
else \
{ \
( post_ops->matrix_add )->ldm = n; \
} \
} \
\
if ( is_matrix_mul == TRUE ) \
{ \
/* Allocate bias buffer, return early if alloc fails.*/ \
dim_t ele_dsize = 0; \
if ( global_dscale_out == 'y' ) \
{ \
ele_dsize = sizeof( C_DSCALE_type ); \
} \
else \
{ \
ele_dsize = sizeof( C_type ); \
} \
( post_ops->matrix_mul )->matrix = malloc( m * n * ele_dsize ); \
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
{ \
goto err_handler; \
} \
if ( global_dscale_out == 'y' ) \
{ \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
} \
else \
{ \
GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
} \
if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \
{ \
( post_ops->matrix_mul )->ldm = m; \
} \
else \
{ \
( post_ops->matrix_mul )->ldm = n; \
} \
} \
\
post_ops->seq_length = cur_op_index; \
\
post_ops->pre_ops = NULL; \
\
return post_ops; \
\
err_handler: \
lpgemm_destroy_post_ops_struct( post_ops ); \
return NULL; \
} \
GEN_ELTWISE_OPS_POST_OPS_CREATOR(bfloat16,float,float,bf16of32)
GEN_ELTWISE_OPS_POST_OPS_CREATOR(bfloat16,bfloat16,float,bf16obf16)
GEN_ELTWISE_OPS_POST_OPS_CREATOR(float,float,float,f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,float,bf16of32)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,float,bf16obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(float,int32_t,float,float,f32os32)
GEN_MAT_MUL_POST_OPS_CREATOR(float,int8_t,float,float,f32os8)
GEN_MAT_MUL_POST_OPS_CREATOR(float,uint8_t,float,float,f32ou8)
GEN_MAT_MUL_POST_OPS_CREATOR(float,float,float,float,f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(float,bfloat16,float,float,f32obf16)
#define GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(A_type, B_type, LP_SFX) \
void eltwise_ops_bench_main_ ## LP_SFX \
@@ -1059,7 +704,7 @@ void eltwise_ops_bench_main_ ## LP_SFX \
( strcmp( post_ops_str, "none" ) != 0 ) ) || \
( global_dscale_out == 'y' ) ) \
{ \
post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,LP_SFX)( m, n, post_ops_str, stor_order ); \
post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,LP_SFX)( m, n, 0, post_ops_str, stor_order ); \
if ( post_op == NULL ) \
{ \
printf(" post op struct allocation failure, returning.\n"); \
@@ -1098,6 +743,10 @@ void eltwise_ops_bench_main_ ## LP_SFX \
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(bfloat16,float,bf16of32)
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(bfloat16,bfloat16,bf16obf16)
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,float,f32of32)
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,bfloat16,f32obf16)
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,int32_t,f32os32)
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,int8_t,f32os8)
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,uint8_t,f32ou8)
int main( int argc, char** argv )
{
@@ -1293,6 +942,60 @@ int main( int argc, char** argv )
post_ops_str_dest
);
}
if ( ( strcmp( eltwise_ops_type_str, "f32obf16" ) == 0 ) ||
( strcmp( eltwise_ops_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'y';
GEN_FUNC_NAME(eltwise_ops_bench_main_, f32obf16)
(
fout, stor_order, transa, transb,
m, n, stride_a, stride_b,
post_ops_str_dest
);
}
if ( ( strcmp( eltwise_ops_type_str, "f32os32" ) == 0 ) ||
( strcmp( eltwise_ops_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'n';
DSCALE_CLIP_MIN = INT_MIN;
DSCALE_CLIP_MAX = INT_MAX;
GEN_FUNC_NAME(eltwise_ops_bench_main_, f32os32)
(
fout, stor_order, transa, transb,
m, n, stride_a, stride_b,
post_ops_str_dest
);
}
if ( ( strcmp( eltwise_ops_type_str, "f32os8" ) == 0 ) ||
( strcmp( eltwise_ops_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'y';
DSCALE_CLIP_MIN = -128;
DSCALE_CLIP_MAX = +127;
GEN_FUNC_NAME(eltwise_ops_bench_main_, f32os8)
(
fout, stor_order, transa, transb,
m, n, stride_a, stride_b,
post_ops_str_dest
);
}
if ( ( strcmp( eltwise_ops_type_str, "f32ou8" ) == 0 ) ||
( strcmp( eltwise_ops_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'y';
DSCALE_CLIP_MIN = 0;
DSCALE_CLIP_MAX = +255;
GEN_FUNC_NAME(eltwise_ops_bench_main_, f32ou8)
(
fout, stor_order, transa, transb,
m, n, stride_a, stride_b,
post_ops_str_dest
);
}
}
}