Updated all post-ops in s8s8s32 API to operate in float precision

Description:

1. Changed all post-ops in s8s8s32o<s32|s8|u8|f32|bf16> to operate
   on float data. All the post-ops are updated to operate on f32
   by converting s32 accumulator registers to float at the end of k
   loop. Changed all post-ops to operate on float data.

2. Added s8s8s32ou8 API which uses s8s8s32os32 kernels but store
   the output in u8

AMD-Internal - SWLCSG-3366

Change-Id: Iadfd9bfb98fc3bf21e675acb95553fe967b806a6
This commit is contained in:
Deepak Negi
2025-02-06 04:31:13 +05:30
committed by Nallani Bhaskar
parent 2ece628a4d
commit 2e687d8847
11 changed files with 11525 additions and 7857 deletions

View File

@@ -140,12 +140,13 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16);
AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8);
AOCL_GEMM_MATMUL(uint8_t,int8_t,bfloat16,int32_t,u8s8s32obf16);
AOCL_GEMM_MATMUL(uint8_t,int8_t,float,int32_t,u8s8s32of32);
AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8);
AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32);
AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8);
AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16);
AOCL_GEMM_MATMUL(int8_t,int8_t,float,int32_t,s8s8s32of32);
AOCL_GEMM_MATMUL(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8);
AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16);
AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8);

View File

@@ -0,0 +1,205 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "aocl_gemm_interface_apis.h"
#include "aocl_gemm_check.h"
#include "lpgemm_types.h"
#include "lpgemm_post_ops.h"
#include "lpgemm_thread_decor_openmp.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_config.h"
#include "lpgemm_utils.h"
#include "lpgemm_logger.h"
AOCL_GEMM_MATMUL(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8)
{
LPGEMM_START_LOGGER();
LPGEMM_WRITE_LOGGER \
(
"s8s8s32ou8", \
order, transa, transb, \
m, n, k, \
( ( float ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
( ( float ) beta ), \
ldc, post_op_unparsed \
);
trans_t blis_transa;
trans_t blis_transb;
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if (bli_cpuid_is_avx512vnni_supported() == FALSE)
{
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
"cannot perform u8s8s32 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
// check for validity of params.
int err_no = 0;
AOCL_GEMM_CHECK
(
"s8s8s32ou8",
order, transa, transb,
m, n, k,
a, lda, mem_format_a,
b, ldb, mem_format_b,
c, ldc,
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
// Column major support disabled for u8s8s32 APIs as we cannot
// swap matrices as both A and B are of different types.
if ( ( order != 'r' ) && ( order != 'R' ) )
{
bli_print_msg("Column major inputs not supported.",
__FILE__, __LINE__);
goto err_hndl;
}
inc_t rs_a = lda;
inc_t cs_a = 1;
if (bli_is_trans(blis_transa))
{
rs_a = 1;
cs_a = lda;
}
inc_t rs_b = ldb;
inc_t cs_b = 1;
if (bli_is_trans(blis_transb))
{
rs_b = 1;
cs_b = ldb;
}
const inc_t rs_c = ldc;
const inc_t cs_c = 1;
AOCL_MEMORY_TAG mtag_a;
AOCL_MEMORY_TAG mtag_b;
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
// Reorder is not supported for A matrix
if (mtag_a == REORDERED)
{
bli_print_msg(" Reordering of A matrix is not supported "
"in row major case.",
__FILE__, __LINE__);
goto err_hndl;
}
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if (mtag_b == UNPACKED)
{
mtag_b = PACK;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if (bli_is_trans(blis_transa))
{
mtag_a = PACK;
}
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed, post_op_list,
( void* )c, ( void* )( &order ),
m, n
);
if( err != BLIS_SUCCESS )
{
goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
#ifdef BLIS_ENABLE_OPENMP
lpgemm_s8s8s32o32_openmp_thread_decorator(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(int32_t *)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, U8);
#else
lpgemm_s8s8s32o32_thread_decorator(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(int32_t *)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, U8);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -108,6 +108,7 @@ GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
GEN_BLIS_MAT_MUL_FUNC(float,float,float,float,f32f32f32of32)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32)
GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
@@ -204,6 +205,7 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,float,f32f32f32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
@@ -246,6 +248,7 @@ GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,u8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,u8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,s8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,s8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,s8s8s32obf16)
@@ -339,6 +342,7 @@ static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
@@ -715,6 +719,7 @@ GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os32)
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32obf16)
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32of32)
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32os8)
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32ou8)
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32os32)
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32obf16)
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32of32)
@@ -731,6 +736,7 @@ GEN_TANH_POSTOP_FLOAT(u8s8s32os32)
GEN_TANH_POSTOP_FLOAT(u8s8s32obf16)
GEN_TANH_POSTOP_FLOAT(u8s8s32of32)
GEN_TANH_POSTOP_FLOAT(s8s8s32os8)
GEN_TANH_POSTOP_FLOAT(s8s8s32ou8)
GEN_TANH_POSTOP_FLOAT(s8s8s32obf16)
GEN_TANH_POSTOP_FLOAT(s8s8s32of32)
GEN_TANH_POSTOP_FLOAT(s8s8s32os32)
@@ -747,6 +753,7 @@ GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32os32)
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32obf16)
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32of32)
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32os8)
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32ou8)
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32os32)
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32obf16)
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32of32)
@@ -763,6 +770,7 @@ GEN_SWISH_POSTOP_INT(float,u8s8s32os32)
GEN_SWISH_POSTOP_FLOAT(u8s8s32obf16)
GEN_SWISH_POSTOP_FLOAT(u8s8s32of32)
GEN_SWISH_POSTOP_INT(float,s8s8s32os8)
GEN_SWISH_POSTOP_INT(float,s8s8s32ou8)
GEN_SWISH_POSTOP_INT(float,s8s8s32os32)
GEN_SWISH_POSTOP_FLOAT(s8s8s32obf16)
GEN_SWISH_POSTOP_FLOAT(s8s8s32of32)
@@ -779,6 +787,7 @@ GEN_SIGMOID_POSTOP_FLOAT(u8s8s32os32)
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32obf16)
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32of32)
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32os8)
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32ou8)
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32os32)
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32obf16)
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32of32)
@@ -797,10 +806,11 @@ GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32ou8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,s8s8s32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,s8s8s32ou8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,s8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,s8s8s32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,s8s8s32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32f32f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16bf16f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16s4f32of32)
@@ -813,10 +823,11 @@ GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32ou8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,s8s8s32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,s8s8s32ou8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,s8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,s8s8s32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,s8s8s32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32f32f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16bf16f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16s4f32of32)
@@ -836,6 +847,7 @@ GEN_PRELU_POST_OP_VAL_INT(u8s8s32os8)
GEN_PRELU_POST_OP_VAL_INT(u8s8s32ou8)
GEN_PRELU_POST_OP_VAL_INT(u8s8s32os32)
GEN_PRELU_POST_OP_VAL_INT(s8s8s32os8)
GEN_PRELU_POST_OP_VAL_INT(s8s8s32ou8)
GEN_PRELU_POST_OP_VAL_INT(s8s8s32os32)
@@ -853,6 +865,7 @@ GEN_CLIP_POST_OP_VAL_INT(u8s8s32os8)
GEN_CLIP_POST_OP_VAL_INT(u8s8s32ou8)
GEN_CLIP_POST_OP_VAL_INT(u8s8s32os32)
GEN_CLIP_POST_OP_VAL_INT(s8s8s32os8)
GEN_CLIP_POST_OP_VAL_INT(s8s8s32ou8)
GEN_CLIP_POST_OP_VAL_INT(s8s8s32os32)
@@ -864,10 +877,11 @@ GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32ou8)
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32obf16)
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32of32)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os8)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32obf16)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32of32)
GEN_GET_BIAS_POST_OP_VAL(float,s8s8s32os8)
GEN_GET_BIAS_POST_OP_VAL(float,s8s8s32ou8)
GEN_GET_BIAS_POST_OP_VAL(float,s8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL(float,s8s8s32obf16)
GEN_GET_BIAS_POST_OP_VAL(float,s8s8s32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(f32f32f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
@@ -1141,6 +1155,7 @@ GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,float,f32f32f32of32,f
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,int32_t,float,s8s8s32os32,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,int32_t,float,s8s8s32os8,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,uint8_t,int32_t,int32_t,float,s8s8s32ou8,s8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,bfloat16,int32_t,int32_t,float,s8s8s32obf16,s8s8s32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,float,int32_t,float,int32_t,s8s8s32of32,s8s8s32of32)
@@ -1152,9 +1167,9 @@ GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,bfloat16,float,float,u8s8s32obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32os32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int8_t,float,int32_t,s8s8s32os8)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,uint8_t,float,int32_t,s8s8s32ou8)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,float,float,int32_t,s8s8s32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,bfloat16,float,int32_t,s8s8s32obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,uint8_t,float,int32_t,s8s8s32ou8)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16bf16f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,bfloat16,bf16bf16f32obf16)
@@ -1358,6 +1373,7 @@ GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,float,int32_t,u8s8s32of32,u8s8s32os32
GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,float,f32f32f32of32,f32f32f32of32,bf16s4f32of32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32,s8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8,s8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8,s8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16,s8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32,s8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32,bf16bf16f32of32,bf16s4f32of32)
@@ -1398,9 +1414,10 @@ int main( int argc, char** argv )
" 3. u8s8s32os32 -d f32 = u8s8s32of32.\n" \
" 4. u8s8s32os32 -d bf16 = u8s8s32obf16.\n" \
" 5. s8s8s32os32 -d s8 = s8s8s32os8.\n" \
" 6. s8s8s32os32 -d f32 = s8s8s32of32.\n" \
" 7. s8s8s32os32 -d bf16 = s8s8s32obf16.\n" \
" 8. bf16bf16f32of32 -d bf16 = bf16bf16f32obf16.\n" \
" 6. s8s8s32os32 -d u8 = s8s8s32ou8.\n" \
" 7. s8s8s32os32 -d f32 = s8s8s32of32.\n" \
" 8. s8s8s32os32 -d bf16 = s8s8s32obf16.\n" \
" 9. bf16bf16f32of32 -d bf16 = bf16bf16f32obf16.\n" \
" Example: ./bench_lpgemm -m a -n 2 -o bias,relu -d bf16 -i input.txt\n" \
);
exit( 1 );
@@ -1748,6 +1765,21 @@ int main( int argc, char** argv )
post_ops_str_dest
);
}
if ( ( strcmp( gemm_type_str, "s8s8s32ou8" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
global_dscale_out = 'y';
global_pre_op = 'n';
DSCALE_CLIP_MIN = 0;
DSCALE_CLIP_MAX = +255;
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32ou8)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest
);
}
if ( ( strcmp( gemm_type_str, "s8s8s32obf16" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -517,7 +517,6 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
zmm8 = _mm512_maskz_sub_epi32( k2, zmm8, zmm0 );
}
//Scale accumulated output with alpha
__m512i selector1 = _mm512_set1_epi32( alpha );
__m512i selector2 = _mm512_set1_epi32( beta );
@@ -536,6 +535,11 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
S8_S32_BETA_OP_NLT16F_MASK( k2, zmm8, 0, 0,
selector1, selector2 )
}
else if ( post_ops_attr.c_stor_type == U8 )
{
U8_S32_BETA_OP_NLT16F_MASK( k2, zmm8, 0, 0,
selector1, selector2 )
}
else if ( post_ops_attr.c_stor_type == BF16 )
{
BF16_S32_BETA_OP_NLT16F_MASK( k2, zmm8, 0, 0,
@@ -549,15 +553,57 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
}
else
{
int8_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
if ( post_ops_attr.c_stor_type == S8 )
{
ctemp[i] = *( ( int8_t* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) );
int8_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( ( int8_t* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) );
}
selector1 = _mm512_cvtepi8_epi32
( _mm_maskz_loadu_epi8( 0xFFFF, ctemp ) );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
uint8_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( ( uint8_t* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) );
}
selector1 = _mm512_cvtepu8_epi32
( _mm_maskz_loadu_epi8( 0xFFFF, ctemp ) );
}
else if ( post_ops_attr.c_stor_type == BF16 )
{
bfloat16 ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( ( bfloat16* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) );
}
selector1 = _mm512_cvtps_epi32(
( __m512 )_mm512_sllv_epi32(
_mm512_cvtepi16_epi32(
_mm256_maskz_loadu_epi16( 0xFFFF, ctemp )
), _mm512_set1_epi32( 16 ) ) );
}
else if ( post_ops_attr.c_stor_type == F32 )
{
float ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( ( float* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) );
}
selector1 = _mm512_cvtps_epi32(
_mm512_maskz_loadu_ps( 0xFFFF, ctemp ) );
}
selector1 = _mm512_cvtepi8_epi32
( _mm_maskz_loadu_epi8( 0xFFFF, ctemp ) );
S32_BETA_FMA( zmm8, selector1, selector2 );
}
}
@@ -581,6 +627,9 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
}
}
__m512 acc_8 = _mm512_setzero_ps();
acc_8 = _mm512_cvtepi32_ps( zmm8 );
// Post Ops
lpgemm_post_op *post_ops_list_temp = post_op;
@@ -589,108 +638,122 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
POST_OPS_BIAS_6x64:
{
__m512 b0 = _mm512_setzero_ps();
if ( post_ops_list_temp->stor_type == BF16 )
{
selector1 =
_mm512_cvtps_epi32
(
( __m512 )( _mm512_sllv_epi32
(
_mm512_cvtepi16_epi32
(
_mm256_maskz_loadu_epi16
(
_cvtu32_mask16( 0xFFFF ),
b0 = (__m512)_mm512_sllv_epi32(
_mm512_cvtepi16_epi32(
_mm256_maskz_loadu_epi16(
_cvtu32_mask16( 0x0001 ),
( ( bfloat16* )post_ops_list_temp->op_args1 )
)
), _mm512_set1_epi32( 16 )
)
)
);
) ), _mm512_set1_epi32( 16 ) );
}
else if ( post_ops_list_temp->stor_type == S8 )
{
selector1 =
_mm512_cvtepi8_epi32
(
_mm_maskz_loadu_epi8
(
_cvtu32_mask16( 0xFFFF ),
( ( int8_t* )post_ops_list_temp->op_args1 )
)
);
b0 = _mm512_cvtepi32_ps(
_mm512_cvtepi8_epi32(
_mm_maskz_loadu_epi8(
_cvtu32_mask16( 0x0001 ),
( ( int8_t* )post_ops_list_temp->op_args1 )
) ) );
}
else if ( post_ops_list_temp->stor_type == F32 )
else if ( post_ops_list_temp->stor_type == S32 )
{
selector1 = _mm512_cvtps_epi32
(
( __m512 )_mm512_maskz_loadu_ps
(
_cvtu32_mask16( 0xFFFF ),
( ( float* ) post_ops_list_temp->op_args1 )
)
);
b0 = _mm512_cvtepi32_ps(
_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args1) ) );
}
else
{
selector1 =
_mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args1) );
b0 = _mm512_maskz_loadu_ps(
_cvtu32_mask16( 0x0001 ),
( ( float* ) post_ops_list_temp->op_args1 ) );
}
zmm8 = _mm512_add_epi32( selector1, zmm8 );
acc_8 = _mm512_add_ps( b0, acc_8 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_RELU_6x64:
{
selector1 = _mm512_setzero_epi32();
__m512 zero = _mm512_setzero_ps();
zmm8 = _mm512_max_epi32( selector1, zmm8 );
acc_8 = _mm512_max_ps( zero, acc_8 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_RELU_SCALE_6x64:
{
selector1 = _mm512_setzero_epi32();
selector2 =
_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) );
__m512 zero = _mm512_setzero_ps();
__m512 scale;
if ( ( post_ops_attr.c_stor_type == S32 ) ||
( post_ops_attr.c_stor_type == U8 ) ||
( post_ops_attr.c_stor_type == S8 ) )
{
scale = _mm512_cvtepi32_ps
( _mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ) );
}
else
{
scale = _mm512_set1_ps(
*( ( float* )post_ops_list_temp->op_args2 ) );
}
__mmask16 relu_cmp_mask;
RELU_SCALE_OP_S32_AVX512(zmm8)
RELU_SCALE_OP_F32_AVX512(acc_8)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_GELU_TANH_6x64:
{
__m512 dn, z, x, r2, r, y, x_tanh;
GELU_TANH_S32_AVX512( zmm8, y, r, r2, x,
z, dn, x_tanh, selector1 )
__m512 dn, z, x, r2, r, y;
__m512i tmpout;
GELU_TANH_F32_AVX512_DEF( acc_8, y, r, r2, x, z, dn, tmpout );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_GELU_ERF_6x64:
{
__m512 x, r, y, x_erf;
__m512 y, r, r2;
GELU_ERF_S32_AVX512( zmm8, y, r, x, x_erf )
GELU_ERF_F32_AVX512_DEF( acc_8, y, r, r2 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_CLIP_6x64:
{
__m512i min = _mm512_set1_epi32(
*( int32_t* )post_ops_list_temp->op_args2 );
__m512i max = _mm512_set1_epi32(
*( int32_t* )post_ops_list_temp->op_args3 );
__m512 min = _mm512_setzero_ps();
__m512 max = _mm512_setzero_ps();
CLIP_S32_AVX512( zmm8, min, max )
if ( ( post_ops_attr.c_stor_type == S32 ) ||
( post_ops_attr.c_stor_type == U8 ) ||
( post_ops_attr.c_stor_type == S8 ) )
{
min = _mm512_cvtepi32_ps
(_mm512_set1_epi32( *( int32_t* )post_ops_list_temp->op_args2 ));
max = _mm512_cvtepi32_ps
(_mm512_set1_epi32( *( int32_t* )post_ops_list_temp->op_args3 ));
}
else
{
min = _mm512_set1_ps(
*( ( float* )post_ops_list_temp->op_args2 ) );
max = _mm512_set1_ps(
*( ( float* )post_ops_list_temp->op_args3 ) );
}
CLIP_F32_AVX512( acc_8, min, max )
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_DOWNSCALE_6x64:
{
selector1 = ( __m512i )_mm512_set1_ps(
__m512 scale0 = _mm512_setzero_ps();
scale0 = _mm512_set1_ps(
*( ( float* )post_ops_list_temp->scale_factor ) );
// Need to ensure sse not used to avoid avx512 -> sse transition.
@@ -700,7 +763,7 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
zero_point0 = _mm_maskz_set1_epi8( 0xFFFF,
*( ( int8_t* )post_ops_list_temp->op_args1 ) );
CVT_MULRND_CVT32(zmm8, selector1, zero_point0 );
CVT_MULRND_F32(acc_8, scale0, zero_point0 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
@@ -715,6 +778,7 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
bool is_f32 = ( post_ops_list_temp->stor_type == F32 );
__m512 scl_fctr1 = _mm512_setzero_ps();
__m512 t0 = _mm512_setzero_ps();
// Even though different registers are used for scalar in column and
// row major case, all those registers will contain the same value.
@@ -742,22 +806,19 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
if( ldm == 1 )
{
BF16_S32_MATRIX_ADD_LOAD( k2, selector1, scl_fctr1, 0, 0 );
zmm8 =
_mm512_cvtps_epi32(
_mm512_add_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ) )
);
BF16_F32_MATRIX_ADD_LOAD( k2, t0, scl_fctr1, 0, 0 );
acc_8 = _mm512_add_ps( t0, acc_8 );
}
else
{
int8_t ctemp[16];
bfloat16 ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( matptr +
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm512_sllv_epi32
t0 = (__m512)_mm512_sllv_epi32
(
_mm512_cvtepi16_epi32
(
@@ -767,11 +828,8 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
)
), _mm512_set1_epi32( 16 )
);
selector1 = ( __m512i )_mm512_mul_ps( ( __m512 )selector1, scl_fctr1 );
zmm8 =
_mm512_cvtps_epi32(
_mm512_add_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ) )
);
t0 = _mm512_mul_ps( t0, scl_fctr1 );
acc_8 = _mm512_add_ps( t0, acc_8 );
}
}
else if ( is_f32 == TRUE )
@@ -780,27 +838,21 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
if( ldm == 1 )
{
F32_S32_MATRIX_ADD_LOAD( k2, selector1, scl_fctr1, 0, 0 );
zmm8 =
_mm512_cvtps_epi32(
_mm512_add_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ) )
);
F32_ACC_MATRIX_ADD_LOAD( k2, t0, scl_fctr1, 0, 0 );
acc_8 = _mm512_add_ps( t0, acc_8 );
}
else
{
int8_t ctemp[16];
float ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( matptr +
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = ( __m512i )_mm512_maskz_loadu_ps( k2, ctemp );
selector1 = ( __m512i )_mm512_mul_ps( ( __m512 )selector1, scl_fctr1 );
zmm8 =
_mm512_cvtps_epi32(
_mm512_add_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ) )
);
t0 = _mm512_maskz_loadu_ps( k2, ctemp );
t0 = _mm512_mul_ps( t0, scl_fctr1 );
acc_8 = _mm512_add_ps( t0, acc_8 );
}
}
else if ( is_s8 == TRUE )
@@ -809,8 +861,8 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
if( ldm == 1 )
{
S8_S32_MATRIX_ADD_LOAD( k2, selector1, scl_fctr1, 0, 0 )
zmm8 = _mm512_add_epi32( selector1, zmm8 );
S8_F32_MATRIX_ADD_LOAD( k2, t0, scl_fctr1, 0, 0 )
acc_8 = _mm512_add_ps( t0, acc_8 );
}
else
{
@@ -821,16 +873,13 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm512_cvtepi8_epi32
( _mm_maskz_loadu_epi8( k2, ctemp ) );
selector1 = _mm512_cvtps_epi32(
_mm512_mul_round_ps
(
_mm512_cvtepi32_ps( selector1 ), scl_fctr1,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC )
)
);
zmm8 = _mm512_add_epi32( selector1, zmm8 );
t0 = _mm512_cvtepi32_ps(
_mm512_cvtepi8_epi32(
_mm_maskz_loadu_epi8( k2, ctemp ) ) );
t0 = _mm512_mul_round_ps( t0, scl_fctr1,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC )
);
acc_8 = _mm512_add_ps( t0, acc_8 );
}
}
else
@@ -839,8 +888,8 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
if( ldm == 1 )
{
S32_S32_MATRIX_ADD_LOAD(k2, selector1, scl_fctr1, 0, 0 );
zmm8 = _mm512_add_epi32( selector1, zmm8 );
S32_F32_MATRIX_ADD_LOAD(k2, t0, scl_fctr1, 0, 0 );
acc_8 = _mm512_add_ps( t0, acc_8 );
}
else
{
@@ -851,15 +900,12 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm512_maskz_loadu_epi32( k2, ctemp );
selector1 = _mm512_cvtps_epi32(
_mm512_mul_round_ps
(
_mm512_cvtepi32_ps( selector1 ), scl_fctr1,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC )
)
);
zmm8 = _mm512_add_epi32( selector1, zmm8 );
t0 = _mm512_cvtepi32_ps(
_mm512_maskz_loadu_epi32( k2, ctemp ) );
t0 = _mm512_mul_round_ps( t0, scl_fctr1,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC )
);
acc_8 = _mm512_add_ps( t0, acc_8 );
}
}
@@ -876,6 +922,7 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
bool is_f32 = ( post_ops_list_temp->stor_type == F32 );
__m512 scl_fctr1 = _mm512_setzero_ps();
__m512 t0 = _mm512_setzero_ps();
// Even though different registers are used for scalar in column and
// row major case, all those registers will contain the same value.
@@ -903,38 +950,32 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
if( ldm == 1 )
{
BF16_S32_MATRIX_MUL_LOAD( k2, selector1, scl_fctr1, 0, 0 );
zmm8 =
_mm512_cvtps_epi32(
_mm512_mul_round_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ),
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) )
);
BF16_F32_MATRIX_MUL_LOAD( k2, t0, scl_fctr1, 0, 0 );
acc_8 = _mm512_mul_round_ps( t0, acc_8,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
}
else
{
int8_t ctemp[16];
bfloat16 ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( matptr +
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm512_sllv_epi32
(
_mm512_cvtepi16_epi32
(
_mm256_maskz_loadu_epi16
(
k2 , ctemp
)
), _mm512_set1_epi32( 16 )
);
selector1 = ( __m512i )_mm512_mul_ps( ( __m512 )selector1, scl_fctr1 );
zmm8 =
_mm512_cvtps_epi32(
_mm512_mul_round_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ),
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) )
);
t0 = (__m512)_mm512_sllv_epi32
(
_mm512_cvtepi16_epi32
(
_mm256_maskz_loadu_epi16
(
k2 , ctemp
)
), _mm512_set1_epi32( 16 )
);
t0 = _mm512_mul_ps( t0, scl_fctr1 );
acc_8 = _mm512_mul_round_ps( t0, acc_8,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
}
}
else if ( is_f32 == TRUE )
@@ -943,29 +984,23 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
if( ldm == 1 )
{
F32_S32_MATRIX_MUL_LOAD( k2, selector1, scl_fctr1, 0, 0 );
zmm8 =
_mm512_cvtps_epi32(
_mm512_mul_round_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ),
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) )
);
F32_MATRIX_MUL_LOAD( k2, t0, scl_fctr1, 0, 0 );
acc_8 = _mm512_mul_round_ps( t0, acc_8,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
}
else
{
int8_t ctemp[16];
float ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( matptr +
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = ( __m512i )_mm512_maskz_loadu_ps( k2, ctemp );
selector1 = ( __m512i )_mm512_mul_ps( ( __m512 )selector1, scl_fctr1 );
zmm8 =
_mm512_cvtps_epi32(
_mm512_mul_round_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ),
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) )
);
t0 = _mm512_maskz_loadu_ps( k2, ctemp );
t0 = _mm512_mul_ps( t0, scl_fctr1 );
acc_8 = _mm512_mul_round_ps( t0, acc_8,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
}
}
else if ( is_s8 == TRUE )
@@ -974,14 +1009,10 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
if( ldm == 1 )
{
S8_S32_MATRIX_MUL_LOAD( k2, selector1, scl_fctr1, 0, 0 )
S8_F32_MATRIX_MUL_LOAD( k2, t0, scl_fctr1, 0, 0 )
// mul_epi32 works on 64 bit lengths, with mul done for lower 32 bits.
// We only need 32 bit mul to get 32 bit output, so using mul_ps.
zmm8 = _mm512_cvtps_epi32(
_mm512_mul_ps( _mm512_cvtepi32_ps( selector1 ),
_mm512_cvtepi32_ps( zmm8 ) )
);
acc_8 = _mm512_mul_round_ps( t0, acc_8,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
}
else
{
@@ -992,19 +1023,13 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm512_cvtepi8_epi32
( _mm_maskz_loadu_epi8( k2, ctemp ) );
selector1 = _mm512_cvtps_epi32(
_mm512_mul_round_ps
(
_mm512_cvtepi32_ps( selector1 ), scl_fctr1,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC )
)
);
zmm8 = _mm512_cvtps_epi32(
_mm512_mul_ps( _mm512_cvtepi32_ps( selector1 ),
_mm512_cvtepi32_ps( zmm8 ) )
);
t0 = _mm512_cvtepi32_ps(
_mm512_cvtepi8_epi32
( _mm_maskz_loadu_epi8( k2, ctemp ) ) );
t0 = _mm512_mul_round_ps( t0, scl_fctr1,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
acc_8 = _mm512_mul_round_ps( t0, acc_8,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
}
}
else
@@ -1013,11 +1038,10 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
if( ldm == 1 )
{
S32_S32_MATRIX_MUL_LOAD(k2, selector1, scl_fctr1, 0, 0 );
zmm8 = _mm512_cvtps_epi32(
_mm512_mul_ps( _mm512_cvtepi32_ps( selector1 ),
_mm512_cvtepi32_ps( zmm8 ) )
);
S32_F32_MATRIX_MUL_LOAD(k2, t0, scl_fctr1, 0, 0 );
acc_8 = _mm512_mul_round_ps( t0, acc_8,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
}
else
{
@@ -1028,18 +1052,12 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm512_maskz_loadu_epi32( k2, ctemp );
selector1 = _mm512_cvtps_epi32(
_mm512_mul_round_ps
(
_mm512_cvtepi32_ps( selector1 ), scl_fctr1,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC )
)
);
zmm8 = _mm512_cvtps_epi32(
_mm512_mul_ps( _mm512_cvtepi32_ps( selector1 ),
_mm512_cvtepi32_ps( zmm8 ) )
);
t0 = _mm512_cvtepi32_ps(
_mm512_maskz_loadu_epi32( k2, ctemp ) );
t0 = _mm512_mul_round_ps( t0, scl_fctr1,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
acc_8 = _mm512_mul_round_ps( t0, acc_8,
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
}
}
@@ -1047,29 +1065,44 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
}
POST_OPS_SWISH_6x64:
{
selector1 =
_mm512_set1_epi32( *( (int32_t*)post_ops_list_temp->op_args2 ) );
__m512 scale;
__m512 al = _mm512_cvtepi32_ps( selector1 );
if ( ( post_ops_attr.c_stor_type == S32 ) ||
( post_ops_attr.c_stor_type == U8 ) ||
( post_ops_attr.c_stor_type == S8 ) )
{
scale = _mm512_cvtepi32_ps
(_mm512_set1_epi32(
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
}
else
{
scale = _mm512_set1_ps(
*( ( float* )post_ops_list_temp->op_args2 ) );
}
__m512 fl_reg, al_in, r, r2, z, dn;
__m512 al_in, r, r2, z, dn;
__m512i temp;
SWISH_S32_AVX512( zmm8, fl_reg, al, al_in, r, r2, z, dn, selector2 );
SWISH_F32_AVX512_DEF( acc_8, scale, al_in, r, r2, z, dn, temp );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_TANH_6x64:
{
__m512 dn, z, x, r2, r, y;
TANH_S32_AVX512( zmm8, y, r, r2, x, z, dn, selector1 );
__m512 dn, z, x, r2, r;
__m512i q;
TANHF_AVX512( acc_8, r, r2, x, z, dn, q );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_SIGMOID_6x64:
{
__m512 fl_reg, al_in, r, r2, z, dn;
__m512 al_in, r, r2, z, dn;
__m512i tmpout;
SIGMOID_S32_AVX512( zmm8, fl_reg, al_in, r, r2, z, dn, selector2 );
SIGMOID_F32_AVX512_DEF( acc_8, al_in, r, r2, z, dn, tmpout );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
@@ -1083,28 +1116,79 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
{
if ( post_ops_attr.c_stor_type == S8 )
{
CVT_STORE_S32_S8_MASK( zmm8, k2, 0, 0 );
CVT_STORE_F32_S8_MASK( k2, acc_8, 0, 0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
CVT_STORE_F32_U8_MASK( k2, acc_8, 0, 0 );
}
else if ( post_ops_attr.c_stor_type == BF16 )
{
CVT_STORE_S32_BF16_MASK( zmm8, k2, 0, 0 );
CVT_STORE_F32_BF16_MASK( k2, acc_8, 0, 0 );
}
else if ( post_ops_attr.c_stor_type == F32 )
{
CVT_STORE_S32_F32_MASK( zmm8, k2, 0, 0 );
STORE_F32_MASK( k2, acc_8, 0, 0 );
}
}
else
{
int8_t ctemp[16];
_mm512_mask_cvtsepi32_storeu_epi8 ( ctemp, k2, zmm8 );
for (dim_t i = 0; i < mr0; i++)
if ( post_ops_attr.c_stor_type == S8 )
{
*( ( int8_t* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
int8_t ctemp[16];
_mm512_mask_cvtsepi32_storeu_epi8 ( ctemp, k2,
_mm512_cvtps_epi32( acc_8 ) );
for (dim_t i = 0; i < mr0; i++)
{
*( ( int8_t* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
}
}
else if ( post_ops_attr.c_stor_type == U8 )
{
uint8_t ctemp[16];
_mm512_mask_cvtusepi32_storeu_epi8 ( ctemp, k2,
_mm512_cvtps_epu32(
_mm512_max_ps( acc_8, _mm512_set1_ps( 0 ) )
) );
for (dim_t i = 0; i < mr0; i++)
{
*( ( uint8_t* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
}
}
else if ( post_ops_attr.c_stor_type == BF16 )
{
bfloat16 ctemp[16];
_mm256_mask_storeu_epi16( ctemp, k2,
(__m256i)_mm512_cvtneps_pbh( acc_8 ) );
for (dim_t i = 0; i < mr0; i++)
{
*( ( bfloat16* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
}
}
else if ( post_ops_attr.c_stor_type == F32 )
{
float ctemp[16];
_mm512_mask_storeu_ps( ctemp, k2, acc_8 );
for (dim_t i = 0; i < mr0; i++)
{
*( ( float* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
}
}
}
}
@@ -1112,14 +1196,16 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
{
if(rs_c == 1)
{
_mm512_mask_storeu_epi32(c_use, k2, zmm8);
_mm512_mask_storeu_epi32(c_use, k2,
_mm512_cvtps_epi32( acc_8 ) );
}
else
{
// Store ZMM8 into ctemp buffer and store back
// element by element into output buffer at strides
int32_t ctemp[16];
_mm512_mask_storeu_epi32(ctemp, k2, zmm8);
_mm512_mask_storeu_epi32(ctemp, k2,
_mm512_cvtps_epi32( acc_8 ) );
for (dim_t i = 0; i < mr0; i++)
{
c_use[i * rs_c] = ctemp[i];

View File

@@ -612,29 +612,76 @@ POST_OPS_BIAS_6x64:
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
}
// c[0,0-15]
acc_00 = _mm512_add_ps( b0, acc_00 );
// c[0,16-31]
acc_01 = _mm512_add_ps( b1, acc_01 );
// c[0,32-47]
acc_02 = _mm512_add_ps( b2, acc_02 );
// c[0,48-63]
acc_03 = _mm512_add_ps( b3, acc_03 );
// c[1,0-15]
acc_10 = _mm512_add_ps( b0, acc_10 );
// c[1,16-31]
acc_11 = _mm512_add_ps( b1, acc_11 );
// c[1,32-47]
acc_12 = _mm512_add_ps( b2, acc_12 );
// c[1,48-63]
acc_13 = _mm512_add_ps( b3, acc_13 );
// c[2,0-15]
acc_20 = _mm512_add_ps( b0, acc_20 );
// c[2,16-31]
acc_21 = _mm512_add_ps( b1, acc_21 );
// c[2,32-47]
acc_22 = _mm512_add_ps( b2, acc_22 );
// c[2,48-63]
acc_23 = _mm512_add_ps( b3, acc_23 );
// c[3,0-15]
acc_30 = _mm512_add_ps( b0, acc_30 );
// c[3,16-31]
acc_31 = _mm512_add_ps( b1, acc_31 );
// c[3,32-47]
acc_32 = _mm512_add_ps( b2, acc_32 );
// c[3,48-63]
acc_33 = _mm512_add_ps( b3, acc_33 );
// c[4,0-15]
acc_40 = _mm512_add_ps( b0, acc_40 );
// c[4,16-31]
acc_41 = _mm512_add_ps( b1, acc_41 );
// c[4,32-47]
acc_42 = _mm512_add_ps( b2, acc_42 );
// c[4,48-63]
acc_43 = _mm512_add_ps( b3, acc_43 );
// c[5,0-15]
acc_50 = _mm512_add_ps( b0, acc_50 );
// c[5,16-31]
acc_51 = _mm512_add_ps( b1, acc_51 );
// c[5,32-47]
acc_52 = _mm512_add_ps( b2, acc_52 );
// c[5,48-63]
acc_53 = _mm512_add_ps( b3, acc_53 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -643,29 +690,76 @@ POST_OPS_RELU_6x64:
{
__m512 zero = _mm512_setzero_ps();
// c[0,0-15]
acc_00 = _mm512_max_ps( zero, acc_00 );
// c[0,16-31]
acc_01 = _mm512_max_ps( zero, acc_01 );
// c[0,32-47]
acc_02 = _mm512_max_ps( zero, acc_02 );
// c[0,48-63]
acc_03 = _mm512_max_ps( zero, acc_03 );
// c[1,0-15]
acc_10 = _mm512_max_ps( zero, acc_10 );
// c[1,16-31]
acc_11 = _mm512_max_ps( zero, acc_11 );
// c[1,32-47]
acc_12 = _mm512_max_ps( zero, acc_12 );
// c[1,48-63]
acc_13 = _mm512_max_ps( zero, acc_13 );
// c[2,0-15]
acc_20 = _mm512_max_ps( zero, acc_20 );
// c[2,16-31]
acc_21 = _mm512_max_ps( zero, acc_21 );
// c[2,32-47]
acc_22 = _mm512_max_ps( zero, acc_22 );
// c[2,48-63]
acc_23 = _mm512_max_ps( zero, acc_23 );
// c[3,0-15]
acc_30 = _mm512_max_ps( zero, acc_30 );
// c[3,16-31]
acc_31 = _mm512_max_ps( zero, acc_31 );
// c[3,32-47]
acc_32 = _mm512_max_ps( zero, acc_32 );
// c[3,48-63]
acc_33 = _mm512_max_ps( zero, acc_33 );
// c[4,0-15]
acc_40 = _mm512_max_ps( zero, acc_40 );
// c[4,16-31]
acc_41 = _mm512_max_ps( zero, acc_41 );
// c[4,32-47]
acc_42 = _mm512_max_ps( zero, acc_42 );
// c[4,48-63]
acc_43 = _mm512_max_ps( zero, acc_43 );
// c[5,0-15]
acc_50 = _mm512_max_ps( zero, acc_50 );
// c[5,16-31]
acc_51 = _mm512_max_ps( zero, acc_51 );
// c[5,32-47]
acc_52 = _mm512_max_ps( zero, acc_52 );
// c[5,48-63]
acc_53 = _mm512_max_ps( zero, acc_53 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -770,29 +864,76 @@ POST_OPS_GELU_TANH_6x64:
__m512 dn, z, x, r2, r, y;
__m512i tmpout;
// c[0, 0-15]
GELU_TANH_F32_AVX512_DEF(acc_00, y, r, r2, x, z, dn, tmpout)
// c[0, 16-31]
GELU_TANH_F32_AVX512_DEF(acc_01, y, r, r2, x, z, dn, tmpout)
// c[0, 32-47]
GELU_TANH_F32_AVX512_DEF(acc_02, y, r, r2, x, z, dn, tmpout)
// c[0, 48-63]
GELU_TANH_F32_AVX512_DEF(acc_03, y, r, r2, x, z, dn, tmpout)
// c[1, 0-15]
GELU_TANH_F32_AVX512_DEF(acc_10, y, r, r2, x, z, dn, tmpout)
// c[1, 16-31]
GELU_TANH_F32_AVX512_DEF(acc_11, y, r, r2, x, z, dn, tmpout)
// c[1, 32-47]
GELU_TANH_F32_AVX512_DEF(acc_12, y, r, r2, x, z, dn, tmpout)
// c[1, 48-63]
GELU_TANH_F32_AVX512_DEF(acc_13, y, r, r2, x, z, dn, tmpout)
// c[2, 0-15]
GELU_TANH_F32_AVX512_DEF(acc_20, y, r, r2, x, z, dn, tmpout)
// c[2, 16-31]
GELU_TANH_F32_AVX512_DEF(acc_21, y, r, r2, x, z, dn, tmpout)
// c[2, 32-47]
GELU_TANH_F32_AVX512_DEF(acc_22, y, r, r2, x, z, dn, tmpout)
// c[2, 48-63]
GELU_TANH_F32_AVX512_DEF(acc_23, y, r, r2, x, z, dn, tmpout)
// c[3, 0-15]
GELU_TANH_F32_AVX512_DEF(acc_30, y, r, r2, x, z, dn, tmpout)
// c[3, 16-31]
GELU_TANH_F32_AVX512_DEF(acc_31, y, r, r2, x, z, dn, tmpout)
// c[3, 32-47]
GELU_TANH_F32_AVX512_DEF(acc_32, y, r, r2, x, z, dn, tmpout)
// c[3, 48-63]
GELU_TANH_F32_AVX512_DEF(acc_33, y, r, r2, x, z, dn, tmpout)
// c[4, 0-15]
GELU_TANH_F32_AVX512_DEF(acc_40, y, r, r2, x, z, dn, tmpout)
// c[4, 16-31]
GELU_TANH_F32_AVX512_DEF(acc_41, y, r, r2, x, z, dn, tmpout)
// c[4, 32-47]
GELU_TANH_F32_AVX512_DEF(acc_42, y, r, r2, x, z, dn, tmpout)
// c[4, 48-63]
GELU_TANH_F32_AVX512_DEF(acc_43, y, r, r2, x, z, dn, tmpout)
// c[5, 0-15]
GELU_TANH_F32_AVX512_DEF(acc_50, y, r, r2, x, z, dn, tmpout)
// c[5, 16-31]
GELU_TANH_F32_AVX512_DEF(acc_51, y, r, r2, x, z, dn, tmpout)
// c[5, 32-47]
GELU_TANH_F32_AVX512_DEF(acc_52, y, r, r2, x, z, dn, tmpout)
// c[5, 48-63]
GELU_TANH_F32_AVX512_DEF(acc_53, y, r, r2, x, z, dn, tmpout)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -801,29 +942,76 @@ POST_OPS_GELU_ERF_6x64:
{
__m512 y, r, r2;
// c[0, 0-15]
GELU_ERF_F32_AVX512_DEF(acc_00, y, r, r2)
// c[0, 16-31]
GELU_ERF_F32_AVX512_DEF(acc_01, y, r, r2)
// c[0, 32-47]
GELU_ERF_F32_AVX512_DEF(acc_02, y, r, r2)
// c[0, 48-63]
GELU_ERF_F32_AVX512_DEF(acc_03, y, r, r2)
// c[1, 0-15]
GELU_ERF_F32_AVX512_DEF(acc_10, y, r, r2)
// c[1, 16-31]
GELU_ERF_F32_AVX512_DEF(acc_11, y, r, r2)
// c[1, 32-47]
GELU_ERF_F32_AVX512_DEF(acc_12, y, r, r2)
// c[1, 48-63]
GELU_ERF_F32_AVX512_DEF(acc_13, y, r, r2)
// c[2, 0-15]
GELU_ERF_F32_AVX512_DEF(acc_20, y, r, r2)
// c[2, 16-31]
GELU_ERF_F32_AVX512_DEF(acc_21, y, r, r2)
// c[2, 32-47]
GELU_ERF_F32_AVX512_DEF(acc_22, y, r, r2)
// c[2, 48-63]
GELU_ERF_F32_AVX512_DEF(acc_23, y, r, r2)
// c[3, 0-15]
GELU_ERF_F32_AVX512_DEF(acc_30, y, r, r2)
// c[3, 16-31]
GELU_ERF_F32_AVX512_DEF(acc_31, y, r, r2)
// c[3, 32-47]
GELU_ERF_F32_AVX512_DEF(acc_32, y, r, r2)
// c[3, 48-63]
GELU_ERF_F32_AVX512_DEF(acc_33, y, r, r2)
// c[4, 0-15]
GELU_ERF_F32_AVX512_DEF(acc_40, y, r, r2)
// c[4, 16-31]
GELU_ERF_F32_AVX512_DEF(acc_41, y, r, r2)
// c[4, 32-47]
GELU_ERF_F32_AVX512_DEF(acc_42, y, r, r2)
// c[4, 48-63]
GELU_ERF_F32_AVX512_DEF(acc_43, y, r, r2)
// c[5, 0-15]
GELU_ERF_F32_AVX512_DEF(acc_50, y, r, r2)
// c[5, 16-31]
GELU_ERF_F32_AVX512_DEF(acc_51, y, r, r2)
// c[5, 32-47]
GELU_ERF_F32_AVX512_DEF(acc_52, y, r, r2)
// c[5, 48-63]
GELU_ERF_F32_AVX512_DEF(acc_53, y, r, r2)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR

View File

@@ -433,25 +433,64 @@ POST_OPS_BIAS_5x64:
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
}
// c[0,0-15]
acc_00 = _mm512_add_ps( b0, acc_00 );
// c[0,16-31]
acc_01 = _mm512_add_ps( b1, acc_01 );
// c[0,32-47]
acc_02 = _mm512_add_ps( b2, acc_02 );
// c[0,48-63]
acc_03 = _mm512_add_ps( b3, acc_03 );
// c[1,0-15]
acc_10 = _mm512_add_ps( b0, acc_10 );
// c[1,16-31]
acc_11 = _mm512_add_ps( b1, acc_11 );
// c[1,32-47]
acc_12 = _mm512_add_ps( b2, acc_12 );
// c[1,48-63]
acc_13 = _mm512_add_ps( b3, acc_13 );
// c[2,0-15]
acc_20 = _mm512_add_ps( b0, acc_20 );
// c[2,16-31]
acc_21 = _mm512_add_ps( b1, acc_21 );
// c[2,32-47]
acc_22 = _mm512_add_ps( b2, acc_22 );
// c[2,48-63]
acc_23 = _mm512_add_ps( b3, acc_23 );
// c[3,0-15]
acc_30 = _mm512_add_ps( b0, acc_30 );
// c[3,16-31]
acc_31 = _mm512_add_ps( b1, acc_31 );
// c[3,32-47]
acc_32 = _mm512_add_ps( b2, acc_32 );
// c[3,48-63]
acc_33 = _mm512_add_ps( b3, acc_33 );
// c[4,0-15]
acc_40 = _mm512_add_ps( b0, acc_40 );
// c[4,16-31]
acc_41 = _mm512_add_ps( b1, acc_41 );
// c[4,32-47]
acc_42 = _mm512_add_ps( b2, acc_42 );
// c[4,48-63]
acc_43 = _mm512_add_ps( b3, acc_43 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -460,25 +499,64 @@ POST_OPS_RELU_5x64:
{
__m512 zero = _mm512_setzero_ps();
// c[0,0-15]
acc_00 = _mm512_max_ps( zero, acc_00 );
// c[0,16-31]
acc_01 = _mm512_max_ps( zero, acc_01 );
// c[0,32-47]
acc_02 = _mm512_max_ps( zero, acc_02 );
// c[0,48-63]
acc_03 = _mm512_max_ps( zero, acc_03 );
// c[1,0-15]
acc_10 = _mm512_max_ps( zero, acc_10 );
// c[1,16-31]
acc_11 = _mm512_max_ps( zero, acc_11 );
// c[1,32-47]
acc_12 = _mm512_max_ps( zero, acc_12 );
// c[1,48-63]
acc_13 = _mm512_max_ps( zero, acc_13 );
// c[2,0-15]
acc_20 = _mm512_max_ps( zero, acc_20 );
// c[2,16-31]
acc_21 = _mm512_max_ps( zero, acc_21 );
// c[2,32-47]
acc_22 = _mm512_max_ps( zero, acc_22 );
// c[2,48-63]
acc_23 = _mm512_max_ps( zero, acc_23 );
// c[3,0-15]
acc_30 = _mm512_max_ps( zero, acc_30 );
// c[3,16-31]
acc_31 = _mm512_max_ps( zero, acc_31 );
// c[3,32-47]
acc_32 = _mm512_max_ps( zero, acc_32 );
// c[3,48-63]
acc_33 = _mm512_max_ps( zero, acc_33 );
// c[4,0-15]
acc_40 = _mm512_max_ps( zero, acc_40 );
// c[4,16-31]
acc_41 = _mm512_max_ps( zero, acc_41 );
// c[4,32-47]
acc_42 = _mm512_max_ps( zero, acc_42 );
// c[4,48-63]
acc_43 = _mm512_max_ps( zero, acc_43 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -569,25 +647,64 @@ POST_OPS_GELU_TANH_5x64:
__m512 dn, z, x, r2, r, y;
__m512i tmpout;
// c[0,0-15]
GELU_TANH_F32_AVX512_DEF(acc_00, y, r, r2, x, z, dn, tmpout)
// c[0,16-31]
GELU_TANH_F32_AVX512_DEF(acc_01, y, r, r2, x, z, dn, tmpout)
// c[0,32-47]
GELU_TANH_F32_AVX512_DEF(acc_02, y, r, r2, x, z, dn, tmpout)
// c[0,48-63]
GELU_TANH_F32_AVX512_DEF(acc_03, y, r, r2, x, z, dn, tmpout)
// c[1,0-15]
GELU_TANH_F32_AVX512_DEF(acc_10, y, r, r2, x, z, dn, tmpout)
// c[1,16-31]
GELU_TANH_F32_AVX512_DEF(acc_11, y, r, r2, x, z, dn, tmpout)
// c[1,32-47]
GELU_TANH_F32_AVX512_DEF(acc_12, y, r, r2, x, z, dn, tmpout)
// c[1,48-63]
GELU_TANH_F32_AVX512_DEF(acc_13, y, r, r2, x, z, dn, tmpout)
// c[2,0-15]
GELU_TANH_F32_AVX512_DEF(acc_20, y, r, r2, x, z, dn, tmpout)
// c[2,16-31]
GELU_TANH_F32_AVX512_DEF(acc_21, y, r, r2, x, z, dn, tmpout)
// c[2,32-47]
GELU_TANH_F32_AVX512_DEF(acc_22, y, r, r2, x, z, dn, tmpout)
// c[2,48-63]
GELU_TANH_F32_AVX512_DEF(acc_23, y, r, r2, x, z, dn, tmpout)
// c[3,0-15]
GELU_TANH_F32_AVX512_DEF(acc_30, y, r, r2, x, z, dn, tmpout)
// c[3,16-31]
GELU_TANH_F32_AVX512_DEF(acc_31, y, r, r2, x, z, dn, tmpout)
// c[3,32-47]
GELU_TANH_F32_AVX512_DEF(acc_32, y, r, r2, x, z, dn, tmpout)
// c[3,48-63]
GELU_TANH_F32_AVX512_DEF(acc_33, y, r, r2, x, z, dn, tmpout)
// c[4,0-15]
GELU_TANH_F32_AVX512_DEF(acc_40, y, r, r2, x, z, dn, tmpout)
// c[4,16-31]
GELU_TANH_F32_AVX512_DEF(acc_41, y, r, r2, x, z, dn, tmpout)
// c[4,32-47]
GELU_TANH_F32_AVX512_DEF(acc_42, y, r, r2, x, z, dn, tmpout)
// c[4,48-63]
GELU_TANH_F32_AVX512_DEF(acc_43, y, r, r2, x, z, dn, tmpout)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -596,25 +713,64 @@ POST_OPS_GELU_ERF_5x64:
{
__m512 y, r, r2;
// c[0,0-15]
GELU_ERF_F32_AVX512_DEF(acc_00, y, r, r2)
// c[0,16-31]
GELU_ERF_F32_AVX512_DEF(acc_01, y, r, r2)
// c[0,32-47]
GELU_ERF_F32_AVX512_DEF(acc_02, y, r, r2)
// c[0,48-63]
GELU_ERF_F32_AVX512_DEF(acc_03, y, r, r2)
// c[1,0-15]
GELU_ERF_F32_AVX512_DEF(acc_10, y, r, r2)
// c[1,16-31]
GELU_ERF_F32_AVX512_DEF(acc_11, y, r, r2)
// c[1,32-47]
GELU_ERF_F32_AVX512_DEF(acc_12, y, r, r2)
// c[1,48-63]
GELU_ERF_F32_AVX512_DEF(acc_13, y, r, r2)
// c[2,0-15]
GELU_ERF_F32_AVX512_DEF(acc_20, y, r, r2)
// c[2,16-31]
GELU_ERF_F32_AVX512_DEF(acc_21, y, r, r2)
// c[2,32-47]
GELU_ERF_F32_AVX512_DEF(acc_22, y, r, r2)
// c[2,48-63]
GELU_ERF_F32_AVX512_DEF(acc_23, y, r, r2)
// c[3,0-15]
GELU_ERF_F32_AVX512_DEF(acc_30, y, r, r2)
// c[3,16-31]
GELU_ERF_F32_AVX512_DEF(acc_31, y, r, r2)
// c[3,32-47]
GELU_ERF_F32_AVX512_DEF(acc_32, y, r, r2)
// c[3,48-63]
GELU_ERF_F32_AVX512_DEF(acc_33, y, r, r2)
// c[4,0-15]
GELU_ERF_F32_AVX512_DEF(acc_40, y, r, r2)
// c[4,16-31]
GELU_ERF_F32_AVX512_DEF(acc_41, y, r, r2)
// c[4,32-47]
GELU_ERF_F32_AVX512_DEF(acc_42, y, r, r2)
// c[4,48-63]
GELU_ERF_F32_AVX512_DEF(acc_43, y, r, r2)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -2270,21 +2426,52 @@ POST_OPS_BIAS_4x64:
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
}
// c[0,0-15]
acc_00 = _mm512_add_ps( b0, acc_00 );
// c[0,16-31]
acc_01 = _mm512_add_ps( b1, acc_01 );
// c[0,32-47]
acc_02 = _mm512_add_ps( b2, acc_02 );
// c[0,48-63]
acc_03 = _mm512_add_ps( b3, acc_03 );
// c[1,0-15]
acc_10 = _mm512_add_ps( b0, acc_10 );
// c[1,16-31]
acc_11 = _mm512_add_ps( b1, acc_11 );
// c[1,32-47]
acc_12 = _mm512_add_ps( b2, acc_12 );
// c[1,48-63]
acc_13 = _mm512_add_ps( b3, acc_13 );
// c[2,0-15]
acc_20 = _mm512_add_ps( b0, acc_20 );
// c[2,16-31]
acc_21 = _mm512_add_ps( b1, acc_21 );
// c[2,32-47]
acc_22 = _mm512_add_ps( b2, acc_22 );
// c[2,48-63]
acc_23 = _mm512_add_ps( b3, acc_23 );
// c[3,0-15]
acc_30 = _mm512_add_ps( b0, acc_30 );
// c[3,16-31]
acc_31 = _mm512_add_ps( b1, acc_31 );
// c[3,32-47]
acc_32 = _mm512_add_ps( b2, acc_32 );
// c[3,48-63]
acc_33 = _mm512_add_ps( b3, acc_33 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -2293,21 +2480,52 @@ POST_OPS_RELU_4x64:
{
__m512 zero = _mm512_setzero_ps();
// c[0,0-15]
acc_00 = _mm512_max_ps( zero, acc_00 );
// c[0,16-31]
acc_01 = _mm512_max_ps( zero, acc_01 );
// c[0,32-47]
acc_02 = _mm512_max_ps( zero, acc_02 );
// c[0,48-63]
acc_03 = _mm512_max_ps( zero, acc_03 );
// c[1,0-15]
acc_10 = _mm512_max_ps( zero, acc_10 );
// c[1,16-31]
acc_11 = _mm512_max_ps( zero, acc_11 );
// c[1,32-47]
acc_12 = _mm512_max_ps( zero, acc_12 );
// c[1,48-63]
acc_13 = _mm512_max_ps( zero, acc_13 );
// c[2,0-15]
acc_20 = _mm512_max_ps( zero, acc_20 );
// c[2,16-31]
acc_21 = _mm512_max_ps( zero, acc_21 );
// c[2,32-47]
acc_22 = _mm512_max_ps( zero, acc_22 );
// c[2,48-63]
acc_23 = _mm512_max_ps( zero, acc_23 );
// c[3,0-15]
acc_30 = _mm512_max_ps( zero, acc_30 );
// c[3,16-31]
acc_31 = _mm512_max_ps( zero, acc_31 );
// c[3,32-47]
acc_32 = _mm512_max_ps( zero, acc_32 );
// c[3,48-63]
acc_33 = _mm512_max_ps( zero, acc_33 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -2386,21 +2604,52 @@ POST_OPS_GELU_TANH_4x64:
__m512 dn, z, x, r2, r, y;
__m512i tmpout;
// c[0,0-15]
GELU_TANH_F32_AVX512_DEF(acc_00, y, r, r2, x, z, dn, tmpout)
// c[0,16-31]
GELU_TANH_F32_AVX512_DEF(acc_01, y, r, r2, x, z, dn, tmpout)
// c[0,32-47]
GELU_TANH_F32_AVX512_DEF(acc_02, y, r, r2, x, z, dn, tmpout)
// c[0,48-63]
GELU_TANH_F32_AVX512_DEF(acc_03, y, r, r2, x, z, dn, tmpout)
// c[1,0-15]
GELU_TANH_F32_AVX512_DEF(acc_10, y, r, r2, x, z, dn, tmpout)
// c[1,16-31]
GELU_TANH_F32_AVX512_DEF(acc_11, y, r, r2, x, z, dn, tmpout)
// c[1,32-47]
GELU_TANH_F32_AVX512_DEF(acc_12, y, r, r2, x, z, dn, tmpout)
// c[1,48-63]
GELU_TANH_F32_AVX512_DEF(acc_13, y, r, r2, x, z, dn, tmpout)
// c[2,0-15]
GELU_TANH_F32_AVX512_DEF(acc_20, y, r, r2, x, z, dn, tmpout)
// c[2,16-31]
GELU_TANH_F32_AVX512_DEF(acc_21, y, r, r2, x, z, dn, tmpout)
// c[2,32-47]
GELU_TANH_F32_AVX512_DEF(acc_22, y, r, r2, x, z, dn, tmpout)
// c[2,48-63]
GELU_TANH_F32_AVX512_DEF(acc_23, y, r, r2, x, z, dn, tmpout)
// c[3,0-15]
GELU_TANH_F32_AVX512_DEF(acc_30, y, r, r2, x, z, dn, tmpout)
// c[3,16-31]
GELU_TANH_F32_AVX512_DEF(acc_31, y, r, r2, x, z, dn, tmpout)
// c[3,32-47]
GELU_TANH_F32_AVX512_DEF(acc_32, y, r, r2, x, z, dn, tmpout)
// c[3,48-63]
GELU_TANH_F32_AVX512_DEF(acc_33, y, r, r2, x, z, dn, tmpout)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -2409,21 +2658,52 @@ POST_OPS_GELU_ERF_4x64:
{
__m512 y, r, r2;
// c[0,0-15]
GELU_ERF_F32_AVX512_DEF(acc_00, y, r, r2)
// c[0,16-31]
GELU_ERF_F32_AVX512_DEF(acc_01, y, r, r2)
// c[0,32-47]
GELU_ERF_F32_AVX512_DEF(acc_02, y, r, r2)
// c[0,48-63]
GELU_ERF_F32_AVX512_DEF(acc_03, y, r, r2)
// c[1,0-15]
GELU_ERF_F32_AVX512_DEF(acc_10, y, r, r2)
// c[1,16-31]
GELU_ERF_F32_AVX512_DEF(acc_11, y, r, r2)
// c[1,32-47]
GELU_ERF_F32_AVX512_DEF(acc_12, y, r, r2)
// c[1,48-63]
GELU_ERF_F32_AVX512_DEF(acc_13, y, r, r2)
// c[2,0-15]
GELU_ERF_F32_AVX512_DEF(acc_20, y, r, r2)
// c[2,16-31]
GELU_ERF_F32_AVX512_DEF(acc_21, y, r, r2)
// c[2,32-47]
GELU_ERF_F32_AVX512_DEF(acc_22, y, r, r2)
// c[2,48-63]
GELU_ERF_F32_AVX512_DEF(acc_23, y, r, r2)
// c[3,0-15]
GELU_ERF_F32_AVX512_DEF(acc_30, y, r, r2)
// c[3,16-31]
GELU_ERF_F32_AVX512_DEF(acc_31, y, r, r2)
// c[3,32-47]
GELU_ERF_F32_AVX512_DEF(acc_32, y, r, r2)
// c[3,48-63]
GELU_ERF_F32_AVX512_DEF(acc_33, y, r, r2)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -3829,17 +4109,40 @@ POST_OPS_BIAS_3x64:
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
}
// c[0,0-15]
acc_00 = _mm512_add_ps( b0, acc_00 );
// c[0,16-31]
acc_01 = _mm512_add_ps( b1, acc_01 );
// c[0,32-47]
acc_02 = _mm512_add_ps( b2, acc_02 );
// c[0,48-63]
acc_03 = _mm512_add_ps( b3, acc_03 );
// c[1,0-15]
acc_10 = _mm512_add_ps( b0, acc_10 );
// c[1,16-31]
acc_11 = _mm512_add_ps( b1, acc_11 );
// c[1,32-47]
acc_12 = _mm512_add_ps( b2, acc_12 );
// c[1,48-63]
acc_13 = _mm512_add_ps( b3, acc_13 );
// c[2,0-15]
acc_20 = _mm512_add_ps( b0, acc_20 );
// c[2,16-31]
acc_21 = _mm512_add_ps( b1, acc_21 );
// c[2,32-47]
acc_22 = _mm512_add_ps( b2, acc_22 );
// c[2,48-63]
acc_23 = _mm512_add_ps( b3, acc_23 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -3848,17 +4151,40 @@ POST_OPS_RELU_3x64:
{
__m512 zero = _mm512_setzero_ps();
// c[0,0-15]
acc_00 = _mm512_max_ps( zero, acc_00 );
// c[0,16-31]
acc_01 = _mm512_max_ps( zero, acc_01 );
// c[0,32-47]
acc_02 = _mm512_max_ps( zero, acc_02 );
// c[0,48-63]
acc_03 = _mm512_max_ps( zero, acc_03 );
// c[1,0-15]
acc_10 = _mm512_max_ps( zero, acc_10 );
// c[1,16-31]
acc_11 = _mm512_max_ps( zero, acc_11 );
// c[1,32-47]
acc_12 = _mm512_max_ps( zero, acc_12 );
// c[1,48-63]
acc_13 = _mm512_max_ps( zero, acc_13 );
// c[2,0-15]
acc_20 = _mm512_max_ps( zero, acc_20 );
// c[2,16-31]
acc_21 = _mm512_max_ps( zero, acc_21 );
// c[2,32-47]
acc_22 = _mm512_max_ps( zero, acc_22 );
// c[2,48-63]
acc_23 = _mm512_max_ps( zero, acc_23 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -3925,17 +4251,40 @@ POST_OPS_GELU_TANH_3x64:
__m512 dn, z, x, r2, r, y;
__m512i tmpout;
// c[0,0-15]
GELU_TANH_F32_AVX512_DEF(acc_00, y, r, r2, x, z, dn, tmpout)
// c[0,16-31]
GELU_TANH_F32_AVX512_DEF(acc_01, y, r, r2, x, z, dn, tmpout)
// c[0,32-47]
GELU_TANH_F32_AVX512_DEF(acc_02, y, r, r2, x, z, dn, tmpout)
// c[0,48-63]
GELU_TANH_F32_AVX512_DEF(acc_03, y, r, r2, x, z, dn, tmpout)
// c[1,0-15]
GELU_TANH_F32_AVX512_DEF(acc_10, y, r, r2, x, z, dn, tmpout)
// c[1,16-31]
GELU_TANH_F32_AVX512_DEF(acc_11, y, r, r2, x, z, dn, tmpout)
// c[1,32-47]
GELU_TANH_F32_AVX512_DEF(acc_12, y, r, r2, x, z, dn, tmpout)
// c[1,48-63]
GELU_TANH_F32_AVX512_DEF(acc_13, y, r, r2, x, z, dn, tmpout)
// c[2,0-15]
GELU_TANH_F32_AVX512_DEF(acc_20, y, r, r2, x, z, dn, tmpout)
// c[2,16-31]
GELU_TANH_F32_AVX512_DEF(acc_21, y, r, r2, x, z, dn, tmpout)
// c[2,32-47]
GELU_TANH_F32_AVX512_DEF(acc_22, y, r, r2, x, z, dn, tmpout)
// c[2,48-63]
GELU_TANH_F32_AVX512_DEF(acc_23, y, r, r2, x, z, dn, tmpout)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -3944,17 +4293,40 @@ POST_OPS_GELU_ERF_3x64:
{
__m512 y, r, r2;
// c[0,0-15]
GELU_ERF_F32_AVX512_DEF(acc_00, y, r, r2)
// c[0,16-31]
GELU_ERF_F32_AVX512_DEF(acc_01, y, r, r2)
// c[0,32-47]
GELU_ERF_F32_AVX512_DEF(acc_02, y, r, r2)
// c[0,48-63]
GELU_ERF_F32_AVX512_DEF(acc_03, y, r, r2)
// c[1,0-15]
GELU_ERF_F32_AVX512_DEF(acc_10, y, r, r2)
// c[1,16-31]
GELU_ERF_F32_AVX512_DEF(acc_11, y, r, r2)
// c[1,32-47]
GELU_ERF_F32_AVX512_DEF(acc_12, y, r, r2)
// c[1,48-63]
GELU_ERF_F32_AVX512_DEF(acc_13, y, r, r2)
// c[2,0-15]
GELU_ERF_F32_AVX512_DEF(acc_20, y, r, r2)
// c[2,16-31]
GELU_ERF_F32_AVX512_DEF(acc_21, y, r, r2)
// c[2,32-47]
GELU_ERF_F32_AVX512_DEF(acc_22, y, r, r2)
// c[2,48-63]
GELU_ERF_F32_AVX512_DEF(acc_23, y, r, r2)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -5117,13 +5489,28 @@ POST_OPS_BIAS_2x64:
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
}
// c[0,0-15]
acc_00 = _mm512_add_ps( b0, acc_00 );
// c[0,16-31]
acc_01 = _mm512_add_ps( b1, acc_01 );
// c[0,32-47]
acc_02 = _mm512_add_ps( b2, acc_02 );
// c[0,48-63]
acc_03 = _mm512_add_ps( b3, acc_03 );
// c[1,0-15]
acc_10 = _mm512_add_ps( b0, acc_10 );
// c[1,16-31]
acc_11 = _mm512_add_ps( b1, acc_11 );
// c[1,32-47]
acc_12 = _mm512_add_ps( b2, acc_12 );
// c[1,48-63]
acc_13 = _mm512_add_ps( b3, acc_13 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -5132,13 +5519,28 @@ POST_OPS_RELU_2x64:
{
__m512 zero = _mm512_setzero_ps();
// c[0,0-15]
acc_00 = _mm512_max_ps( zero, acc_00 );
// c[0,16-31]
acc_01 = _mm512_max_ps( zero, acc_01 );
// c[0,32-47]
acc_02 = _mm512_max_ps( zero, acc_02 );
// c[0,48-63]
acc_03 = _mm512_max_ps( zero, acc_03 );
// c[1,0-15]
acc_10 = _mm512_max_ps( zero, acc_10 );
// c[1,16-31]
acc_11 = _mm512_max_ps( zero, acc_11 );
// c[1,32-47]
acc_12 = _mm512_max_ps( zero, acc_12 );
// c[1,48-63]
acc_13 = _mm512_max_ps( zero, acc_13 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -5193,13 +5595,28 @@ POST_OPS_GELU_TANH_2x64:
__m512 dn, z, x, r2, r, y;
__m512i tmpout;
// c[0,0-15]
GELU_TANH_F32_AVX512_DEF(acc_00, y, r, r2, x, z, dn, tmpout)
// c[0,16-31]
GELU_TANH_F32_AVX512_DEF(acc_01, y, r, r2, x, z, dn, tmpout)
// c[0,32-47]
GELU_TANH_F32_AVX512_DEF(acc_02, y, r, r2, x, z, dn, tmpout)
// c[0,48-63]
GELU_TANH_F32_AVX512_DEF(acc_03, y, r, r2, x, z, dn, tmpout)
// c[1,0-15]
GELU_TANH_F32_AVX512_DEF(acc_10, y, r, r2, x, z, dn, tmpout)
// c[1,16-31]
GELU_TANH_F32_AVX512_DEF(acc_11, y, r, r2, x, z, dn, tmpout)
// c[1,32-47]
GELU_TANH_F32_AVX512_DEF(acc_12, y, r, r2, x, z, dn, tmpout)
// c[1,48-63]
GELU_TANH_F32_AVX512_DEF(acc_13, y, r, r2, x, z, dn, tmpout)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -5208,13 +5625,28 @@ POST_OPS_GELU_ERF_2x64:
{
__m512 y, r, r2;
// c[0,0-15]
GELU_ERF_F32_AVX512_DEF(acc_00, y, r, r2)
// c[0,16-31]
GELU_ERF_F32_AVX512_DEF(acc_01, y, r, r2)
// c[0,32-47]
GELU_ERF_F32_AVX512_DEF(acc_02, y, r, r2)
// c[0,48-63]
GELU_ERF_F32_AVX512_DEF(acc_03, y, r, r2)
// c[1,0-15]
GELU_ERF_F32_AVX512_DEF(acc_10, y, r, r2)
// c[1,16-31]
GELU_ERF_F32_AVX512_DEF(acc_11, y, r, r2)
// c[1,32-47]
GELU_ERF_F32_AVX512_DEF(acc_12, y, r, r2)
// c[1,48-63]
GELU_ERF_F32_AVX512_DEF(acc_13, y, r, r2)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -6121,9 +6553,16 @@ POST_OPS_BIAS_1x64:
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
}
// c[0,0-15]
acc_00 = _mm512_add_ps( b0, acc_00 );
// c[0,16-31]
acc_01 = _mm512_add_ps( b1, acc_01 );
// c[0,32-47]
acc_02 = _mm512_add_ps( b2, acc_02 );
// c[0,48-63]
acc_03 = _mm512_add_ps( b3, acc_03 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -6132,9 +6571,16 @@ POST_OPS_RELU_1x64:
{
__m512 zero = _mm512_setzero_ps();
// c[0,0-15]
acc_00 = _mm512_max_ps( zero, acc_00 );
// c[0,16-31]
acc_01 = _mm512_max_ps( zero, acc_01 );
// c[0,32-47]
acc_02 = _mm512_max_ps( zero, acc_02 );
// c[0,48-63]
acc_03 = _mm512_max_ps( zero, acc_03 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -6177,9 +6623,16 @@ POST_OPS_GELU_TANH_1x64:
__m512 dn, z, x, r2, r, y;
__m512i tmpout;
// c[0,0-15]
GELU_TANH_F32_AVX512_DEF(acc_00, y, r, r2, x, z, dn, tmpout)
// c[0,16-31]
GELU_TANH_F32_AVX512_DEF(acc_01, y, r, r2, x, z, dn, tmpout)
// c[0,32-47]
GELU_TANH_F32_AVX512_DEF(acc_02, y, r, r2, x, z, dn, tmpout)
// c[0,48-63]
GELU_TANH_F32_AVX512_DEF(acc_03, y, r, r2, x, z, dn, tmpout)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -6188,9 +6641,16 @@ POST_OPS_GELU_ERF_1x64:
{
__m512 y, r, r2;
// c[0,0-15]
GELU_ERF_F32_AVX512_DEF(acc_00, y, r, r2)
// c[0,16-31]
GELU_ERF_F32_AVX512_DEF(acc_01, y, r, r2)
// c[0,32-47]
GELU_ERF_F32_AVX512_DEF(acc_02, y, r, r2)
// c[0,48-63]
GELU_ERF_F32_AVX512_DEF(acc_03, y, r, r2)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR