Element wise post-op APIs are upgraded with new post-ops

Description:

1. Added new output types for f32 element wise API's to support
   s8, u8, s32 , bf16 outputs.

2. Updated the base f32 API to support all the post-ops supported in
   gemm API's

AMD Internal: [SWLCSG-3384]

Change-Id: I1a7caac76876ddc5a121840b4e585ded37ca81e8
This commit is contained in:
Deepak Negi
2025-02-07 02:25:14 +05:30
committed by Nallani Bhaskar
parent 0bae96d7ac
commit 3a7523b51b
7 changed files with 5847 additions and 1818 deletions

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -182,17 +182,21 @@ AOCL_UTIL_ELTWISE_OPS(bfloat16,bfloat16,bf16obf16)
);
}
AOCL_UTIL_ELTWISE_OPS(float,float,f32of32)
BLIS_INLINE void aocl_eltwise_ops_f32of32_base
(
const char order,
const char transa,
const char transb,
const dim_t m,
const dim_t n,
const float* a,
const dim_t lda,
float* b,
const dim_t ldb,
aocl_post_op* post_op_unparsed,
AOCL_STORAGE_TYPE c_downscale
)
{
AOCL_UTIL_ELTWISE_OPS_CHECK
(
"f32of32",
order, transa, transb,
m, n,
a, lda,
b, ldb
);
trans_t blis_transa;
trans_t blis_transb;
@@ -260,7 +264,7 @@ AOCL_UTIL_ELTWISE_OPS(float,float,f32of32)
a, rs_a, cs_a,
b, rs_b, cs_b,
&rntm_g, lcntx_g,
post_op_list, F32
post_op_list, c_downscale
);
#else
lpgemm_eltwise_ops_f32of32_thread_decorator
@@ -269,7 +273,112 @@ AOCL_UTIL_ELTWISE_OPS(float,float,f32of32)
a, rs_a, cs_a,
b, rs_b, cs_b,
&rntm_g, lcntx_g,
post_op_list, F32
post_op_list, c_downscale
);
#endif
}
AOCL_UTIL_ELTWISE_OPS(float,float,f32of32)
{
AOCL_UTIL_ELTWISE_OPS_CHECK
(
"f32of32",
order, transa, transb,
m, n,
a, lda,
b, ldb
);
aocl_eltwise_ops_f32of32_base
(
order, transa, transb,
m, n,
a, lda,
b, ldb,
post_op_unparsed, F32
);
}
AOCL_UTIL_ELTWISE_OPS(float,bfloat16,f32obf16)
{
AOCL_UTIL_ELTWISE_OPS_CHECK
(
"f32obf16",
order, transa, transb,
m, n,
a, lda,
b, ldb
);
aocl_eltwise_ops_f32of32_base
(
order, transa, transb,
m, n,
a, lda,
(float*)b, ldb,
post_op_unparsed, BF16
);
}
AOCL_UTIL_ELTWISE_OPS(float,int32_t,f32os32)
{
AOCL_UTIL_ELTWISE_OPS_CHECK
(
"f32os32",
order, transa, transb,
m, n,
a, lda,
b, ldb
);
aocl_eltwise_ops_f32of32_base
(
order, transa, transb,
m, n,
a, lda,
(float*)b, ldb,
post_op_unparsed, S32
);
}
AOCL_UTIL_ELTWISE_OPS(float,int8_t,f32os8)
{
AOCL_UTIL_ELTWISE_OPS_CHECK
(
"f32os8",
order, transa, transb,
m, n,
a, lda,
b, ldb
);
aocl_eltwise_ops_f32of32_base
(
order, transa, transb,
m, n,
a, lda,
(float*)b, ldb,
post_op_unparsed, S8
);
}
AOCL_UTIL_ELTWISE_OPS(float,uint8_t,f32ou8)
{
AOCL_UTIL_ELTWISE_OPS_CHECK
(
"f32ou8",
order, transa, transb,
m, n,
a, lda,
b, ldb
);
aocl_eltwise_ops_f32of32_base
(
order, transa, transb,
m, n,
a, lda,
(float*)b, ldb,
post_op_unparsed, U8
);
}

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -56,5 +56,9 @@ BLIS_EXPORT_ADDON void aocl_gemm_eltwise_ops_ ## LP_SFX \
AOCL_UTIL_ELTWISE_OPS(bfloat16,float,bf16of32);
AOCL_UTIL_ELTWISE_OPS(bfloat16,bfloat16,bf16obf16);
AOCL_UTIL_ELTWISE_OPS(float,float,f32of32);
AOCL_UTIL_ELTWISE_OPS(float,bfloat16,f32obf16);
AOCL_UTIL_ELTWISE_OPS(float,int32_t,f32os32);
AOCL_UTIL_ELTWISE_OPS(float,int8_t,f32os8);
AOCL_UTIL_ELTWISE_OPS(float,uint8_t,f32ou8);
#endif // AOCL_ELTWISE_OPS_INTERFACE_H

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -84,10 +84,18 @@ LPGEMM_ELTWISE_OPS_IFACE(float,float,f32of32)
post_ops_attr.is_last_k = TRUE; // Should always be TRUE here.
// Advance the matrix to the right positions based on thread id.
// To note that float and bfloat16 are both handled using this same
// frame, so the strides needs to be updated on the actual b matrix
// datatype or the c_downscale value.
// To note that float, bfloat16, int32_t, int8_t and uint8_t are
// handled using this same frame, so the strides needs to be
// updated on the actual b matrix datatype or the c_downscale value.
dim_t dsize = sizeof( float );
if ( post_ops_attr.c_stor_type == BF16 )
{
dsize = sizeof( bfloat16 );
}
if ( post_ops_attr.c_stor_type == S8 || post_ops_attr.c_stor_type == U8 )
{
dsize = sizeof( int8_t );
}
int8_t* b_i = ( int8_t* )b;
( ( lpgemm_util_post_ops_kernel_f32 )( lcntx->eltwise_ops_kern_fun_ptr ) )

View File

@@ -34,7 +34,18 @@
#include "bench_lpgemm_helpers.h"
GEN_FILL_ARRAY_FUNC(int8_t)
GEN_FILL_ARRAY_FUNC(float)
GEN_FILL_ARRAY_FUNC(int32_t)
void fill_array_uint8_t ( void* arr, dim_t size )
{
if( size < 0 ) return;
uint8_t* temp_arr = ( uint8_t* ) arr;
for ( dim_t i = 0; i < size; ++i )
{
temp_arr[i] = ( uint8_t )( rand() % 5 );
}
}
void print_result
(
@@ -87,31 +98,60 @@ ACCUM_type eltwise_ops_get_temp_accum_ ## LP_SFX \
} \
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32of32)
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32os32)
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32obf16)
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32os8)
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32ou8)
GEN_GET_BIAS_POST_OP_VAL(float,bf16of32)
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16obf16)
GEN_GET_BIAS_POST_OP_VAL(float,f32of32)
GEN_GET_BIAS_POST_OP_VAL(float,f32os32)
GEN_GET_BIAS_POST_OP_VAL(float,f32obf16)
GEN_GET_BIAS_POST_OP_VAL(float,f32os8)
GEN_GET_BIAS_POST_OP_VAL(float,f32ou8)
GEN_GELU_TANH_POSTOP_FLOAT(bf16of32)
GEN_GELU_TANH_POSTOP_FLOAT(bf16obf16)
GEN_GELU_TANH_POSTOP_FLOAT(f32of32)
GEN_GELU_TANH_POSTOP_FLOAT(f32os32)
GEN_GELU_TANH_POSTOP_FLOAT(f32obf16)
GEN_GELU_TANH_POSTOP_FLOAT(f32os8)
GEN_GELU_TANH_POSTOP_FLOAT(f32ou8)
GEN_TANH_POSTOP_FLOAT(bf16of32)
GEN_TANH_POSTOP_FLOAT(bf16obf16)
GEN_TANH_POSTOP_FLOAT(f32of32)
GEN_TANH_POSTOP_FLOAT(f32os32)
GEN_TANH_POSTOP_FLOAT(f32obf16)
GEN_TANH_POSTOP_FLOAT(f32os8)
GEN_TANH_POSTOP_FLOAT(f32ou8)
GEN_GELU_ERF_POSTOP_FLOAT(bf16of32)
GEN_GELU_ERF_POSTOP_FLOAT(bf16obf16)
GEN_GELU_ERF_POSTOP_FLOAT(f32of32)
GEN_GELU_ERF_POSTOP_FLOAT(f32os32)
GEN_GELU_ERF_POSTOP_FLOAT(f32obf16)
GEN_GELU_ERF_POSTOP_FLOAT(f32os8)
GEN_GELU_ERF_POSTOP_FLOAT(f32ou8)
GEN_SWISH_POSTOP_FLOAT(bf16of32)
GEN_SWISH_POSTOP_FLOAT(bf16obf16)
GEN_SWISH_POSTOP_FLOAT(f32of32)
GEN_SWISH_POSTOP_FLOAT(f32os32)
GEN_SWISH_POSTOP_FLOAT(f32obf16)
GEN_SWISH_POSTOP_FLOAT(f32os8)
GEN_SWISH_POSTOP_FLOAT(f32ou8)
GEN_SIGMOID_POSTOP_FLOAT(bf16of32)
GEN_SIGMOID_POSTOP_FLOAT(bf16obf16)
GEN_SIGMOID_POSTOP_FLOAT(f32of32)
GEN_SIGMOID_POSTOP_FLOAT(f32os32)
GEN_SIGMOID_POSTOP_FLOAT(f32obf16)
GEN_SIGMOID_POSTOP_FLOAT(f32os8)
GEN_SIGMOID_POSTOP_FLOAT(f32ou8)
static inline float eltwise_ops_accuracy_check_downscale_bf16of32
(
@@ -174,7 +214,7 @@ static inline float eltwise_ops_accuracy_check_downscale_f32of32
dim_t j
)
{
dim_t j_scale = j;
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
@@ -193,15 +233,140 @@ static inline float eltwise_ops_accuracy_check_downscale_f32of32
return out_temp_accum;
}
static inline float eltwise_ops_accuracy_check_downscale_f32os32
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float zp_float = *( ( float* )( post_op->sum )->zero_point + j_zp );
float out_temp_accum = ( temp_accum *
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
zp_float );
return out_temp_accum;
}
static inline float eltwise_ops_accuracy_check_downscale_f32os8
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float out_temp_accum = \
( float )min( \
max( nearbyintf( ( float )( temp_accum ) * \
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
*( ( float* )( post_op->sum )->zero_point + j_zp ), \
DSCALE_CLIP_MIN ), \
DSCALE_CLIP_MAX ); \
return out_temp_accum;
}
static inline float eltwise_ops_accuracy_check_downscale_f32ou8
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float out_temp_accum = \
( float )min( \
max( nearbyintf( ( float )( temp_accum ) * \
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
*( ( float* )( post_op->sum )->zero_point + j_zp ), \
DSCALE_CLIP_MIN ), \
DSCALE_CLIP_MAX ); \
return out_temp_accum;
}
static inline float eltwise_ops_accuracy_check_downscale_f32obf16
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float zp_float = *( ( float* )( post_op->sum )->zero_point + j_zp );
float out_temp_accum = ( temp_accum *
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
zp_float );
return out_temp_accum;
}
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32ou8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32ou8)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,float)
#define GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(A_type,B_type,ACCUM_type,LP_SFX) \
void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
@@ -388,7 +553,7 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
dim_t scl_fctr_len = ( post_op->matrix_mul )->scale_factor_len; \
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,LP_SFX) \
( ( post_op->matrix_mul )->matrix, i, \
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op->matrix_add )->stor_type ); \
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op->matrix_mul )->stor_type ); \
} \
else \
{} \
@@ -399,13 +564,14 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
( \
&out_temp_accum, &temp_accum \
); \
\
if ( ( ( *( b + ( rs_b * i ) + ( cs_b * j ) ) - out_temp_accum ) > 1.0E-5 ) || \
( (out_temp_accum - *( b + ( rs_b * i ) + ( cs_b * j ) ) ) > 1.0E-5 ) ) \
{ \
float comp_float, ref_float; \
\
float comp_float, ref_float; \
GEN_FUNC_NAME(B_type,_to_float)(*( b + ( rs_b * i ) + ( cs_b * j ) ), &comp_float); \
GEN_FUNC_NAME(B_type,_to_float)(out_temp_accum, &ref_float); \
\
if ( ( ( comp_float - ref_float ) > 1.0E-5 ) || \
( ( ref_float - comp_float ) > 1.0E-5 ) ) \
{ \
if ( fout ) \
{ \
fprintf( fout, "%s Failure input m: %ld, n: %ld," \
@@ -427,6 +593,10 @@ cleanup_acc: \
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(bfloat16,float,float,bf16of32)
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,bf16obf16)
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,float,float,f32of32)
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,bfloat16,float,f32obf16)
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,int32_t,float,f32os32)
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,int8_t,float,f32os8)
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,uint8_t,float,f32ou8)
#define GEN_ELTWISE_OPS_BENCH_DRV_FUNC(A_type,B_type,LP_SFX) \
void eltwise_ops_bench_driver_ ## LP_SFX \
@@ -471,544 +641,19 @@ void eltwise_ops_bench_driver_ ## LP_SFX \
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(bfloat16,float,bf16of32)
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(bfloat16,bfloat16,bf16obf16)
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,float,f32of32)
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,bfloat16,f32obf16)
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,int32_t,f32os32)
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,int8_t,f32os8)
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,uint8_t,f32ou8)
#define GEN_ELTWISE_OPS_POST_OPS_CREATOR(C_DSCALE_type,C_type,DSCALE_type,BLAS_SFX) \
static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
( \
dim_t m, \
dim_t n, \
char* post_ops_str, \
char stor_order \
) \
{ \
if ( ( ( post_ops_str == NULL ) || \
( strcmp( post_ops_str, "none" ) == 0 ) ) && \
( global_dscale_out == 'n' ) ) \
{ \
return NULL; \
} \
\
aocl_post_op* post_ops = NULL; \
post_ops = ( aocl_post_op* ) malloc( sizeof( aocl_post_op ) ); \
\
if ( post_ops == NULL ) \
{ \
return NULL; \
} \
post_ops->eltwise = NULL; \
\
/* Only supporting 8 post ops at max for now.*/ \
dim_t max_post_ops_seq_length = 8; \
post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \
malloc \
( \
max_post_ops_seq_length * \
sizeof( AOCL_POST_OP_TYPE ) \
); \
\
if ( post_ops->seq_vector == NULL ) \
{ \
goto err_handler; \
} \
\
/* Parse post ops list.*/ \
dim_t cur_op_index = 0; \
/* Ensure the buffers that use NULL check in deinit code is properly set to NULL.*/ \
\
/* Bench limitation: can only support 1 bias, but LPGEMM can support
* multiple scale post-ops. */ \
post_ops->bias = NULL; \
post_ops->bias = malloc( sizeof( aocl_post_op_bias ) ); \
if ( post_ops->bias == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->bias )->bias = NULL; \
\
/* Bench limitation: can only support 1 scale, but LPGEMM can support
* multiple scale post-ops. */ \
post_ops->sum = NULL; \
post_ops->sum = malloc( sizeof( aocl_post_op_sum ) ); \
if ( post_ops->sum == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->sum )->scale_factor = NULL; \
( post_ops->sum )->buff = NULL; \
( post_ops->sum )->zero_point = NULL; \
( post_ops->sum )->scale_factor_len = 0; \
( post_ops->sum )->zero_point_len = 0; \
\
/* Bench limitation: can only support 1 matrix add, but LPGEMM can support
* multiple matrix add post-ops. */ \
post_ops->matrix_add = NULL; \
post_ops->matrix_add = malloc( sizeof( aocl_post_op_matrix_add ) ); \
if ( post_ops->matrix_add == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->matrix_add )->matrix = NULL; \
( post_ops->matrix_add )->ldm = 0; \
\
/* Bench limitation: can only support 1 matrix mul, but LPGEMM can support
* multiple matrix mul post-ops. */ \
post_ops->matrix_mul = NULL; \
post_ops->matrix_mul = malloc( sizeof( aocl_post_op_matrix_mul ) ); \
if ( post_ops->matrix_mul == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->matrix_mul )->matrix = NULL; \
( post_ops->matrix_mul )->ldm = 0; \
\
bool is_bias = FALSE; \
bool is_relu = FALSE; \
bool is_param_relu = FALSE; \
bool is_gelu_tanh = FALSE; \
bool is_gelu_erf = FALSE; \
bool is_swish = FALSE; \
bool is_clip = FALSE; \
bool is_scalar_scale = FALSE; \
bool is_scalar_zp = FALSE; \
bool is_matrix_add = FALSE; \
bool is_matrix_mul = FALSE; \
bool is_tanh = FALSE; \
bool is_sigmoid = FALSE; \
bool is_bias_stor_type = FALSE; \
dim_t activator_idx = 0; \
dim_t clip_idx = 0; \
char * bias_stor_type = ""; \
\
/* Post-Ops string parser. */ \
dim_t num_eltwise = 0; \
if ( strcmp( post_ops_str, "none" ) != 0 ) \
{ \
char* ops_tok = strtok(post_ops_str, ", =" ); \
\
/* Ensure only one activator is used as an eltwise post-op.*/ \
bool is_activator_set = FALSE; \
while ( ops_tok ) \
{ \
str_tolower( ops_tok ); \
if ( strcmp( ops_tok, "bias" ) == 0 ) \
{ \
post_ops->seq_vector[cur_op_index] = BIAS; \
ops_tok = strtok( NULL, ", " ); \
if( ( strcmp( ops_tok, "na" ) == 0 ) ) \
{ \
is_bias_stor_type = FALSE; \
} \
else if ( ( strcmp( ops_tok, "f32" ) == 0 ) ) \
{ \
is_bias_stor_type = TRUE; \
bias_stor_type = "F32"; \
} \
else if ( ( strcmp( ops_tok, "bf16" ) == 0 ) ) \
{ \
is_bias_stor_type = TRUE; \
bias_stor_type = "BF16"; \
} \
is_bias = TRUE; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "relu" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_relu = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "prelu" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_param_relu = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "swish" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_swish = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "gelu_tanh" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_gelu_tanh = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "gelu_erf" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_gelu_erf = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( strcmp( ops_tok, "clip" ) == 0 ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_clip = TRUE; \
num_eltwise += 1; \
clip_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "tanh" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_tanh = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( ( strcmp( ops_tok, "sigmoid" ) == 0 ) && \
( is_activator_set == FALSE ) ) \
{ \
post_ops->seq_vector[cur_op_index] = ELTWISE; \
is_sigmoid = TRUE; \
is_activator_set = TRUE; \
num_eltwise += 1; \
activator_idx = cur_op_index; \
cur_op_index++; \
} \
else if ( strcmp( ops_tok, "scale" ) == 0 ) \
{ \
ops_tok = strtok( NULL, ", " ); \
str_tolower( ops_tok ); \
if ( ( strcmp( ops_tok, "scalar" ) == 0 ) || \
( strcmp( ops_tok, "s" ) == 0 ) ) \
{ \
is_scalar_scale = TRUE; \
} \
} \
else if ( strcmp( ops_tok, "zp" ) == 0 ) \
{ \
ops_tok = strtok( NULL, ", " ); \
str_tolower( ops_tok ); \
if ( ( strcmp( ops_tok, "scalar" ) == 0 ) || \
( strcmp( ops_tok, "s" ) == 0 ) ) \
{ \
is_scalar_zp = TRUE; \
} \
} \
else if ( strcmp( ops_tok, "matrix_add" ) == 0 ) \
{ \
post_ops->seq_vector[cur_op_index] = MATRIX_ADD; \
is_matrix_add = TRUE; \
cur_op_index++; \
} \
else if ( strcmp( ops_tok, "matrix_mul" ) == 0 ) \
{ \
post_ops->seq_vector[cur_op_index] = MATRIX_MUL; \
is_matrix_mul = TRUE; \
cur_op_index++; \
} \
\
ops_tok = strtok( NULL, ", =" ); \
} \
} \
\
if ( is_bias == TRUE ) \
{ \
/* Allocate bias buffer, return early if alloc fails.*/ \
( post_ops->bias )->bias = malloc( n * sizeof( C_type ) ); \
if ( ( post_ops->bias )->bias == NULL ) \
{ \
goto err_handler; \
} \
if(is_bias_stor_type == TRUE) \
{ \
if( ( strcmp( bias_stor_type, "BF16" ) == 0 ) ) \
{ \
( post_ops->bias )->stor_type = AOCL_GEMM_BF16; \
} \
else if( ( strcmp( bias_stor_type, "F32" ) == 0 ) ) \
{ \
( post_ops->bias )->stor_type = AOCL_GEMM_F32; \
} \
else {} \
} \
else \
{ \
( post_ops->bias )->stor_type = NULLTYPE ; \
} \
if ( global_dscale_out == 'y' ) \
{ \
GEN_FUNC_NAME(fill_array_post_ops_,C_DSCALE_type)( ( post_ops->bias )->bias, n ); \
} \
else \
{ \
GEN_FUNC_NAME(fill_array_post_ops_,C_type)( ( post_ops->bias )->bias, n ); \
} \
} \
\
if ( num_eltwise > 0 ) \
{ \
if ( num_eltwise > 1 ) \
{ \
if ( activator_idx < clip_idx ) \
{ \
activator_idx = 0; \
clip_idx = 1; \
} \
else \
{ \
activator_idx = 1; \
clip_idx = 0; \
} \
} \
else \
{ \
activator_idx = 0; \
clip_idx = 0; \
} \
\
post_ops->num_eltwise = num_eltwise; \
post_ops->eltwise = malloc( num_eltwise * sizeof( aocl_post_op_eltwise ) ); \
if ( post_ops->eltwise == NULL ) \
{ \
goto err_handler; \
} \
\
/* Only one of relu, prelu, swish, gelu_tanh, gelu_erf allowed as
* an activator. */ \
if ( is_relu == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = RELU; \
} \
else if ( is_param_relu == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
{ \
goto err_handler; \
} \
*( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )6; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = PRELU; \
} \
if ( is_swish == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
{ \
goto err_handler; \
} \
*( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )2; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = SWISH; \
} \
else if ( is_gelu_tanh == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_TANH; \
} \
else if ( is_gelu_erf == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_ERF; \
} \
else if ( is_tanh == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = TANH; \
} \
if ( is_sigmoid == TRUE ) \
{ \
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
( post_ops->eltwise + activator_idx )->algo.algo_type = SIGMOID; \
} \
if ( is_clip == TRUE ) \
{ \
( post_ops->eltwise + clip_idx )->is_power_of_2 = FALSE; \
( post_ops->eltwise + clip_idx )->scale_factor = NULL; \
( post_ops->eltwise + clip_idx )->algo.alpha = NULL; \
( post_ops->eltwise + clip_idx )->algo.beta = NULL; \
( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( DSCALE_type ) ); \
if ( ( post_ops->eltwise + clip_idx )->algo.alpha == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( DSCALE_type ) ); \
if ( ( post_ops->eltwise + clip_idx )->algo.beta == NULL ) \
{ \
goto err_handler; \
} \
*( ( DSCALE_type* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( DSCALE_type ) ( -64 ); \
*( ( DSCALE_type* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( DSCALE_type ) ( 23 ); \
( post_ops->eltwise + clip_idx )->algo.algo_type = CLIP; \
} \
} \
\
if ( global_dscale_out == 'y' ) \
{ \
post_ops->seq_vector[cur_op_index] = SCALE; \
cur_op_index++; \
\
( post_ops->sum )->is_power_of_2 = FALSE; \
if ( global_dscale_out == 'y' ) \
{ \
dim_t n_scale = n; \
if ( is_scalar_scale == TRUE ) \
{ \
n_scale = 1; \
} \
\
dim_t n_zp = n; \
if ( is_scalar_zp == TRUE ) \
{ \
n_zp = 1; \
} \
\
/* Allocate scale buffer, return early if alloc fails.*/ \
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( DSCALE_type ) ); \
if ( ( post_ops->sum )->scale_factor == NULL ) \
{ \
goto err_handler; \
} \
( post_ops->sum )->zero_point = malloc( n_zp * sizeof( C_DSCALE_type ) ); \
if ( ( post_ops->sum )->zero_point == NULL ) \
{ \
goto err_handler; \
} \
\
/* Fill scale factor and zero points.*/ \
DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )( post_ops->sum )->scale_factor; \
for ( dim_t i = 0; i < n_scale; ++i ) \
{ \
temp_dscale_ptr[i] = ( ( DSCALE_type )1 )/ ( ( DSCALE_type )1000 ); \
} \
( post_ops->sum )->scale_factor_len = n_scale; \
\
C_DSCALE_type* temp_dzero_point_ptr = ( C_DSCALE_type* )( post_ops->sum )->zero_point; \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( temp_dzero_point_ptr, n_zp ); \
( post_ops->sum )->zero_point_len = n_zp; \
} \
} \
\
if ( is_matrix_add == TRUE ) \
{ \
/* Allocate bias buffer, return early if alloc fails.*/ \
dim_t ele_dsize = 0; \
if ( global_dscale_out == 'y' ) \
{ \
ele_dsize = sizeof( C_DSCALE_type ); \
} \
else \
{ \
ele_dsize = sizeof( C_type ); \
} \
( post_ops->matrix_add )->matrix = malloc( m * n * ele_dsize ); \
if ( ( post_ops->matrix_add )->matrix == NULL ) \
{ \
goto err_handler; \
} \
if ( global_dscale_out == 'y' ) \
{ \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
else \
{ \
GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
} \
if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \
{ \
( post_ops->matrix_add )->ldm = m; \
} \
else \
{ \
( post_ops->matrix_add )->ldm = n; \
} \
} \
\
if ( is_matrix_mul == TRUE ) \
{ \
/* Allocate bias buffer, return early if alloc fails.*/ \
dim_t ele_dsize = 0; \
if ( global_dscale_out == 'y' ) \
{ \
ele_dsize = sizeof( C_DSCALE_type ); \
} \
else \
{ \
ele_dsize = sizeof( C_type ); \
} \
( post_ops->matrix_mul )->matrix = malloc( m * n * ele_dsize ); \
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
{ \
goto err_handler; \
} \
if ( global_dscale_out == 'y' ) \
{ \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
} \
else \
{ \
GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
} \
if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \
{ \
( post_ops->matrix_mul )->ldm = m; \
} \
else \
{ \
( post_ops->matrix_mul )->ldm = n; \
} \
} \
\
post_ops->seq_length = cur_op_index; \
\
post_ops->pre_ops = NULL; \
\
return post_ops; \
\
err_handler: \
lpgemm_destroy_post_ops_struct( post_ops ); \
return NULL; \
} \
GEN_ELTWISE_OPS_POST_OPS_CREATOR(bfloat16,float,float,bf16of32)
GEN_ELTWISE_OPS_POST_OPS_CREATOR(bfloat16,bfloat16,float,bf16obf16)
GEN_ELTWISE_OPS_POST_OPS_CREATOR(float,float,float,f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,float,bf16of32)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,float,bf16obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(float,int32_t,float,float,f32os32)
GEN_MAT_MUL_POST_OPS_CREATOR(float,int8_t,float,float,f32os8)
GEN_MAT_MUL_POST_OPS_CREATOR(float,uint8_t,float,float,f32ou8)
GEN_MAT_MUL_POST_OPS_CREATOR(float,float,float,float,f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(float,bfloat16,float,float,f32obf16)
#define GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(A_type, B_type, LP_SFX) \
void eltwise_ops_bench_main_ ## LP_SFX \
@@ -1059,7 +704,7 @@ void eltwise_ops_bench_main_ ## LP_SFX \
( strcmp( post_ops_str, "none" ) != 0 ) ) || \
( global_dscale_out == 'y' ) ) \
{ \
post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,LP_SFX)( m, n, post_ops_str, stor_order ); \
post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,LP_SFX)( m, n, 0, post_ops_str, stor_order ); \
if ( post_op == NULL ) \
{ \
printf(" post op struct allocation failure, returning.\n"); \
@@ -1098,6 +743,10 @@ void eltwise_ops_bench_main_ ## LP_SFX \
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(bfloat16,float,bf16of32)
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(bfloat16,bfloat16,bf16obf16)
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,float,f32of32)
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,bfloat16,f32obf16)
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,int32_t,f32os32)
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,int8_t,f32os8)
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,uint8_t,f32ou8)
int main( int argc, char** argv )
{
@@ -1293,6 +942,60 @@ int main( int argc, char** argv )
post_ops_str_dest
);
}
if ( ( strcmp( eltwise_ops_type_str, "f32obf16" ) == 0 ) ||
( strcmp( eltwise_ops_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'y';
GEN_FUNC_NAME(eltwise_ops_bench_main_, f32obf16)
(
fout, stor_order, transa, transb,
m, n, stride_a, stride_b,
post_ops_str_dest
);
}
if ( ( strcmp( eltwise_ops_type_str, "f32os32" ) == 0 ) ||
( strcmp( eltwise_ops_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'n';
DSCALE_CLIP_MIN = INT_MIN;
DSCALE_CLIP_MAX = INT_MAX;
GEN_FUNC_NAME(eltwise_ops_bench_main_, f32os32)
(
fout, stor_order, transa, transb,
m, n, stride_a, stride_b,
post_ops_str_dest
);
}
if ( ( strcmp( eltwise_ops_type_str, "f32os8" ) == 0 ) ||
( strcmp( eltwise_ops_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'y';
DSCALE_CLIP_MIN = -128;
DSCALE_CLIP_MAX = +127;
GEN_FUNC_NAME(eltwise_ops_bench_main_, f32os8)
(
fout, stor_order, transa, transb,
m, n, stride_a, stride_b,
post_ops_str_dest
);
}
if ( ( strcmp( eltwise_ops_type_str, "f32ou8" ) == 0 ) ||
( strcmp( eltwise_ops_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'y';
DSCALE_CLIP_MIN = 0;
DSCALE_CLIP_MAX = +255;
GEN_FUNC_NAME(eltwise_ops_bench_main_, f32ou8)
(
fout, stor_order, transa, transb,
m, n, stride_a, stride_b,
post_ops_str_dest
);
}
}
}

View File

@@ -83,6 +83,92 @@
zmm2 = _mm512_mul_ps(zmm2,alpha); \
zmm3 = _mm512_mul_ps(zmm3,alpha);
// BF16 bias helper macros.
#define BF16_F32_BIAS_LOAD(scr,mask,n_ind) \
scr = ( __m512)( _mm512_sllv_epi32 \
( \
_mm512_cvtepi16_epi32 \
( \
_mm256_maskz_loadu_epi16 \
( \
( mask ), \
( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
), _mm512_set1_epi32( 16 ) \
) \
); \
// F32 bias helper macros.
#define S32_F32_BIAS_LOAD(scr,mask,n_ind) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_maskz_loadu_epi32 \
( \
( mask ), \
( ( int32_t* ) post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
); \
// S8 bias helper macros.
#define S8_F32_BIAS_LOAD(scr,mask,n_ind) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_cvtepi8_epi32 \
( \
_mm_maskz_loadu_epi8 \
( \
( mask ), \
( ( int8_t* )post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
) \
); \
// BF16 bias helper macros.
#define BF16_F32_BIAS_BCAST(scr,mask,m_ind) \
scr = ( __m512)( _mm512_sllv_epi32 \
( \
_mm512_cvtepi16_epi32 \
( \
_mm256_maskz_set1_epi16 \
( \
( mask ), \
*( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_i + m_ind ) \
) \
), _mm512_set1_epi32( 16 ) \
) \
); \
// F32 bias helper macros.
#define S32_F32_BIAS_BCAST(scr,mask,m_ind) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_maskz_set1_epi32 \
( \
( mask ), \
*( ( ( int32_t* ) post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_i + m_ind ) \
) \
); \
// S8 bias helper macros.
#define S8_F32_BIAS_BCAST(scr,mask,m_ind) \
scr = _mm512_cvtepi32_ps \
( \
_mm512_cvtepi8_epi32 \
( \
_mm_maskz_set1_epi8 \
( \
( mask ), \
*( ( ( int8_t* )post_ops_list_temp->op_args1 ) + \
post_ops_attr.post_op_c_i + m_ind ) \
) \
) \
); \
// Matrix Add post-ops helper macros
#define F32_MATRIX_ADD_2COL(scr0,scr1,m_ind,r_ind0,r_ind1) \
zmm ## r_ind0 = _mm512_add_ps( scr0, zmm ## r_ind0 ); \
@@ -99,6 +185,146 @@
zmm ## r_ind2 = _mm512_add_ps( scr2, zmm ## r_ind2 ); \
zmm ## r_ind3 = _mm512_add_ps( scr3, zmm ## r_ind3 ); \
//BF16 matrix_add helper macros.
#define BF16_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
scr = (__m512)( _mm512_sllv_epi32 \
( \
_mm512_cvtepi16_epi32 \
( \
_mm256_maskz_loadu_epi16 \
( \
mask, \
matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
), _mm512_set1_epi32( 16 ) \
) \
); \
scr = _mm512_mul_ps( scr, scl_fct ); \
#define BF16_F32_MATRIX_ADD_2COL(scr0,scr1, \
scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
F32_MATRIX_ADD_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \
#define BF16_F32_MATRIX_ADD_3COL(scr0,scr1,scr2, \
scl_fct0,scl_fct1,scl_fct2,m_ind,r_ind0,r_ind1,r_ind2) \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \
#define BF16_F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3, \
scl_fct0,scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
#define BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3, \
scl_fct0,scl_fct1,scl_fct2,scl_fct3,scr0,scr1,scr2,scr3,m_ind) \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
// S8 matrix_add helper macros.
#define S8_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
scr = _mm512_cvtepi32_ps( \
_mm512_cvtepi8_epi32 \
( \
_mm_maskz_loadu_epi8 \
( \
mask, \
matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
) \
); \
scr = _mm512_mul_round_ps \
( \
( scr ), scl_fct, \
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \
); \
#define S8_F32_MATRIX_ADD_2COL(scr0,scr1, \
scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
F32_MATRIX_ADD_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \
#define S8_F32_MATRIX_ADD_3COL(scr0,scr1,scr2, \
scl_fct0,scl_fct1,scl_fct2,m_ind,r_ind0,r_ind1,r_ind2) \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \
#define S8_F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3, \
scl_fct0,scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
#define S8_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3, \
scl_fct0,scl_fct1,scl_fct2,scl_fct3,scr0,scr1,scr2,scr3,m_ind) \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
// S32 matrix_add helper macros.
#define S32_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
scr = _mm512_cvtepi32_ps ( \
_mm512_maskz_loadu_epi32 \
( \
mask, \
matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
); \
scr = _mm512_mul_round_ps \
( \
( scr ), scl_fct, \
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \
); \
#define S32_F32_MATRIX_ADD_2COL(scr0,scr1, \
scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
F32_MATRIX_ADD_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \
#define S32_F32_MATRIX_ADD_3COL(scr0,scr1,scr2, \
scl_fct0,scl_fct1,scl_fct2,m_ind,r_ind0,r_ind1,r_ind2) \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \
#define S32_F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3, \
scl_fct0,scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
#define S32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3, \
scl_fct0,scl_fct1,scl_fct2,scl_fct3,scr0,scr1,scr2,scr3,m_ind) \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
//F32 matrix_add helper macros.
#define F32_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
scr = _mm512_maskz_loadu_ps \
( \
@@ -153,6 +379,105 @@
zmm ## r_ind2 = _mm512_mul_ps( scr2, zmm ## r_ind2 ); \
zmm ## r_ind3 = _mm512_mul_ps( scr3, zmm ## r_ind3 ); \
#define BF16_F32_MATRIX_MUL_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
BF16_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind); \
#define BF16_F32_MATRIX_MUL_2COL(scr0,scr1, \
scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
F32_MATRIX_MUL_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \
#define BF16_F32_MATRIX_MUL_3COL(scr0,scr1,scr2, \
scl_fct0,scl_fct1,scl_fct2,m_ind,r_ind0,r_ind1,r_ind2) \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \
#define BF16_F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,scl_fct0, \
scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
#define BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3, \
scl_fct0,scl_fct1,scl_fct2,scl_fct3,scr0,scr1,scr2,scr3,m_ind) \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
#define S8_F32_MATRIX_MUL_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
S8_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind); \
#define S8_F32_MATRIX_MUL_2COL(scr0,scr1, \
scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
F32_MATRIX_MUL_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \
#define S8_F32_MATRIX_MUL_3COL(scr0,scr1,scr2, \
scl_fct0,scl_fct1,scl_fct2,m_ind,r_ind0,r_ind1,r_ind2) \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \
#define S8_F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,scl_fct0, \
scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
#define S8_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3, \
scl_fct0,scl_fct1,scl_fct2,scl_fct3,scr0,scr1,scr2,scr3,m_ind) \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
#define S32_F32_MATRIX_MUL_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
S32_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind); \
#define S32_F32_MATRIX_MUL_2COL(scr0,scr1, \
scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
F32_MATRIX_MUL_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \
#define S32_F32_MATRIX_MUL_3COL(scr0,scr1,scr2, \
scl_fct0,scl_fct1,scl_fct2,m_ind,r_ind0,r_ind1,r_ind2) \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \
#define S32_F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,scl_fct0, \
scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
#define S32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3, \
scl_fct0,scl_fct1,scl_fct2,scl_fct3,scr0,scr1,scr2,scr3,m_ind) \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
#define F32_F32_MATRIX_MUL_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
F32_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind); \
@@ -190,6 +515,38 @@
reg = _mm512_mul_ps( reg, selector ); \
reg = _mm512_add_ps( reg, zero_point ); \
// Downscale store bf16 macro
#define CVT_STORE_F32_BF16_POST_OPS_MASK(reg,mask,m_ind,n_ind) \
_mm256_mask_storeu_epi16 \
( \
b_q + ( rs_b * ( ir + m_ind ) ) + ( cs_b * ( jr + n_ind ) ), \
mask, (__m256i) _mm512_cvtneps_pbh( reg ) \
) \
// Downscale store s8 macro
#define CVT_STORE_F32_S8_POST_OPS_MASK(reg,mask,m_ind,n_ind) \
_mm512_mask_cvtsepi32_storeu_epi8 \
( \
b_q + ( rs_b * ( ir + m_ind ) ) + ( cs_b * ( jr + n_ind ) ), \
mask, _mm512_cvtps_epi32( reg ) \
) \
// Downscale store u8 macro
#define CVT_STORE_F32_U8_POST_OPS_MASK(reg,mask,m_ind,n_ind) \
_mm512_mask_cvtusepi32_storeu_epi8 \
( \
b_q + ( rs_b * ( ir + m_ind ) ) + ( cs_b * ( jr + n_ind ) ), \
mask, _mm512_cvtps_epu32( _mm512_max_ps( reg, _mm512_set1_ps( 0 ) ) ) \
) \
// Downscale store f32 macro
#define CVT_STORE_F32_S32_POST_OPS_MASK(reg,mask,m_ind,n_ind) \
_mm512_mask_storeu_epi32 \
( \
b_q + ( rs_b * ( ir + m_ind ) ) + ( cs_b * ( jr + n_ind ) ), \
mask, _mm512_cvtps_epi32 ( reg ) \
) \
/*x_tanh = tanhf(x_tanh) */ \
#define TANH_F32S_AVX512(x_tanh, r, r2, x, z, dn, q) \