mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Element wise post-op APIs are upgraded with new post-ops
Description: 1. Added new output types for f32 element wise API's to support s8, u8, s32 , bf16 outputs. 2. Updated the base f32 API to support all the post-ops supported in gemm API's AMD Internal: [SWLCSG-3384] Change-Id: I1a7caac76876ddc5a121840b4e585ded37ca81e8
This commit is contained in:
committed by
Nallani Bhaskar
parent
0bae96d7ac
commit
3a7523b51b
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -182,17 +182,21 @@ AOCL_UTIL_ELTWISE_OPS(bfloat16,bfloat16,bf16obf16)
|
||||
);
|
||||
}
|
||||
|
||||
AOCL_UTIL_ELTWISE_OPS(float,float,f32of32)
|
||||
BLIS_INLINE void aocl_eltwise_ops_f32of32_base
|
||||
(
|
||||
const char order,
|
||||
const char transa,
|
||||
const char transb,
|
||||
const dim_t m,
|
||||
const dim_t n,
|
||||
const float* a,
|
||||
const dim_t lda,
|
||||
float* b,
|
||||
const dim_t ldb,
|
||||
aocl_post_op* post_op_unparsed,
|
||||
AOCL_STORAGE_TYPE c_downscale
|
||||
)
|
||||
{
|
||||
AOCL_UTIL_ELTWISE_OPS_CHECK
|
||||
(
|
||||
"f32of32",
|
||||
order, transa, transb,
|
||||
m, n,
|
||||
a, lda,
|
||||
b, ldb
|
||||
);
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
@@ -260,7 +264,7 @@ AOCL_UTIL_ELTWISE_OPS(float,float,f32of32)
|
||||
a, rs_a, cs_a,
|
||||
b, rs_b, cs_b,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
post_op_list, c_downscale
|
||||
);
|
||||
#else
|
||||
lpgemm_eltwise_ops_f32of32_thread_decorator
|
||||
@@ -269,7 +273,112 @@ AOCL_UTIL_ELTWISE_OPS(float,float,f32of32)
|
||||
a, rs_a, cs_a,
|
||||
b, rs_b, cs_b,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
post_op_list, c_downscale
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
AOCL_UTIL_ELTWISE_OPS(float,float,f32of32)
|
||||
{
|
||||
AOCL_UTIL_ELTWISE_OPS_CHECK
|
||||
(
|
||||
"f32of32",
|
||||
order, transa, transb,
|
||||
m, n,
|
||||
a, lda,
|
||||
b, ldb
|
||||
);
|
||||
|
||||
aocl_eltwise_ops_f32of32_base
|
||||
(
|
||||
order, transa, transb,
|
||||
m, n,
|
||||
a, lda,
|
||||
b, ldb,
|
||||
post_op_unparsed, F32
|
||||
);
|
||||
}
|
||||
|
||||
AOCL_UTIL_ELTWISE_OPS(float,bfloat16,f32obf16)
|
||||
{
|
||||
AOCL_UTIL_ELTWISE_OPS_CHECK
|
||||
(
|
||||
"f32obf16",
|
||||
order, transa, transb,
|
||||
m, n,
|
||||
a, lda,
|
||||
b, ldb
|
||||
);
|
||||
|
||||
aocl_eltwise_ops_f32of32_base
|
||||
(
|
||||
order, transa, transb,
|
||||
m, n,
|
||||
a, lda,
|
||||
(float*)b, ldb,
|
||||
post_op_unparsed, BF16
|
||||
);
|
||||
}
|
||||
|
||||
AOCL_UTIL_ELTWISE_OPS(float,int32_t,f32os32)
|
||||
{
|
||||
AOCL_UTIL_ELTWISE_OPS_CHECK
|
||||
(
|
||||
"f32os32",
|
||||
order, transa, transb,
|
||||
m, n,
|
||||
a, lda,
|
||||
b, ldb
|
||||
);
|
||||
|
||||
aocl_eltwise_ops_f32of32_base
|
||||
(
|
||||
order, transa, transb,
|
||||
m, n,
|
||||
a, lda,
|
||||
(float*)b, ldb,
|
||||
post_op_unparsed, S32
|
||||
);
|
||||
}
|
||||
|
||||
AOCL_UTIL_ELTWISE_OPS(float,int8_t,f32os8)
|
||||
{
|
||||
AOCL_UTIL_ELTWISE_OPS_CHECK
|
||||
(
|
||||
"f32os8",
|
||||
order, transa, transb,
|
||||
m, n,
|
||||
a, lda,
|
||||
b, ldb
|
||||
);
|
||||
|
||||
aocl_eltwise_ops_f32of32_base
|
||||
(
|
||||
order, transa, transb,
|
||||
m, n,
|
||||
a, lda,
|
||||
(float*)b, ldb,
|
||||
post_op_unparsed, S8
|
||||
);
|
||||
}
|
||||
|
||||
AOCL_UTIL_ELTWISE_OPS(float,uint8_t,f32ou8)
|
||||
{
|
||||
AOCL_UTIL_ELTWISE_OPS_CHECK
|
||||
(
|
||||
"f32ou8",
|
||||
order, transa, transb,
|
||||
m, n,
|
||||
a, lda,
|
||||
b, ldb
|
||||
);
|
||||
|
||||
aocl_eltwise_ops_f32of32_base
|
||||
(
|
||||
order, transa, transb,
|
||||
m, n,
|
||||
a, lda,
|
||||
(float*)b, ldb,
|
||||
post_op_unparsed, U8
|
||||
);
|
||||
}
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -56,5 +56,9 @@ BLIS_EXPORT_ADDON void aocl_gemm_eltwise_ops_ ## LP_SFX \
|
||||
AOCL_UTIL_ELTWISE_OPS(bfloat16,float,bf16of32);
|
||||
AOCL_UTIL_ELTWISE_OPS(bfloat16,bfloat16,bf16obf16);
|
||||
AOCL_UTIL_ELTWISE_OPS(float,float,f32of32);
|
||||
AOCL_UTIL_ELTWISE_OPS(float,bfloat16,f32obf16);
|
||||
AOCL_UTIL_ELTWISE_OPS(float,int32_t,f32os32);
|
||||
AOCL_UTIL_ELTWISE_OPS(float,int8_t,f32os8);
|
||||
AOCL_UTIL_ELTWISE_OPS(float,uint8_t,f32ou8);
|
||||
|
||||
#endif // AOCL_ELTWISE_OPS_INTERFACE_H
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -84,10 +84,18 @@ LPGEMM_ELTWISE_OPS_IFACE(float,float,f32of32)
|
||||
post_ops_attr.is_last_k = TRUE; // Should always be TRUE here.
|
||||
|
||||
// Advance the matrix to the right positions based on thread id.
|
||||
// To note that float and bfloat16 are both handled using this same
|
||||
// frame, so the strides needs to be updated on the actual b matrix
|
||||
// datatype or the c_downscale value.
|
||||
// To note that float, bfloat16, int32_t, int8_t and uint8_t are
|
||||
// handled using this same frame, so the strides needs to be
|
||||
// updated on the actual b matrix datatype or the c_downscale value.
|
||||
dim_t dsize = sizeof( float );
|
||||
if ( post_ops_attr.c_stor_type == BF16 )
|
||||
{
|
||||
dsize = sizeof( bfloat16 );
|
||||
}
|
||||
if ( post_ops_attr.c_stor_type == S8 || post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
dsize = sizeof( int8_t );
|
||||
}
|
||||
int8_t* b_i = ( int8_t* )b;
|
||||
|
||||
( ( lpgemm_util_post_ops_kernel_f32 )( lcntx->eltwise_ops_kern_fun_ptr ) )
|
||||
|
||||
@@ -34,7 +34,18 @@
|
||||
|
||||
#include "bench_lpgemm_helpers.h"
|
||||
|
||||
GEN_FILL_ARRAY_FUNC(int8_t)
|
||||
GEN_FILL_ARRAY_FUNC(float)
|
||||
GEN_FILL_ARRAY_FUNC(int32_t)
|
||||
void fill_array_uint8_t ( void* arr, dim_t size )
|
||||
{
|
||||
if( size < 0 ) return;
|
||||
uint8_t* temp_arr = ( uint8_t* ) arr;
|
||||
for ( dim_t i = 0; i < size; ++i )
|
||||
{
|
||||
temp_arr[i] = ( uint8_t )( rand() % 5 );
|
||||
}
|
||||
}
|
||||
|
||||
void print_result
|
||||
(
|
||||
@@ -87,31 +98,60 @@ ACCUM_type eltwise_ops_get_temp_accum_ ## LP_SFX \
|
||||
} \
|
||||
|
||||
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32of32)
|
||||
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32os32)
|
||||
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32obf16)
|
||||
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32os8)
|
||||
GEN_ELTWISE_OPS_GET_TEMP_ACCUM_F(float,float,f32ou8)
|
||||
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,bf16of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,f32os32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,f32obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,f32os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,f32ou8)
|
||||
|
||||
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(bf16of32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(bf16obf16)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(f32of32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(f32os32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(f32obf16)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(f32os8)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(f32ou8)
|
||||
|
||||
GEN_TANH_POSTOP_FLOAT(bf16of32)
|
||||
GEN_TANH_POSTOP_FLOAT(bf16obf16)
|
||||
GEN_TANH_POSTOP_FLOAT(f32of32)
|
||||
GEN_TANH_POSTOP_FLOAT(f32os32)
|
||||
GEN_TANH_POSTOP_FLOAT(f32obf16)
|
||||
GEN_TANH_POSTOP_FLOAT(f32os8)
|
||||
GEN_TANH_POSTOP_FLOAT(f32ou8)
|
||||
|
||||
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(bf16of32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(bf16obf16)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(f32of32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(f32os32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(f32obf16)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(f32os8)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(f32ou8)
|
||||
|
||||
GEN_SWISH_POSTOP_FLOAT(bf16of32)
|
||||
GEN_SWISH_POSTOP_FLOAT(bf16obf16)
|
||||
GEN_SWISH_POSTOP_FLOAT(f32of32)
|
||||
GEN_SWISH_POSTOP_FLOAT(f32os32)
|
||||
GEN_SWISH_POSTOP_FLOAT(f32obf16)
|
||||
GEN_SWISH_POSTOP_FLOAT(f32os8)
|
||||
GEN_SWISH_POSTOP_FLOAT(f32ou8)
|
||||
|
||||
GEN_SIGMOID_POSTOP_FLOAT(bf16of32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(bf16obf16)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(f32of32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(f32os32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(f32obf16)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(f32os8)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(f32ou8)
|
||||
|
||||
static inline float eltwise_ops_accuracy_check_downscale_bf16of32
|
||||
(
|
||||
@@ -174,7 +214,7 @@ static inline float eltwise_ops_accuracy_check_downscale_f32of32
|
||||
dim_t j
|
||||
)
|
||||
{
|
||||
dim_t j_scale = j;
|
||||
dim_t j_scale = j;
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 )
|
||||
{
|
||||
j_scale = 0;
|
||||
@@ -193,15 +233,140 @@ static inline float eltwise_ops_accuracy_check_downscale_f32of32
|
||||
return out_temp_accum;
|
||||
}
|
||||
|
||||
static inline float eltwise_ops_accuracy_check_downscale_f32os32
|
||||
(
|
||||
float temp_accum,
|
||||
aocl_post_op* post_op,
|
||||
dim_t j
|
||||
)
|
||||
{
|
||||
dim_t j_scale = j;
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 )
|
||||
{
|
||||
j_scale = 0;
|
||||
}
|
||||
|
||||
dim_t j_zp = j;
|
||||
if ( ( post_op->sum )->zero_point_len == 1 )
|
||||
{
|
||||
j_zp = 0;
|
||||
}
|
||||
|
||||
float zp_float = *( ( float* )( post_op->sum )->zero_point + j_zp );
|
||||
float out_temp_accum = ( temp_accum *
|
||||
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
|
||||
zp_float );
|
||||
return out_temp_accum;
|
||||
}
|
||||
|
||||
static inline float eltwise_ops_accuracy_check_downscale_f32os8
|
||||
(
|
||||
float temp_accum,
|
||||
aocl_post_op* post_op,
|
||||
dim_t j
|
||||
)
|
||||
{
|
||||
dim_t j_scale = j;
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 )
|
||||
{
|
||||
j_scale = 0;
|
||||
}
|
||||
|
||||
dim_t j_zp = j;
|
||||
if ( ( post_op->sum )->zero_point_len == 1 )
|
||||
{
|
||||
j_zp = 0;
|
||||
}
|
||||
|
||||
float out_temp_accum = \
|
||||
( float )min( \
|
||||
max( nearbyintf( ( float )( temp_accum ) * \
|
||||
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
|
||||
*( ( float* )( post_op->sum )->zero_point + j_zp ), \
|
||||
DSCALE_CLIP_MIN ), \
|
||||
DSCALE_CLIP_MAX ); \
|
||||
return out_temp_accum;
|
||||
}
|
||||
|
||||
static inline float eltwise_ops_accuracy_check_downscale_f32ou8
|
||||
(
|
||||
float temp_accum,
|
||||
aocl_post_op* post_op,
|
||||
dim_t j
|
||||
)
|
||||
{
|
||||
dim_t j_scale = j;
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 )
|
||||
{
|
||||
j_scale = 0;
|
||||
}
|
||||
|
||||
dim_t j_zp = j;
|
||||
if ( ( post_op->sum )->zero_point_len == 1 )
|
||||
{
|
||||
j_zp = 0;
|
||||
}
|
||||
|
||||
float out_temp_accum = \
|
||||
( float )min( \
|
||||
max( nearbyintf( ( float )( temp_accum ) * \
|
||||
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
|
||||
*( ( float* )( post_op->sum )->zero_point + j_zp ), \
|
||||
DSCALE_CLIP_MIN ), \
|
||||
DSCALE_CLIP_MAX ); \
|
||||
|
||||
return out_temp_accum;
|
||||
}
|
||||
|
||||
static inline float eltwise_ops_accuracy_check_downscale_f32obf16
|
||||
(
|
||||
float temp_accum,
|
||||
aocl_post_op* post_op,
|
||||
dim_t j
|
||||
)
|
||||
{
|
||||
dim_t j_scale = j;
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 )
|
||||
{
|
||||
j_scale = 0;
|
||||
}
|
||||
|
||||
dim_t j_zp = j;
|
||||
if ( ( post_op->sum )->zero_point_len == 1 )
|
||||
{
|
||||
j_zp = 0;
|
||||
}
|
||||
|
||||
float zp_float = *( ( float* )( post_op->sum )->zero_point + j_zp );
|
||||
float out_temp_accum = ( temp_accum *
|
||||
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
|
||||
zp_float );
|
||||
return out_temp_accum;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32os32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32ou8)
|
||||
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bf16obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32os32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32ou8)
|
||||
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,float)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,float)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,float)
|
||||
|
||||
#define GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(A_type,B_type,ACCUM_type,LP_SFX) \
|
||||
void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
|
||||
@@ -388,7 +553,7 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
|
||||
dim_t scl_fctr_len = ( post_op->matrix_mul )->scale_factor_len; \
|
||||
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,LP_SFX) \
|
||||
( ( post_op->matrix_mul )->matrix, i, \
|
||||
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op->matrix_add )->stor_type ); \
|
||||
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op->matrix_mul )->stor_type ); \
|
||||
} \
|
||||
else \
|
||||
{} \
|
||||
@@ -399,13 +564,14 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
|
||||
( \
|
||||
&out_temp_accum, &temp_accum \
|
||||
); \
|
||||
\
|
||||
if ( ( ( *( b + ( rs_b * i ) + ( cs_b * j ) ) - out_temp_accum ) > 1.0E-5 ) || \
|
||||
( (out_temp_accum - *( b + ( rs_b * i ) + ( cs_b * j ) ) ) > 1.0E-5 ) ) \
|
||||
{ \
|
||||
float comp_float, ref_float; \
|
||||
\
|
||||
float comp_float, ref_float; \
|
||||
GEN_FUNC_NAME(B_type,_to_float)(*( b + ( rs_b * i ) + ( cs_b * j ) ), &comp_float); \
|
||||
GEN_FUNC_NAME(B_type,_to_float)(out_temp_accum, &ref_float); \
|
||||
\
|
||||
if ( ( ( comp_float - ref_float ) > 1.0E-5 ) || \
|
||||
( ( ref_float - comp_float ) > 1.0E-5 ) ) \
|
||||
{ \
|
||||
if ( fout ) \
|
||||
{ \
|
||||
fprintf( fout, "%s Failure input m: %ld, n: %ld," \
|
||||
@@ -427,6 +593,10 @@ cleanup_acc: \
|
||||
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(bfloat16,float,float,bf16of32)
|
||||
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,bf16obf16)
|
||||
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,float,float,f32of32)
|
||||
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,bfloat16,float,f32obf16)
|
||||
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,int32_t,float,f32os32)
|
||||
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,int8_t,float,f32os8)
|
||||
GEN_ELTWISE_OPS_ACC_CHK_DRV_FUNC(float,uint8_t,float,f32ou8)
|
||||
|
||||
#define GEN_ELTWISE_OPS_BENCH_DRV_FUNC(A_type,B_type,LP_SFX) \
|
||||
void eltwise_ops_bench_driver_ ## LP_SFX \
|
||||
@@ -471,544 +641,19 @@ void eltwise_ops_bench_driver_ ## LP_SFX \
|
||||
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(bfloat16,float,bf16of32)
|
||||
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(bfloat16,bfloat16,bf16obf16)
|
||||
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,float,f32of32)
|
||||
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,bfloat16,f32obf16)
|
||||
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,int32_t,f32os32)
|
||||
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,int8_t,f32os8)
|
||||
GEN_ELTWISE_OPS_BENCH_DRV_FUNC(float,uint8_t,f32ou8)
|
||||
|
||||
#define GEN_ELTWISE_OPS_POST_OPS_CREATOR(C_DSCALE_type,C_type,DSCALE_type,BLAS_SFX) \
|
||||
static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
( \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
char* post_ops_str, \
|
||||
char stor_order \
|
||||
) \
|
||||
{ \
|
||||
if ( ( ( post_ops_str == NULL ) || \
|
||||
( strcmp( post_ops_str, "none" ) == 0 ) ) && \
|
||||
( global_dscale_out == 'n' ) ) \
|
||||
{ \
|
||||
return NULL; \
|
||||
} \
|
||||
\
|
||||
aocl_post_op* post_ops = NULL; \
|
||||
post_ops = ( aocl_post_op* ) malloc( sizeof( aocl_post_op ) ); \
|
||||
\
|
||||
if ( post_ops == NULL ) \
|
||||
{ \
|
||||
return NULL; \
|
||||
} \
|
||||
post_ops->eltwise = NULL; \
|
||||
\
|
||||
/* Only supporting 8 post ops at max for now.*/ \
|
||||
dim_t max_post_ops_seq_length = 8; \
|
||||
post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \
|
||||
malloc \
|
||||
( \
|
||||
max_post_ops_seq_length * \
|
||||
sizeof( AOCL_POST_OP_TYPE ) \
|
||||
); \
|
||||
\
|
||||
if ( post_ops->seq_vector == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
\
|
||||
/* Parse post ops list.*/ \
|
||||
dim_t cur_op_index = 0; \
|
||||
/* Ensure the buffers that use NULL check in deinit code is properly set to NULL.*/ \
|
||||
\
|
||||
/* Bench limitation: can only support 1 bias, but LPGEMM can support
|
||||
* multiple scale post-ops. */ \
|
||||
post_ops->bias = NULL; \
|
||||
post_ops->bias = malloc( sizeof( aocl_post_op_bias ) ); \
|
||||
if ( post_ops->bias == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
( post_ops->bias )->bias = NULL; \
|
||||
\
|
||||
/* Bench limitation: can only support 1 scale, but LPGEMM can support
|
||||
* multiple scale post-ops. */ \
|
||||
post_ops->sum = NULL; \
|
||||
post_ops->sum = malloc( sizeof( aocl_post_op_sum ) ); \
|
||||
if ( post_ops->sum == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
( post_ops->sum )->scale_factor = NULL; \
|
||||
( post_ops->sum )->buff = NULL; \
|
||||
( post_ops->sum )->zero_point = NULL; \
|
||||
( post_ops->sum )->scale_factor_len = 0; \
|
||||
( post_ops->sum )->zero_point_len = 0; \
|
||||
\
|
||||
/* Bench limitation: can only support 1 matrix add, but LPGEMM can support
|
||||
* multiple matrix add post-ops. */ \
|
||||
post_ops->matrix_add = NULL; \
|
||||
post_ops->matrix_add = malloc( sizeof( aocl_post_op_matrix_add ) ); \
|
||||
if ( post_ops->matrix_add == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
( post_ops->matrix_add )->matrix = NULL; \
|
||||
( post_ops->matrix_add )->ldm = 0; \
|
||||
\
|
||||
/* Bench limitation: can only support 1 matrix mul, but LPGEMM can support
|
||||
* multiple matrix mul post-ops. */ \
|
||||
post_ops->matrix_mul = NULL; \
|
||||
post_ops->matrix_mul = malloc( sizeof( aocl_post_op_matrix_mul ) ); \
|
||||
if ( post_ops->matrix_mul == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
( post_ops->matrix_mul )->matrix = NULL; \
|
||||
( post_ops->matrix_mul )->ldm = 0; \
|
||||
\
|
||||
bool is_bias = FALSE; \
|
||||
bool is_relu = FALSE; \
|
||||
bool is_param_relu = FALSE; \
|
||||
bool is_gelu_tanh = FALSE; \
|
||||
bool is_gelu_erf = FALSE; \
|
||||
bool is_swish = FALSE; \
|
||||
bool is_clip = FALSE; \
|
||||
bool is_scalar_scale = FALSE; \
|
||||
bool is_scalar_zp = FALSE; \
|
||||
bool is_matrix_add = FALSE; \
|
||||
bool is_matrix_mul = FALSE; \
|
||||
bool is_tanh = FALSE; \
|
||||
bool is_sigmoid = FALSE; \
|
||||
bool is_bias_stor_type = FALSE; \
|
||||
dim_t activator_idx = 0; \
|
||||
dim_t clip_idx = 0; \
|
||||
char * bias_stor_type = ""; \
|
||||
\
|
||||
/* Post-Ops string parser. */ \
|
||||
dim_t num_eltwise = 0; \
|
||||
if ( strcmp( post_ops_str, "none" ) != 0 ) \
|
||||
{ \
|
||||
char* ops_tok = strtok(post_ops_str, ", =" ); \
|
||||
\
|
||||
/* Ensure only one activator is used as an eltwise post-op.*/ \
|
||||
bool is_activator_set = FALSE; \
|
||||
while ( ops_tok ) \
|
||||
{ \
|
||||
str_tolower( ops_tok ); \
|
||||
if ( strcmp( ops_tok, "bias" ) == 0 ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = BIAS; \
|
||||
ops_tok = strtok( NULL, ", " ); \
|
||||
if( ( strcmp( ops_tok, "na" ) == 0 ) ) \
|
||||
{ \
|
||||
is_bias_stor_type = FALSE; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "f32" ) == 0 ) ) \
|
||||
{ \
|
||||
is_bias_stor_type = TRUE; \
|
||||
bias_stor_type = "F32"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "bf16" ) == 0 ) ) \
|
||||
{ \
|
||||
is_bias_stor_type = TRUE; \
|
||||
bias_stor_type = "BF16"; \
|
||||
} \
|
||||
is_bias = TRUE; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "relu" ) == 0 ) && \
|
||||
( is_activator_set == FALSE ) ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = ELTWISE; \
|
||||
is_relu = TRUE; \
|
||||
is_activator_set = TRUE; \
|
||||
num_eltwise += 1; \
|
||||
activator_idx = cur_op_index; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "prelu" ) == 0 ) && \
|
||||
( is_activator_set == FALSE ) ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = ELTWISE; \
|
||||
is_param_relu = TRUE; \
|
||||
is_activator_set = TRUE; \
|
||||
num_eltwise += 1; \
|
||||
activator_idx = cur_op_index; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "swish" ) == 0 ) && \
|
||||
( is_activator_set == FALSE ) ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = ELTWISE; \
|
||||
is_swish = TRUE; \
|
||||
is_activator_set = TRUE; \
|
||||
num_eltwise += 1; \
|
||||
activator_idx = cur_op_index; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "gelu_tanh" ) == 0 ) && \
|
||||
( is_activator_set == FALSE ) ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = ELTWISE; \
|
||||
is_gelu_tanh = TRUE; \
|
||||
is_activator_set = TRUE; \
|
||||
num_eltwise += 1; \
|
||||
activator_idx = cur_op_index; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "gelu_erf" ) == 0 ) && \
|
||||
( is_activator_set == FALSE ) ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = ELTWISE; \
|
||||
is_gelu_erf = TRUE; \
|
||||
is_activator_set = TRUE; \
|
||||
num_eltwise += 1; \
|
||||
activator_idx = cur_op_index; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
else if ( strcmp( ops_tok, "clip" ) == 0 ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = ELTWISE; \
|
||||
is_clip = TRUE; \
|
||||
num_eltwise += 1; \
|
||||
clip_idx = cur_op_index; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "tanh" ) == 0 ) && \
|
||||
( is_activator_set == FALSE ) ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = ELTWISE; \
|
||||
is_tanh = TRUE; \
|
||||
is_activator_set = TRUE; \
|
||||
num_eltwise += 1; \
|
||||
activator_idx = cur_op_index; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "sigmoid" ) == 0 ) && \
|
||||
( is_activator_set == FALSE ) ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = ELTWISE; \
|
||||
is_sigmoid = TRUE; \
|
||||
is_activator_set = TRUE; \
|
||||
num_eltwise += 1; \
|
||||
activator_idx = cur_op_index; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
else if ( strcmp( ops_tok, "scale" ) == 0 ) \
|
||||
{ \
|
||||
ops_tok = strtok( NULL, ", " ); \
|
||||
str_tolower( ops_tok ); \
|
||||
if ( ( strcmp( ops_tok, "scalar" ) == 0 ) || \
|
||||
( strcmp( ops_tok, "s" ) == 0 ) ) \
|
||||
{ \
|
||||
is_scalar_scale = TRUE; \
|
||||
} \
|
||||
} \
|
||||
else if ( strcmp( ops_tok, "zp" ) == 0 ) \
|
||||
{ \
|
||||
ops_tok = strtok( NULL, ", " ); \
|
||||
str_tolower( ops_tok ); \
|
||||
if ( ( strcmp( ops_tok, "scalar" ) == 0 ) || \
|
||||
( strcmp( ops_tok, "s" ) == 0 ) ) \
|
||||
{ \
|
||||
is_scalar_zp = TRUE; \
|
||||
} \
|
||||
} \
|
||||
else if ( strcmp( ops_tok, "matrix_add" ) == 0 ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = MATRIX_ADD; \
|
||||
is_matrix_add = TRUE; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
else if ( strcmp( ops_tok, "matrix_mul" ) == 0 ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = MATRIX_MUL; \
|
||||
is_matrix_mul = TRUE; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
\
|
||||
ops_tok = strtok( NULL, ", =" ); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
if ( is_bias == TRUE ) \
|
||||
{ \
|
||||
/* Allocate bias buffer, return early if alloc fails.*/ \
|
||||
( post_ops->bias )->bias = malloc( n * sizeof( C_type ) ); \
|
||||
if ( ( post_ops->bias )->bias == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
if(is_bias_stor_type == TRUE) \
|
||||
{ \
|
||||
if( ( strcmp( bias_stor_type, "BF16" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_BF16; \
|
||||
} \
|
||||
else if( ( strcmp( bias_stor_type, "F32" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_F32; \
|
||||
} \
|
||||
else {} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->bias )->stor_type = NULLTYPE ; \
|
||||
} \
|
||||
if ( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,C_DSCALE_type)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,C_type)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
if ( num_eltwise > 0 ) \
|
||||
{ \
|
||||
if ( num_eltwise > 1 ) \
|
||||
{ \
|
||||
if ( activator_idx < clip_idx ) \
|
||||
{ \
|
||||
activator_idx = 0; \
|
||||
clip_idx = 1; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
activator_idx = 1; \
|
||||
clip_idx = 0; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
activator_idx = 0; \
|
||||
clip_idx = 0; \
|
||||
} \
|
||||
\
|
||||
post_ops->num_eltwise = num_eltwise; \
|
||||
post_ops->eltwise = malloc( num_eltwise * sizeof( aocl_post_op_eltwise ) ); \
|
||||
if ( post_ops->eltwise == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
\
|
||||
/* Only one of relu, prelu, swish, gelu_tanh, gelu_erf allowed as
|
||||
* an activator. */ \
|
||||
if ( is_relu == TRUE ) \
|
||||
{ \
|
||||
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
|
||||
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.algo_type = RELU; \
|
||||
} \
|
||||
else if ( is_param_relu == TRUE ) \
|
||||
{ \
|
||||
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
|
||||
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \
|
||||
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
*( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )6; \
|
||||
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.algo_type = PRELU; \
|
||||
} \
|
||||
if ( is_swish == TRUE ) \
|
||||
{ \
|
||||
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
|
||||
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \
|
||||
if ( ( post_ops->eltwise + activator_idx )->algo.alpha == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
*( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )2; \
|
||||
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.algo_type = SWISH; \
|
||||
} \
|
||||
else if ( is_gelu_tanh == TRUE ) \
|
||||
{ \
|
||||
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
|
||||
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_TANH; \
|
||||
} \
|
||||
else if ( is_gelu_erf == TRUE ) \
|
||||
{ \
|
||||
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
|
||||
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_ERF; \
|
||||
} \
|
||||
else if ( is_tanh == TRUE ) \
|
||||
{ \
|
||||
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
|
||||
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.beta = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.algo_type = TANH; \
|
||||
} \
|
||||
if ( is_sigmoid == TRUE ) \
|
||||
{ \
|
||||
( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \
|
||||
( post_ops->eltwise + activator_idx )->scale_factor = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \
|
||||
( post_ops->eltwise + activator_idx )->algo.algo_type = SIGMOID; \
|
||||
} \
|
||||
if ( is_clip == TRUE ) \
|
||||
{ \
|
||||
( post_ops->eltwise + clip_idx )->is_power_of_2 = FALSE; \
|
||||
( post_ops->eltwise + clip_idx )->scale_factor = NULL; \
|
||||
( post_ops->eltwise + clip_idx )->algo.alpha = NULL; \
|
||||
( post_ops->eltwise + clip_idx )->algo.beta = NULL; \
|
||||
( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( DSCALE_type ) ); \
|
||||
if ( ( post_ops->eltwise + clip_idx )->algo.alpha == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( DSCALE_type ) ); \
|
||||
if ( ( post_ops->eltwise + clip_idx )->algo.beta == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
*( ( DSCALE_type* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( DSCALE_type ) ( -64 ); \
|
||||
*( ( DSCALE_type* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( DSCALE_type ) ( 23 ); \
|
||||
( post_ops->eltwise + clip_idx )->algo.algo_type = CLIP; \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
if ( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = SCALE; \
|
||||
cur_op_index++; \
|
||||
\
|
||||
( post_ops->sum )->is_power_of_2 = FALSE; \
|
||||
if ( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
dim_t n_scale = n; \
|
||||
if ( is_scalar_scale == TRUE ) \
|
||||
{ \
|
||||
n_scale = 1; \
|
||||
} \
|
||||
\
|
||||
dim_t n_zp = n; \
|
||||
if ( is_scalar_zp == TRUE ) \
|
||||
{ \
|
||||
n_zp = 1; \
|
||||
} \
|
||||
\
|
||||
/* Allocate scale buffer, return early if alloc fails.*/ \
|
||||
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( DSCALE_type ) ); \
|
||||
if ( ( post_ops->sum )->scale_factor == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
( post_ops->sum )->zero_point = malloc( n_zp * sizeof( C_DSCALE_type ) ); \
|
||||
if ( ( post_ops->sum )->zero_point == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
\
|
||||
/* Fill scale factor and zero points.*/ \
|
||||
DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )( post_ops->sum )->scale_factor; \
|
||||
for ( dim_t i = 0; i < n_scale; ++i ) \
|
||||
{ \
|
||||
temp_dscale_ptr[i] = ( ( DSCALE_type )1 )/ ( ( DSCALE_type )1000 ); \
|
||||
} \
|
||||
( post_ops->sum )->scale_factor_len = n_scale; \
|
||||
\
|
||||
C_DSCALE_type* temp_dzero_point_ptr = ( C_DSCALE_type* )( post_ops->sum )->zero_point; \
|
||||
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( temp_dzero_point_ptr, n_zp ); \
|
||||
( post_ops->sum )->zero_point_len = n_zp; \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
if ( is_matrix_add == TRUE ) \
|
||||
{ \
|
||||
/* Allocate bias buffer, return early if alloc fails.*/ \
|
||||
dim_t ele_dsize = 0; \
|
||||
if ( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
ele_dsize = sizeof( C_DSCALE_type ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
ele_dsize = sizeof( C_type ); \
|
||||
} \
|
||||
( post_ops->matrix_add )->matrix = malloc( m * n * ele_dsize ); \
|
||||
if ( ( post_ops->matrix_add )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
if ( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_add )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \
|
||||
{ \
|
||||
( post_ops->matrix_add )->ldm = m; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->matrix_add )->ldm = n; \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
if ( is_matrix_mul == TRUE ) \
|
||||
{ \
|
||||
/* Allocate bias buffer, return early if alloc fails.*/ \
|
||||
dim_t ele_dsize = 0; \
|
||||
if ( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
ele_dsize = sizeof( C_DSCALE_type ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
ele_dsize = sizeof( C_type ); \
|
||||
} \
|
||||
( post_ops->matrix_mul )->matrix = malloc( m * n * ele_dsize ); \
|
||||
if ( ( post_ops->matrix_mul )->matrix == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
if ( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
GEN_FUNC_NAME(fill_array_,C_type)( ( post_ops->matrix_mul )->matrix, ( m * n ) ); \
|
||||
} \
|
||||
if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \
|
||||
{ \
|
||||
( post_ops->matrix_mul )->ldm = m; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->matrix_mul )->ldm = n; \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
post_ops->seq_length = cur_op_index; \
|
||||
\
|
||||
post_ops->pre_ops = NULL; \
|
||||
\
|
||||
return post_ops; \
|
||||
\
|
||||
err_handler: \
|
||||
lpgemm_destroy_post_ops_struct( post_ops ); \
|
||||
return NULL; \
|
||||
} \
|
||||
|
||||
GEN_ELTWISE_OPS_POST_OPS_CREATOR(bfloat16,float,float,bf16of32)
|
||||
GEN_ELTWISE_OPS_POST_OPS_CREATOR(bfloat16,bfloat16,float,bf16obf16)
|
||||
GEN_ELTWISE_OPS_POST_OPS_CREATOR(float,float,float,f32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,float,bf16of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,float,bf16obf16)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(float,int32_t,float,float,f32os32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(float,int8_t,float,float,f32os8)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(float,uint8_t,float,float,f32ou8)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(float,float,float,float,f32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(float,bfloat16,float,float,f32obf16)
|
||||
|
||||
#define GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(A_type, B_type, LP_SFX) \
|
||||
void eltwise_ops_bench_main_ ## LP_SFX \
|
||||
@@ -1059,7 +704,7 @@ void eltwise_ops_bench_main_ ## LP_SFX \
|
||||
( strcmp( post_ops_str, "none" ) != 0 ) ) || \
|
||||
( global_dscale_out == 'y' ) ) \
|
||||
{ \
|
||||
post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,LP_SFX)( m, n, post_ops_str, stor_order ); \
|
||||
post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,LP_SFX)( m, n, 0, post_ops_str, stor_order ); \
|
||||
if ( post_op == NULL ) \
|
||||
{ \
|
||||
printf(" post op struct allocation failure, returning.\n"); \
|
||||
@@ -1098,6 +743,10 @@ void eltwise_ops_bench_main_ ## LP_SFX \
|
||||
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(bfloat16,float,bf16of32)
|
||||
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(bfloat16,bfloat16,bf16obf16)
|
||||
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,float,f32of32)
|
||||
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,bfloat16,f32obf16)
|
||||
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,int32_t,f32os32)
|
||||
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,int8_t,f32os8)
|
||||
GEN_ELTWISE_OPS_BENCH_MAIN_FUNC(float,uint8_t,f32ou8)
|
||||
|
||||
int main( int argc, char** argv )
|
||||
{
|
||||
@@ -1293,6 +942,60 @@ int main( int argc, char** argv )
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( eltwise_ops_type_str, "f32obf16" ) == 0 ) ||
|
||||
( strcmp( eltwise_ops_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'y';
|
||||
GEN_FUNC_NAME(eltwise_ops_bench_main_, f32obf16)
|
||||
(
|
||||
fout, stor_order, transa, transb,
|
||||
m, n, stride_a, stride_b,
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( eltwise_ops_type_str, "f32os32" ) == 0 ) ||
|
||||
( strcmp( eltwise_ops_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'n';
|
||||
DSCALE_CLIP_MIN = INT_MIN;
|
||||
DSCALE_CLIP_MAX = INT_MAX;
|
||||
GEN_FUNC_NAME(eltwise_ops_bench_main_, f32os32)
|
||||
(
|
||||
fout, stor_order, transa, transb,
|
||||
m, n, stride_a, stride_b,
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( eltwise_ops_type_str, "f32os8" ) == 0 ) ||
|
||||
( strcmp( eltwise_ops_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'y';
|
||||
DSCALE_CLIP_MIN = -128;
|
||||
DSCALE_CLIP_MAX = +127;
|
||||
GEN_FUNC_NAME(eltwise_ops_bench_main_, f32os8)
|
||||
(
|
||||
fout, stor_order, transa, transb,
|
||||
m, n, stride_a, stride_b,
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( eltwise_ops_type_str, "f32ou8" ) == 0 ) ||
|
||||
( strcmp( eltwise_ops_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'y';
|
||||
DSCALE_CLIP_MIN = 0;
|
||||
DSCALE_CLIP_MAX = +255;
|
||||
GEN_FUNC_NAME(eltwise_ops_bench_main_, f32ou8)
|
||||
(
|
||||
fout, stor_order, transa, transb,
|
||||
m, n, stride_a, stride_b,
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -83,6 +83,92 @@
|
||||
zmm2 = _mm512_mul_ps(zmm2,alpha); \
|
||||
zmm3 = _mm512_mul_ps(zmm3,alpha);
|
||||
|
||||
// BF16 bias helper macros.
|
||||
#define BF16_F32_BIAS_LOAD(scr,mask,n_ind) \
|
||||
scr = ( __m512)( _mm512_sllv_epi32 \
|
||||
( \
|
||||
_mm512_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm256_maskz_loadu_epi16 \
|
||||
( \
|
||||
( mask ), \
|
||||
( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
), _mm512_set1_epi32( 16 ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
// F32 bias helper macros.
|
||||
#define S32_F32_BIAS_LOAD(scr,mask,n_ind) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_maskz_loadu_epi32 \
|
||||
( \
|
||||
( mask ), \
|
||||
( ( int32_t* ) post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
// S8 bias helper macros.
|
||||
#define S8_F32_BIAS_LOAD(scr,mask,n_ind) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_cvtepi8_epi32 \
|
||||
( \
|
||||
_mm_maskz_loadu_epi8 \
|
||||
( \
|
||||
( mask ), \
|
||||
( ( int8_t* )post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
) \
|
||||
); \
|
||||
|
||||
// BF16 bias helper macros.
|
||||
#define BF16_F32_BIAS_BCAST(scr,mask,m_ind) \
|
||||
scr = ( __m512)( _mm512_sllv_epi32 \
|
||||
( \
|
||||
_mm512_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm256_maskz_set1_epi16 \
|
||||
( \
|
||||
( mask ), \
|
||||
*( ( ( bfloat16* )post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_i + m_ind ) \
|
||||
) \
|
||||
), _mm512_set1_epi32( 16 ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
// F32 bias helper macros.
|
||||
#define S32_F32_BIAS_BCAST(scr,mask,m_ind) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_maskz_set1_epi32 \
|
||||
( \
|
||||
( mask ), \
|
||||
*( ( ( int32_t* ) post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_i + m_ind ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
// S8 bias helper macros.
|
||||
#define S8_F32_BIAS_BCAST(scr,mask,m_ind) \
|
||||
scr = _mm512_cvtepi32_ps \
|
||||
( \
|
||||
_mm512_cvtepi8_epi32 \
|
||||
( \
|
||||
_mm_maskz_set1_epi8 \
|
||||
( \
|
||||
( mask ), \
|
||||
*( ( ( int8_t* )post_ops_list_temp->op_args1 ) + \
|
||||
post_ops_attr.post_op_c_i + m_ind ) \
|
||||
) \
|
||||
) \
|
||||
); \
|
||||
|
||||
// Matrix Add post-ops helper macros
|
||||
#define F32_MATRIX_ADD_2COL(scr0,scr1,m_ind,r_ind0,r_ind1) \
|
||||
zmm ## r_ind0 = _mm512_add_ps( scr0, zmm ## r_ind0 ); \
|
||||
@@ -99,6 +185,146 @@
|
||||
zmm ## r_ind2 = _mm512_add_ps( scr2, zmm ## r_ind2 ); \
|
||||
zmm ## r_ind3 = _mm512_add_ps( scr3, zmm ## r_ind3 ); \
|
||||
|
||||
//BF16 matrix_add helper macros.
|
||||
#define BF16_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
|
||||
scr = (__m512)( _mm512_sllv_epi32 \
|
||||
( \
|
||||
_mm512_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm256_maskz_loadu_epi16 \
|
||||
( \
|
||||
mask, \
|
||||
matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
), _mm512_set1_epi32( 16 ) \
|
||||
) \
|
||||
); \
|
||||
scr = _mm512_mul_ps( scr, scl_fct ); \
|
||||
|
||||
#define BF16_F32_MATRIX_ADD_2COL(scr0,scr1, \
|
||||
scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
F32_MATRIX_ADD_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \
|
||||
|
||||
#define BF16_F32_MATRIX_ADD_3COL(scr0,scr1,scr2, \
|
||||
scl_fct0,scl_fct1,scl_fct2,m_ind,r_ind0,r_ind1,r_ind2) \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \
|
||||
|
||||
#define BF16_F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3, \
|
||||
scl_fct0,scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
|
||||
F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
|
||||
|
||||
#define BF16_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3, \
|
||||
scl_fct0,scl_fct1,scl_fct2,scl_fct3,scr0,scr1,scr2,scr3,m_ind) \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
BF16_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
|
||||
F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
|
||||
|
||||
// S8 matrix_add helper macros.
|
||||
#define S8_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
|
||||
scr = _mm512_cvtepi32_ps( \
|
||||
_mm512_cvtepi8_epi32 \
|
||||
( \
|
||||
_mm_maskz_loadu_epi8 \
|
||||
( \
|
||||
mask, \
|
||||
matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
) \
|
||||
); \
|
||||
scr = _mm512_mul_round_ps \
|
||||
( \
|
||||
( scr ), scl_fct, \
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \
|
||||
); \
|
||||
|
||||
#define S8_F32_MATRIX_ADD_2COL(scr0,scr1, \
|
||||
scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
F32_MATRIX_ADD_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \
|
||||
|
||||
#define S8_F32_MATRIX_ADD_3COL(scr0,scr1,scr2, \
|
||||
scl_fct0,scl_fct1,scl_fct2,m_ind,r_ind0,r_ind1,r_ind2) \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \
|
||||
|
||||
#define S8_F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3, \
|
||||
scl_fct0,scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
|
||||
F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
|
||||
|
||||
#define S8_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3, \
|
||||
scl_fct0,scl_fct1,scl_fct2,scl_fct3,scr0,scr1,scr2,scr3,m_ind) \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
S8_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
|
||||
F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
|
||||
|
||||
// S32 matrix_add helper macros.
|
||||
#define S32_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
|
||||
scr = _mm512_cvtepi32_ps ( \
|
||||
_mm512_maskz_loadu_epi32 \
|
||||
( \
|
||||
mask, \
|
||||
matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
); \
|
||||
scr = _mm512_mul_round_ps \
|
||||
( \
|
||||
( scr ), scl_fct, \
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \
|
||||
); \
|
||||
|
||||
#define S32_F32_MATRIX_ADD_2COL(scr0,scr1, \
|
||||
scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
F32_MATRIX_ADD_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \
|
||||
|
||||
#define S32_F32_MATRIX_ADD_3COL(scr0,scr1,scr2, \
|
||||
scl_fct0,scl_fct1,scl_fct2,m_ind,r_ind0,r_ind1,r_ind2) \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
F32_MATRIX_ADD_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \
|
||||
|
||||
#define S32_F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3, \
|
||||
scl_fct0,scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
|
||||
F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
|
||||
|
||||
#define S32_F32_MATRIX_ADD_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3, \
|
||||
scl_fct0,scl_fct1,scl_fct2,scl_fct3,scr0,scr1,scr2,scr3,m_ind) \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
S32_F32_MATRIX_ADD_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
|
||||
F32_MATRIX_ADD_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
|
||||
|
||||
//F32 matrix_add helper macros.
|
||||
#define F32_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
|
||||
scr = _mm512_maskz_loadu_ps \
|
||||
( \
|
||||
@@ -153,6 +379,105 @@
|
||||
zmm ## r_ind2 = _mm512_mul_ps( scr2, zmm ## r_ind2 ); \
|
||||
zmm ## r_ind3 = _mm512_mul_ps( scr3, zmm ## r_ind3 ); \
|
||||
|
||||
#define BF16_F32_MATRIX_MUL_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
|
||||
BF16_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind); \
|
||||
|
||||
#define BF16_F32_MATRIX_MUL_2COL(scr0,scr1, \
|
||||
scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
F32_MATRIX_MUL_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \
|
||||
|
||||
#define BF16_F32_MATRIX_MUL_3COL(scr0,scr1,scr2, \
|
||||
scl_fct0,scl_fct1,scl_fct2,m_ind,r_ind0,r_ind1,r_ind2) \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \
|
||||
|
||||
#define BF16_F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,scl_fct0, \
|
||||
scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
|
||||
F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
|
||||
|
||||
#define BF16_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3, \
|
||||
scl_fct0,scl_fct1,scl_fct2,scl_fct3,scr0,scr1,scr2,scr3,m_ind) \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
BF16_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
|
||||
F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
|
||||
|
||||
|
||||
#define S8_F32_MATRIX_MUL_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
|
||||
S8_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind); \
|
||||
|
||||
#define S8_F32_MATRIX_MUL_2COL(scr0,scr1, \
|
||||
scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
F32_MATRIX_MUL_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \
|
||||
|
||||
#define S8_F32_MATRIX_MUL_3COL(scr0,scr1,scr2, \
|
||||
scl_fct0,scl_fct1,scl_fct2,m_ind,r_ind0,r_ind1,r_ind2) \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \
|
||||
|
||||
#define S8_F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,scl_fct0, \
|
||||
scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
|
||||
F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
|
||||
|
||||
#define S8_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3, \
|
||||
scl_fct0,scl_fct1,scl_fct2,scl_fct3,scr0,scr1,scr2,scr3,m_ind) \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
S8_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
|
||||
F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
|
||||
|
||||
#define S32_F32_MATRIX_MUL_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
|
||||
S32_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind); \
|
||||
|
||||
|
||||
#define S32_F32_MATRIX_MUL_2COL(scr0,scr1, \
|
||||
scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
F32_MATRIX_MUL_2COL(scr0,scr1,m_ind,r_ind0,r_ind1); \
|
||||
|
||||
#define S32_F32_MATRIX_MUL_3COL(scr0,scr1,scr2, \
|
||||
scl_fct0,scl_fct1,scl_fct2,m_ind,r_ind0,r_ind1,r_ind2) \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
F32_MATRIX_MUL_3COL(scr0,scr1,scr2,m_ind,r_ind0,r_ind1,r_ind2); \
|
||||
|
||||
#define S32_F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,scl_fct0, \
|
||||
scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
|
||||
F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
|
||||
|
||||
#define S32_F32_MATRIX_MUL_4COL_MASK(k0,k1,k2,k3,r_ind0,r_ind1,r_ind2,r_ind3, \
|
||||
scl_fct0,scl_fct1,scl_fct2,scl_fct3,scr0,scr1,scr2,scr3,m_ind) \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr0,scl_fct0,m_ind,0); \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr1,scl_fct1,m_ind,1); \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr2,scl_fct2,m_ind,2); \
|
||||
S32_F32_MATRIX_MUL_LOAD(_cvtu32_mask16( 0xFFFF ),scr3,scl_fct3,m_ind,3); \
|
||||
F32_MATRIX_MUL_4COL(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3); \
|
||||
|
||||
|
||||
#define F32_F32_MATRIX_MUL_LOAD(mask,scr,scl_fct,m_ind,n_ind) \
|
||||
F32_F32_MATRIX_ADD_LOAD(mask,scr,scl_fct,m_ind,n_ind); \
|
||||
|
||||
@@ -190,6 +515,38 @@
|
||||
reg = _mm512_mul_ps( reg, selector ); \
|
||||
reg = _mm512_add_ps( reg, zero_point ); \
|
||||
|
||||
// Downscale store bf16 macro
|
||||
#define CVT_STORE_F32_BF16_POST_OPS_MASK(reg,mask,m_ind,n_ind) \
|
||||
_mm256_mask_storeu_epi16 \
|
||||
( \
|
||||
b_q + ( rs_b * ( ir + m_ind ) ) + ( cs_b * ( jr + n_ind ) ), \
|
||||
mask, (__m256i) _mm512_cvtneps_pbh( reg ) \
|
||||
) \
|
||||
|
||||
|
||||
// Downscale store s8 macro
|
||||
#define CVT_STORE_F32_S8_POST_OPS_MASK(reg,mask,m_ind,n_ind) \
|
||||
_mm512_mask_cvtsepi32_storeu_epi8 \
|
||||
( \
|
||||
b_q + ( rs_b * ( ir + m_ind ) ) + ( cs_b * ( jr + n_ind ) ), \
|
||||
mask, _mm512_cvtps_epi32( reg ) \
|
||||
) \
|
||||
|
||||
// Downscale store u8 macro
|
||||
#define CVT_STORE_F32_U8_POST_OPS_MASK(reg,mask,m_ind,n_ind) \
|
||||
_mm512_mask_cvtusepi32_storeu_epi8 \
|
||||
( \
|
||||
b_q + ( rs_b * ( ir + m_ind ) ) + ( cs_b * ( jr + n_ind ) ), \
|
||||
mask, _mm512_cvtps_epu32( _mm512_max_ps( reg, _mm512_set1_ps( 0 ) ) ) \
|
||||
) \
|
||||
|
||||
// Downscale store f32 macro
|
||||
#define CVT_STORE_F32_S32_POST_OPS_MASK(reg,mask,m_ind,n_ind) \
|
||||
_mm512_mask_storeu_epi32 \
|
||||
( \
|
||||
b_q + ( rs_b * ( ir + m_ind ) ) + ( cs_b * ( jr + n_ind ) ), \
|
||||
mask, _mm512_cvtps_epi32 ( reg ) \
|
||||
) \
|
||||
|
||||
/*x_tanh = tanhf(x_tanh) */ \
|
||||
#define TANH_F32S_AVX512(x_tanh, r, r2, x, z, dn, q) \
|
||||
|
||||
Reference in New Issue
Block a user