mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Updated all post-ops in s8s8s32 API to operate in float precision
Description: 1. Changed all post-ops in s8s8s32o<s32|s8|u8|f32|bf16> to operate on float data. All the post-ops are updated to operate on f32 by converting s32 accumulator registers to float at the end of k loop. Changed all post-ops to operate on float data. 2. Added s8s8s32ou8 API which uses s8s8s32os32 kernels but store the output in u8 AMD-Internal - SWLCSG-3366 Change-Id: Iadfd9bfb98fc3bf21e675acb95553fe967b806a6
This commit is contained in:
committed by
Nallani Bhaskar
parent
2ece628a4d
commit
2e687d8847
@@ -140,12 +140,13 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16);
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8);
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,bfloat16,int32_t,u8s8s32obf16);
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,float,int32_t,u8s8s32of32);
|
||||
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8);
|
||||
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32);
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8);
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16);
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,float,int32_t,s8s8s32of32);
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8);
|
||||
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16);
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8);
|
||||
|
||||
205
addon/aocl_gemm/aocl_gemm_s8s8s32ou8.c
Normal file
205
addon/aocl_gemm/aocl_gemm_s8s8s32ou8.c
Normal file
@@ -0,0 +1,205 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "aocl_gemm_interface_apis.h"
|
||||
#include "aocl_gemm_check.h"
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_post_ops.h"
|
||||
#include "lpgemm_thread_decor_openmp.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"s8s8s32ou8", \
|
||||
order, transa, transb, \
|
||||
m, n, k, \
|
||||
( ( float ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
( ( float ) beta ), \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if (bli_cpuid_is_avx512vnni_supported() == FALSE)
|
||||
{
|
||||
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
|
||||
"cannot perform u8s8s32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
AOCL_GEMM_CHECK
|
||||
(
|
||||
"s8s8s32ou8",
|
||||
order, transa, transb,
|
||||
m, n, k,
|
||||
a, lda, mem_format_a,
|
||||
b, ldb, mem_format_b,
|
||||
c, ldc,
|
||||
err_no
|
||||
);
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
|
||||
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
|
||||
|
||||
// Column major support disabled for u8s8s32 APIs as we cannot
|
||||
// swap matrices as both A and B are of different types.
|
||||
if ( ( order != 'r' ) && ( order != 'R' ) )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
inc_t rs_a = lda;
|
||||
inc_t cs_a = 1;
|
||||
|
||||
if (bli_is_trans(blis_transa))
|
||||
{
|
||||
rs_a = 1;
|
||||
cs_a = lda;
|
||||
}
|
||||
|
||||
inc_t rs_b = ldb;
|
||||
inc_t cs_b = 1;
|
||||
|
||||
if (bli_is_trans(blis_transb))
|
||||
{
|
||||
rs_b = 1;
|
||||
cs_b = ldb;
|
||||
}
|
||||
|
||||
const inc_t rs_c = ldc;
|
||||
const inc_t cs_c = 1;
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a;
|
||||
AOCL_MEMORY_TAG mtag_b;
|
||||
|
||||
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
|
||||
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if (mtag_a == REORDERED)
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported "
|
||||
"in row major case.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if (mtag_b == UNPACKED)
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if (bli_is_trans(blis_transa))
|
||||
{
|
||||
mtag_a = PACK;
|
||||
}
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed, post_op_list,
|
||||
( void* )c, ( void* )( &order ),
|
||||
m, n
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
lpgemm_s8s8s32o32_openmp_thread_decorator(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(int32_t *)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, U8);
|
||||
#else
|
||||
lpgemm_s8s8s32o32_thread_decorator(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(int32_t *)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, U8);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
@@ -108,6 +108,7 @@ GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
GEN_BLIS_MAT_MUL_FUNC(float,float,float,float,f32f32f32of32)
|
||||
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
|
||||
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
|
||||
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8)
|
||||
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
|
||||
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32)
|
||||
GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
@@ -204,6 +205,7 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,float,f32f32f32of32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
@@ -246,6 +248,7 @@ GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,u8s8s32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,u8s8s32obf16)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,s8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,s8s8s32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,s8s8s32obf16)
|
||||
|
||||
@@ -339,6 +342,7 @@ static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
|
||||
@@ -715,6 +719,7 @@ GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32obf16)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32of32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32ou8)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32os32)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32obf16)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(s8s8s32of32)
|
||||
@@ -731,6 +736,7 @@ GEN_TANH_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_TANH_POSTOP_FLOAT(u8s8s32obf16)
|
||||
GEN_TANH_POSTOP_FLOAT(u8s8s32of32)
|
||||
GEN_TANH_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_TANH_POSTOP_FLOAT(s8s8s32ou8)
|
||||
GEN_TANH_POSTOP_FLOAT(s8s8s32obf16)
|
||||
GEN_TANH_POSTOP_FLOAT(s8s8s32of32)
|
||||
GEN_TANH_POSTOP_FLOAT(s8s8s32os32)
|
||||
@@ -747,6 +753,7 @@ GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32obf16)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(u8s8s32of32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32ou8)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32os32)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32obf16)
|
||||
GEN_GELU_ERF_POSTOP_FLOAT(s8s8s32of32)
|
||||
@@ -763,6 +770,7 @@ GEN_SWISH_POSTOP_INT(float,u8s8s32os32)
|
||||
GEN_SWISH_POSTOP_FLOAT(u8s8s32obf16)
|
||||
GEN_SWISH_POSTOP_FLOAT(u8s8s32of32)
|
||||
GEN_SWISH_POSTOP_INT(float,s8s8s32os8)
|
||||
GEN_SWISH_POSTOP_INT(float,s8s8s32ou8)
|
||||
GEN_SWISH_POSTOP_INT(float,s8s8s32os32)
|
||||
GEN_SWISH_POSTOP_FLOAT(s8s8s32obf16)
|
||||
GEN_SWISH_POSTOP_FLOAT(s8s8s32of32)
|
||||
@@ -779,6 +787,7 @@ GEN_SIGMOID_POSTOP_FLOAT(u8s8s32os32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32obf16)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(u8s8s32of32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32os8)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32ou8)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32os32)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32obf16)
|
||||
GEN_SIGMOID_POSTOP_FLOAT(s8s8s32of32)
|
||||
@@ -797,10 +806,11 @@ GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32ou8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32os32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,u8s8s32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32os32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,s8s8s32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,s8s8s32os8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,s8s8s32ou8)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,s8s8s32os32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,s8s8s32obf16)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,s8s8s32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,f32f32f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16bf16f32of32)
|
||||
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,bf16s4f32of32)
|
||||
@@ -813,10 +823,11 @@ GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32ou8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32os32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,u8s8s32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32os32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,s8s8s32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,s8s8s32os8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,s8s8s32ou8)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,s8s8s32os32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,s8s8s32obf16)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,s8s8s32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,f32f32f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16bf16f32of32)
|
||||
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,bf16s4f32of32)
|
||||
@@ -836,6 +847,7 @@ GEN_PRELU_POST_OP_VAL_INT(u8s8s32os8)
|
||||
GEN_PRELU_POST_OP_VAL_INT(u8s8s32ou8)
|
||||
GEN_PRELU_POST_OP_VAL_INT(u8s8s32os32)
|
||||
GEN_PRELU_POST_OP_VAL_INT(s8s8s32os8)
|
||||
GEN_PRELU_POST_OP_VAL_INT(s8s8s32ou8)
|
||||
GEN_PRELU_POST_OP_VAL_INT(s8s8s32os32)
|
||||
|
||||
|
||||
@@ -853,6 +865,7 @@ GEN_CLIP_POST_OP_VAL_INT(u8s8s32os8)
|
||||
GEN_CLIP_POST_OP_VAL_INT(u8s8s32ou8)
|
||||
GEN_CLIP_POST_OP_VAL_INT(u8s8s32os32)
|
||||
GEN_CLIP_POST_OP_VAL_INT(s8s8s32os8)
|
||||
GEN_CLIP_POST_OP_VAL_INT(s8s8s32ou8)
|
||||
GEN_CLIP_POST_OP_VAL_INT(s8s8s32os32)
|
||||
|
||||
|
||||
@@ -864,10 +877,11 @@ GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32ou8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32os32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,s8s8s32os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,s8s8s32ou8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,s8s8s32os32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,s8s8s32obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,s8s8s32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(f32f32f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
|
||||
@@ -1141,6 +1155,7 @@ GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,float,f32f32f32of32,f
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,int32_t,float,s8s8s32os32,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,int32_t,float,s8s8s32os8,s8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,uint8_t,int32_t,int32_t,float,s8s8s32ou8,s8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,bfloat16,int32_t,int32_t,float,s8s8s32obf16,s8s8s32obf16)
|
||||
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,float,int32_t,float,int32_t,s8s8s32of32,s8s8s32of32)
|
||||
|
||||
@@ -1152,9 +1167,9 @@ GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,bfloat16,float,float,u8s8s32obf16)
|
||||
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32os32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int8_t,float,int32_t,s8s8s32os8)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,uint8_t,float,int32_t,s8s8s32ou8)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,float,float,int32_t,s8s8s32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,bfloat16,float,int32_t,s8s8s32obf16)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,uint8_t,float,int32_t,s8s8s32ou8)
|
||||
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16bf16f32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,bfloat16,bf16bf16f32obf16)
|
||||
@@ -1358,6 +1373,7 @@ GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,float,int32_t,u8s8s32of32,u8s8s32os32
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,float,f32f32f32of32,f32f32f32of32,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32,s8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8,s8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8,s8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16,s8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32,s8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32,bf16bf16f32of32,bf16s4f32of32)
|
||||
@@ -1398,9 +1414,10 @@ int main( int argc, char** argv )
|
||||
" 3. u8s8s32os32 -d f32 = u8s8s32of32.\n" \
|
||||
" 4. u8s8s32os32 -d bf16 = u8s8s32obf16.\n" \
|
||||
" 5. s8s8s32os32 -d s8 = s8s8s32os8.\n" \
|
||||
" 6. s8s8s32os32 -d f32 = s8s8s32of32.\n" \
|
||||
" 7. s8s8s32os32 -d bf16 = s8s8s32obf16.\n" \
|
||||
" 8. bf16bf16f32of32 -d bf16 = bf16bf16f32obf16.\n" \
|
||||
" 6. s8s8s32os32 -d u8 = s8s8s32ou8.\n" \
|
||||
" 7. s8s8s32os32 -d f32 = s8s8s32of32.\n" \
|
||||
" 8. s8s8s32os32 -d bf16 = s8s8s32obf16.\n" \
|
||||
" 9. bf16bf16f32of32 -d bf16 = bf16bf16f32obf16.\n" \
|
||||
" Example: ./bench_lpgemm -m a -n 2 -o bias,relu -d bf16 -i input.txt\n" \
|
||||
);
|
||||
exit( 1 );
|
||||
@@ -1748,6 +1765,21 @@ int main( int argc, char** argv )
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( gemm_type_str, "s8s8s32ou8" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
strncpy( post_ops_str_dest, post_ops_str, POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'y';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = 0;
|
||||
DSCALE_CLIP_MAX = +255;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32ou8)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( gemm_type_str, "s8s8s32obf16" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -517,7 +517,6 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
zmm8 = _mm512_maskz_sub_epi32( k2, zmm8, zmm0 );
|
||||
|
||||
}
|
||||
|
||||
//Scale accumulated output with alpha
|
||||
__m512i selector1 = _mm512_set1_epi32( alpha );
|
||||
__m512i selector2 = _mm512_set1_epi32( beta );
|
||||
@@ -536,6 +535,11 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
S8_S32_BETA_OP_NLT16F_MASK( k2, zmm8, 0, 0,
|
||||
selector1, selector2 )
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
U8_S32_BETA_OP_NLT16F_MASK( k2, zmm8, 0, 0,
|
||||
selector1, selector2 )
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == BF16 )
|
||||
{
|
||||
BF16_S32_BETA_OP_NLT16F_MASK( k2, zmm8, 0, 0,
|
||||
@@ -549,15 +553,57 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
}
|
||||
else
|
||||
{
|
||||
int8_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
if ( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
ctemp[i] = *( ( int8_t* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) );
|
||||
int8_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( ( int8_t* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) );
|
||||
}
|
||||
selector1 = _mm512_cvtepi8_epi32
|
||||
( _mm_maskz_loadu_epi8( 0xFFFF, ctemp ) );
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
uint8_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( ( uint8_t* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) );
|
||||
}
|
||||
selector1 = _mm512_cvtepu8_epi32
|
||||
( _mm_maskz_loadu_epi8( 0xFFFF, ctemp ) );
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == BF16 )
|
||||
{
|
||||
bfloat16 ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( ( bfloat16* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) );
|
||||
}
|
||||
selector1 = _mm512_cvtps_epi32(
|
||||
( __m512 )_mm512_sllv_epi32(
|
||||
_mm512_cvtepi16_epi32(
|
||||
_mm256_maskz_loadu_epi16( 0xFFFF, ctemp )
|
||||
), _mm512_set1_epi32( 16 ) ) );
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == F32 )
|
||||
{
|
||||
float ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( ( float* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) );
|
||||
}
|
||||
selector1 = _mm512_cvtps_epi32(
|
||||
_mm512_maskz_loadu_ps( 0xFFFF, ctemp ) );
|
||||
}
|
||||
selector1 = _mm512_cvtepi8_epi32
|
||||
( _mm_maskz_loadu_epi8( 0xFFFF, ctemp ) );
|
||||
S32_BETA_FMA( zmm8, selector1, selector2 );
|
||||
}
|
||||
}
|
||||
@@ -581,6 +627,9 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
}
|
||||
}
|
||||
|
||||
__m512 acc_8 = _mm512_setzero_ps();
|
||||
acc_8 = _mm512_cvtepi32_ps( zmm8 );
|
||||
|
||||
// Post Ops
|
||||
lpgemm_post_op *post_ops_list_temp = post_op;
|
||||
|
||||
@@ -589,108 +638,122 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
|
||||
POST_OPS_BIAS_6x64:
|
||||
{
|
||||
__m512 b0 = _mm512_setzero_ps();
|
||||
|
||||
if ( post_ops_list_temp->stor_type == BF16 )
|
||||
{
|
||||
selector1 =
|
||||
_mm512_cvtps_epi32
|
||||
(
|
||||
( __m512 )( _mm512_sllv_epi32
|
||||
(
|
||||
_mm512_cvtepi16_epi32
|
||||
(
|
||||
_mm256_maskz_loadu_epi16
|
||||
(
|
||||
_cvtu32_mask16( 0xFFFF ),
|
||||
b0 = (__m512)_mm512_sllv_epi32(
|
||||
_mm512_cvtepi16_epi32(
|
||||
_mm256_maskz_loadu_epi16(
|
||||
_cvtu32_mask16( 0x0001 ),
|
||||
( ( bfloat16* )post_ops_list_temp->op_args1 )
|
||||
)
|
||||
), _mm512_set1_epi32( 16 )
|
||||
)
|
||||
)
|
||||
);
|
||||
) ), _mm512_set1_epi32( 16 ) );
|
||||
}
|
||||
else if ( post_ops_list_temp->stor_type == S8 )
|
||||
{
|
||||
selector1 =
|
||||
_mm512_cvtepi8_epi32
|
||||
(
|
||||
_mm_maskz_loadu_epi8
|
||||
(
|
||||
_cvtu32_mask16( 0xFFFF ),
|
||||
( ( int8_t* )post_ops_list_temp->op_args1 )
|
||||
)
|
||||
);
|
||||
b0 = _mm512_cvtepi32_ps(
|
||||
_mm512_cvtepi8_epi32(
|
||||
_mm_maskz_loadu_epi8(
|
||||
_cvtu32_mask16( 0x0001 ),
|
||||
( ( int8_t* )post_ops_list_temp->op_args1 )
|
||||
) ) );
|
||||
}
|
||||
else if ( post_ops_list_temp->stor_type == F32 )
|
||||
else if ( post_ops_list_temp->stor_type == S32 )
|
||||
{
|
||||
selector1 = _mm512_cvtps_epi32
|
||||
(
|
||||
( __m512 )_mm512_maskz_loadu_ps
|
||||
(
|
||||
_cvtu32_mask16( 0xFFFF ),
|
||||
( ( float* ) post_ops_list_temp->op_args1 )
|
||||
)
|
||||
);
|
||||
b0 = _mm512_cvtepi32_ps(
|
||||
_mm512_set1_epi32(
|
||||
*( ( int32_t* )post_ops_list_temp->op_args1) ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
selector1 =
|
||||
_mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args1) );
|
||||
b0 = _mm512_maskz_loadu_ps(
|
||||
_cvtu32_mask16( 0x0001 ),
|
||||
( ( float* ) post_ops_list_temp->op_args1 ) );
|
||||
}
|
||||
zmm8 = _mm512_add_epi32( selector1, zmm8 );
|
||||
acc_8 = _mm512_add_ps( b0, acc_8 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_RELU_6x64:
|
||||
{
|
||||
selector1 = _mm512_setzero_epi32();
|
||||
__m512 zero = _mm512_setzero_ps();
|
||||
|
||||
zmm8 = _mm512_max_epi32( selector1, zmm8 );
|
||||
acc_8 = _mm512_max_ps( zero, acc_8 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_RELU_SCALE_6x64:
|
||||
{
|
||||
selector1 = _mm512_setzero_epi32();
|
||||
selector2 =
|
||||
_mm512_set1_epi32(
|
||||
*( ( int32_t* )post_ops_list_temp->op_args2 ) );
|
||||
__m512 zero = _mm512_setzero_ps();
|
||||
__m512 scale;
|
||||
|
||||
if ( ( post_ops_attr.c_stor_type == S32 ) ||
|
||||
( post_ops_attr.c_stor_type == U8 ) ||
|
||||
( post_ops_attr.c_stor_type == S8 ) )
|
||||
{
|
||||
scale = _mm512_cvtepi32_ps
|
||||
( _mm512_set1_epi32(
|
||||
*( ( int32_t* )post_ops_list_temp->op_args2 ) ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
scale = _mm512_set1_ps(
|
||||
*( ( float* )post_ops_list_temp->op_args2 ) );
|
||||
}
|
||||
|
||||
__mmask16 relu_cmp_mask;
|
||||
|
||||
RELU_SCALE_OP_S32_AVX512(zmm8)
|
||||
RELU_SCALE_OP_F32_AVX512(acc_8)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_GELU_TANH_6x64:
|
||||
{
|
||||
__m512 dn, z, x, r2, r, y, x_tanh;
|
||||
GELU_TANH_S32_AVX512( zmm8, y, r, r2, x,
|
||||
z, dn, x_tanh, selector1 )
|
||||
__m512 dn, z, x, r2, r, y;
|
||||
__m512i tmpout;
|
||||
|
||||
GELU_TANH_F32_AVX512_DEF( acc_8, y, r, r2, x, z, dn, tmpout );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_GELU_ERF_6x64:
|
||||
{
|
||||
__m512 x, r, y, x_erf;
|
||||
__m512 y, r, r2;
|
||||
|
||||
GELU_ERF_S32_AVX512( zmm8, y, r, x, x_erf )
|
||||
GELU_ERF_F32_AVX512_DEF( acc_8, y, r, r2 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_CLIP_6x64:
|
||||
{
|
||||
__m512i min = _mm512_set1_epi32(
|
||||
*( int32_t* )post_ops_list_temp->op_args2 );
|
||||
__m512i max = _mm512_set1_epi32(
|
||||
*( int32_t* )post_ops_list_temp->op_args3 );
|
||||
__m512 min = _mm512_setzero_ps();
|
||||
__m512 max = _mm512_setzero_ps();
|
||||
|
||||
CLIP_S32_AVX512( zmm8, min, max )
|
||||
if ( ( post_ops_attr.c_stor_type == S32 ) ||
|
||||
( post_ops_attr.c_stor_type == U8 ) ||
|
||||
( post_ops_attr.c_stor_type == S8 ) )
|
||||
{
|
||||
min = _mm512_cvtepi32_ps
|
||||
(_mm512_set1_epi32( *( int32_t* )post_ops_list_temp->op_args2 ));
|
||||
max = _mm512_cvtepi32_ps
|
||||
(_mm512_set1_epi32( *( int32_t* )post_ops_list_temp->op_args3 ));
|
||||
}
|
||||
else
|
||||
{
|
||||
min = _mm512_set1_ps(
|
||||
*( ( float* )post_ops_list_temp->op_args2 ) );
|
||||
max = _mm512_set1_ps(
|
||||
*( ( float* )post_ops_list_temp->op_args3 ) );
|
||||
}
|
||||
|
||||
CLIP_F32_AVX512( acc_8, min, max )
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_DOWNSCALE_6x64:
|
||||
{
|
||||
selector1 = ( __m512i )_mm512_set1_ps(
|
||||
__m512 scale0 = _mm512_setzero_ps();
|
||||
scale0 = _mm512_set1_ps(
|
||||
*( ( float* )post_ops_list_temp->scale_factor ) );
|
||||
|
||||
// Need to ensure sse not used to avoid avx512 -> sse transition.
|
||||
@@ -700,7 +763,7 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
zero_point0 = _mm_maskz_set1_epi8( 0xFFFF,
|
||||
*( ( int8_t* )post_ops_list_temp->op_args1 ) );
|
||||
|
||||
CVT_MULRND_CVT32(zmm8, selector1, zero_point0 );
|
||||
CVT_MULRND_F32(acc_8, scale0, zero_point0 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
@@ -715,6 +778,7 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
bool is_f32 = ( post_ops_list_temp->stor_type == F32 );
|
||||
|
||||
__m512 scl_fctr1 = _mm512_setzero_ps();
|
||||
__m512 t0 = _mm512_setzero_ps();
|
||||
|
||||
// Even though different registers are used for scalar in column and
|
||||
// row major case, all those registers will contain the same value.
|
||||
@@ -742,22 +806,19 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
BF16_S32_MATRIX_ADD_LOAD( k2, selector1, scl_fctr1, 0, 0 );
|
||||
zmm8 =
|
||||
_mm512_cvtps_epi32(
|
||||
_mm512_add_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ) )
|
||||
);
|
||||
BF16_F32_MATRIX_ADD_LOAD( k2, t0, scl_fctr1, 0, 0 );
|
||||
acc_8 = _mm512_add_ps( t0, acc_8 );
|
||||
}
|
||||
else
|
||||
{
|
||||
int8_t ctemp[16];
|
||||
bfloat16 ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( matptr +
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm512_sllv_epi32
|
||||
t0 = (__m512)_mm512_sllv_epi32
|
||||
(
|
||||
_mm512_cvtepi16_epi32
|
||||
(
|
||||
@@ -767,11 +828,8 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
)
|
||||
), _mm512_set1_epi32( 16 )
|
||||
);
|
||||
selector1 = ( __m512i )_mm512_mul_ps( ( __m512 )selector1, scl_fctr1 );
|
||||
zmm8 =
|
||||
_mm512_cvtps_epi32(
|
||||
_mm512_add_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ) )
|
||||
);
|
||||
t0 = _mm512_mul_ps( t0, scl_fctr1 );
|
||||
acc_8 = _mm512_add_ps( t0, acc_8 );
|
||||
}
|
||||
}
|
||||
else if ( is_f32 == TRUE )
|
||||
@@ -780,27 +838,21 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
F32_S32_MATRIX_ADD_LOAD( k2, selector1, scl_fctr1, 0, 0 );
|
||||
zmm8 =
|
||||
_mm512_cvtps_epi32(
|
||||
_mm512_add_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ) )
|
||||
);
|
||||
F32_ACC_MATRIX_ADD_LOAD( k2, t0, scl_fctr1, 0, 0 );
|
||||
acc_8 = _mm512_add_ps( t0, acc_8 );
|
||||
}
|
||||
else
|
||||
{
|
||||
int8_t ctemp[16];
|
||||
float ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( matptr +
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = ( __m512i )_mm512_maskz_loadu_ps( k2, ctemp );
|
||||
selector1 = ( __m512i )_mm512_mul_ps( ( __m512 )selector1, scl_fctr1 );
|
||||
zmm8 =
|
||||
_mm512_cvtps_epi32(
|
||||
_mm512_add_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ) )
|
||||
);
|
||||
t0 = _mm512_maskz_loadu_ps( k2, ctemp );
|
||||
t0 = _mm512_mul_ps( t0, scl_fctr1 );
|
||||
acc_8 = _mm512_add_ps( t0, acc_8 );
|
||||
}
|
||||
}
|
||||
else if ( is_s8 == TRUE )
|
||||
@@ -809,8 +861,8 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
S8_S32_MATRIX_ADD_LOAD( k2, selector1, scl_fctr1, 0, 0 )
|
||||
zmm8 = _mm512_add_epi32( selector1, zmm8 );
|
||||
S8_F32_MATRIX_ADD_LOAD( k2, t0, scl_fctr1, 0, 0 )
|
||||
acc_8 = _mm512_add_ps( t0, acc_8 );
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -821,16 +873,13 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm512_cvtepi8_epi32
|
||||
( _mm_maskz_loadu_epi8( k2, ctemp ) );
|
||||
selector1 = _mm512_cvtps_epi32(
|
||||
_mm512_mul_round_ps
|
||||
(
|
||||
_mm512_cvtepi32_ps( selector1 ), scl_fctr1,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC )
|
||||
)
|
||||
);
|
||||
zmm8 = _mm512_add_epi32( selector1, zmm8 );
|
||||
t0 = _mm512_cvtepi32_ps(
|
||||
_mm512_cvtepi8_epi32(
|
||||
_mm_maskz_loadu_epi8( k2, ctemp ) ) );
|
||||
t0 = _mm512_mul_round_ps( t0, scl_fctr1,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC )
|
||||
);
|
||||
acc_8 = _mm512_add_ps( t0, acc_8 );
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -839,8 +888,8 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
S32_S32_MATRIX_ADD_LOAD(k2, selector1, scl_fctr1, 0, 0 );
|
||||
zmm8 = _mm512_add_epi32( selector1, zmm8 );
|
||||
S32_F32_MATRIX_ADD_LOAD(k2, t0, scl_fctr1, 0, 0 );
|
||||
acc_8 = _mm512_add_ps( t0, acc_8 );
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -851,15 +900,12 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm512_maskz_loadu_epi32( k2, ctemp );
|
||||
selector1 = _mm512_cvtps_epi32(
|
||||
_mm512_mul_round_ps
|
||||
(
|
||||
_mm512_cvtepi32_ps( selector1 ), scl_fctr1,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC )
|
||||
)
|
||||
);
|
||||
zmm8 = _mm512_add_epi32( selector1, zmm8 );
|
||||
t0 = _mm512_cvtepi32_ps(
|
||||
_mm512_maskz_loadu_epi32( k2, ctemp ) );
|
||||
t0 = _mm512_mul_round_ps( t0, scl_fctr1,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC )
|
||||
);
|
||||
acc_8 = _mm512_add_ps( t0, acc_8 );
|
||||
}
|
||||
}
|
||||
|
||||
@@ -876,6 +922,7 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
bool is_f32 = ( post_ops_list_temp->stor_type == F32 );
|
||||
|
||||
__m512 scl_fctr1 = _mm512_setzero_ps();
|
||||
__m512 t0 = _mm512_setzero_ps();
|
||||
|
||||
// Even though different registers are used for scalar in column and
|
||||
// row major case, all those registers will contain the same value.
|
||||
@@ -903,38 +950,32 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
BF16_S32_MATRIX_MUL_LOAD( k2, selector1, scl_fctr1, 0, 0 );
|
||||
zmm8 =
|
||||
_mm512_cvtps_epi32(
|
||||
_mm512_mul_round_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ),
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) )
|
||||
);
|
||||
BF16_F32_MATRIX_MUL_LOAD( k2, t0, scl_fctr1, 0, 0 );
|
||||
acc_8 = _mm512_mul_round_ps( t0, acc_8,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
int8_t ctemp[16];
|
||||
bfloat16 ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( matptr +
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm512_sllv_epi32
|
||||
(
|
||||
_mm512_cvtepi16_epi32
|
||||
(
|
||||
_mm256_maskz_loadu_epi16
|
||||
(
|
||||
k2 , ctemp
|
||||
)
|
||||
), _mm512_set1_epi32( 16 )
|
||||
);
|
||||
selector1 = ( __m512i )_mm512_mul_ps( ( __m512 )selector1, scl_fctr1 );
|
||||
zmm8 =
|
||||
_mm512_cvtps_epi32(
|
||||
_mm512_mul_round_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ),
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) )
|
||||
);
|
||||
t0 = (__m512)_mm512_sllv_epi32
|
||||
(
|
||||
_mm512_cvtepi16_epi32
|
||||
(
|
||||
_mm256_maskz_loadu_epi16
|
||||
(
|
||||
k2 , ctemp
|
||||
)
|
||||
), _mm512_set1_epi32( 16 )
|
||||
);
|
||||
t0 = _mm512_mul_ps( t0, scl_fctr1 );
|
||||
acc_8 = _mm512_mul_round_ps( t0, acc_8,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
|
||||
}
|
||||
}
|
||||
else if ( is_f32 == TRUE )
|
||||
@@ -943,29 +984,23 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
F32_S32_MATRIX_MUL_LOAD( k2, selector1, scl_fctr1, 0, 0 );
|
||||
zmm8 =
|
||||
_mm512_cvtps_epi32(
|
||||
_mm512_mul_round_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ),
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) )
|
||||
);
|
||||
F32_MATRIX_MUL_LOAD( k2, t0, scl_fctr1, 0, 0 );
|
||||
acc_8 = _mm512_mul_round_ps( t0, acc_8,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
int8_t ctemp[16];
|
||||
float ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( matptr +
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = ( __m512i )_mm512_maskz_loadu_ps( k2, ctemp );
|
||||
selector1 = ( __m512i )_mm512_mul_ps( ( __m512 )selector1, scl_fctr1 );
|
||||
zmm8 =
|
||||
_mm512_cvtps_epi32(
|
||||
_mm512_mul_round_ps( ( __m512 )selector1, _mm512_cvtepi32_ps( zmm8 ),
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) )
|
||||
);
|
||||
t0 = _mm512_maskz_loadu_ps( k2, ctemp );
|
||||
t0 = _mm512_mul_ps( t0, scl_fctr1 );
|
||||
acc_8 = _mm512_mul_round_ps( t0, acc_8,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
|
||||
}
|
||||
}
|
||||
else if ( is_s8 == TRUE )
|
||||
@@ -974,14 +1009,10 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
S8_S32_MATRIX_MUL_LOAD( k2, selector1, scl_fctr1, 0, 0 )
|
||||
S8_F32_MATRIX_MUL_LOAD( k2, t0, scl_fctr1, 0, 0 )
|
||||
|
||||
// mul_epi32 works on 64 bit lengths, with mul done for lower 32 bits.
|
||||
// We only need 32 bit mul to get 32 bit output, so using mul_ps.
|
||||
zmm8 = _mm512_cvtps_epi32(
|
||||
_mm512_mul_ps( _mm512_cvtepi32_ps( selector1 ),
|
||||
_mm512_cvtepi32_ps( zmm8 ) )
|
||||
);
|
||||
acc_8 = _mm512_mul_round_ps( t0, acc_8,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -992,19 +1023,13 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm512_cvtepi8_epi32
|
||||
( _mm_maskz_loadu_epi8( k2, ctemp ) );
|
||||
selector1 = _mm512_cvtps_epi32(
|
||||
_mm512_mul_round_ps
|
||||
(
|
||||
_mm512_cvtepi32_ps( selector1 ), scl_fctr1,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC )
|
||||
)
|
||||
);
|
||||
zmm8 = _mm512_cvtps_epi32(
|
||||
_mm512_mul_ps( _mm512_cvtepi32_ps( selector1 ),
|
||||
_mm512_cvtepi32_ps( zmm8 ) )
|
||||
);
|
||||
t0 = _mm512_cvtepi32_ps(
|
||||
_mm512_cvtepi8_epi32
|
||||
( _mm_maskz_loadu_epi8( k2, ctemp ) ) );
|
||||
t0 = _mm512_mul_round_ps( t0, scl_fctr1,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
|
||||
acc_8 = _mm512_mul_round_ps( t0, acc_8,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1013,11 +1038,10 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
S32_S32_MATRIX_MUL_LOAD(k2, selector1, scl_fctr1, 0, 0 );
|
||||
zmm8 = _mm512_cvtps_epi32(
|
||||
_mm512_mul_ps( _mm512_cvtepi32_ps( selector1 ),
|
||||
_mm512_cvtepi32_ps( zmm8 ) )
|
||||
);
|
||||
S32_F32_MATRIX_MUL_LOAD(k2, t0, scl_fctr1, 0, 0 );
|
||||
|
||||
acc_8 = _mm512_mul_round_ps( t0, acc_8,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1028,18 +1052,12 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm512_maskz_loadu_epi32( k2, ctemp );
|
||||
selector1 = _mm512_cvtps_epi32(
|
||||
_mm512_mul_round_ps
|
||||
(
|
||||
_mm512_cvtepi32_ps( selector1 ), scl_fctr1,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC )
|
||||
)
|
||||
);
|
||||
zmm8 = _mm512_cvtps_epi32(
|
||||
_mm512_mul_ps( _mm512_cvtepi32_ps( selector1 ),
|
||||
_mm512_cvtepi32_ps( zmm8 ) )
|
||||
);
|
||||
t0 = _mm512_cvtepi32_ps(
|
||||
_mm512_maskz_loadu_epi32( k2, ctemp ) );
|
||||
t0 = _mm512_mul_round_ps( t0, scl_fctr1,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
|
||||
acc_8 = _mm512_mul_round_ps( t0, acc_8,
|
||||
( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) );
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1047,29 +1065,44 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
}
|
||||
POST_OPS_SWISH_6x64:
|
||||
{
|
||||
selector1 =
|
||||
_mm512_set1_epi32( *( (int32_t*)post_ops_list_temp->op_args2 ) );
|
||||
__m512 scale;
|
||||
|
||||
__m512 al = _mm512_cvtepi32_ps( selector1 );
|
||||
if ( ( post_ops_attr.c_stor_type == S32 ) ||
|
||||
( post_ops_attr.c_stor_type == U8 ) ||
|
||||
( post_ops_attr.c_stor_type == S8 ) )
|
||||
{
|
||||
scale = _mm512_cvtepi32_ps
|
||||
(_mm512_set1_epi32(
|
||||
*( ( int32_t* )post_ops_list_temp->op_args2 ) ));
|
||||
}
|
||||
else
|
||||
{
|
||||
scale = _mm512_set1_ps(
|
||||
*( ( float* )post_ops_list_temp->op_args2 ) );
|
||||
}
|
||||
|
||||
__m512 fl_reg, al_in, r, r2, z, dn;
|
||||
__m512 al_in, r, r2, z, dn;
|
||||
__m512i temp;
|
||||
|
||||
SWISH_S32_AVX512( zmm8, fl_reg, al, al_in, r, r2, z, dn, selector2 );
|
||||
SWISH_F32_AVX512_DEF( acc_8, scale, al_in, r, r2, z, dn, temp );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_TANH_6x64:
|
||||
{
|
||||
__m512 dn, z, x, r2, r, y;
|
||||
TANH_S32_AVX512( zmm8, y, r, r2, x, z, dn, selector1 );
|
||||
__m512 dn, z, x, r2, r;
|
||||
__m512i q;
|
||||
|
||||
TANHF_AVX512( acc_8, r, r2, x, z, dn, q );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_SIGMOID_6x64:
|
||||
{
|
||||
__m512 fl_reg, al_in, r, r2, z, dn;
|
||||
__m512 al_in, r, r2, z, dn;
|
||||
__m512i tmpout;
|
||||
|
||||
SIGMOID_S32_AVX512( zmm8, fl_reg, al_in, r, r2, z, dn, selector2 );
|
||||
SIGMOID_F32_AVX512_DEF( acc_8, al_in, r, r2, z, dn, tmpout );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
@@ -1083,28 +1116,79 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
{
|
||||
if ( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
CVT_STORE_S32_S8_MASK( zmm8, k2, 0, 0 );
|
||||
CVT_STORE_F32_S8_MASK( k2, acc_8, 0, 0 );
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
CVT_STORE_F32_U8_MASK( k2, acc_8, 0, 0 );
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == BF16 )
|
||||
{
|
||||
CVT_STORE_S32_BF16_MASK( zmm8, k2, 0, 0 );
|
||||
CVT_STORE_F32_BF16_MASK( k2, acc_8, 0, 0 );
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == F32 )
|
||||
{
|
||||
CVT_STORE_S32_F32_MASK( zmm8, k2, 0, 0 );
|
||||
STORE_F32_MASK( k2, acc_8, 0, 0 );
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
int8_t ctemp[16];
|
||||
|
||||
_mm512_mask_cvtsepi32_storeu_epi8 ( ctemp, k2, zmm8 );
|
||||
|
||||
for (dim_t i = 0; i < mr0; i++)
|
||||
if ( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
*( ( int8_t* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
|
||||
int8_t ctemp[16];
|
||||
|
||||
_mm512_mask_cvtsepi32_storeu_epi8 ( ctemp, k2,
|
||||
_mm512_cvtps_epi32( acc_8 ) );
|
||||
|
||||
for (dim_t i = 0; i < mr0; i++)
|
||||
{
|
||||
*( ( int8_t* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
|
||||
}
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
uint8_t ctemp[16];
|
||||
|
||||
_mm512_mask_cvtusepi32_storeu_epi8 ( ctemp, k2,
|
||||
_mm512_cvtps_epu32(
|
||||
_mm512_max_ps( acc_8, _mm512_set1_ps( 0 ) )
|
||||
) );
|
||||
|
||||
for (dim_t i = 0; i < mr0; i++)
|
||||
{
|
||||
*( ( uint8_t* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
|
||||
}
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == BF16 )
|
||||
{
|
||||
bfloat16 ctemp[16];
|
||||
|
||||
_mm256_mask_storeu_epi16( ctemp, k2,
|
||||
(__m256i)_mm512_cvtneps_pbh( acc_8 ) );
|
||||
|
||||
for (dim_t i = 0; i < mr0; i++)
|
||||
{
|
||||
*( ( bfloat16* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
|
||||
}
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == F32 )
|
||||
{
|
||||
float ctemp[16];
|
||||
|
||||
_mm512_mask_storeu_ps( ctemp, k2, acc_8 );
|
||||
|
||||
for (dim_t i = 0; i < mr0; i++)
|
||||
{
|
||||
*( ( float* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1112,14 +1196,16 @@ LPGEMV_N_EQ1_KERN(int8_t, int8_t, int32_t, s8s8s32os32)
|
||||
{
|
||||
if(rs_c == 1)
|
||||
{
|
||||
_mm512_mask_storeu_epi32(c_use, k2, zmm8);
|
||||
_mm512_mask_storeu_epi32(c_use, k2,
|
||||
_mm512_cvtps_epi32( acc_8 ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
// Store ZMM8 into ctemp buffer and store back
|
||||
// element by element into output buffer at strides
|
||||
int32_t ctemp[16];
|
||||
_mm512_mask_storeu_epi32(ctemp, k2, zmm8);
|
||||
_mm512_mask_storeu_epi32(ctemp, k2,
|
||||
_mm512_cvtps_epi32( acc_8 ) );
|
||||
for (dim_t i = 0; i < mr0; i++)
|
||||
{
|
||||
c_use[i * rs_c] = ctemp[i];
|
||||
|
||||
@@ -612,29 +612,76 @@ POST_OPS_BIAS_6x64:
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
}
|
||||
|
||||
// c[0,0-15]
|
||||
acc_00 = _mm512_add_ps( b0, acc_00 );
|
||||
|
||||
// c[0,16-31]
|
||||
acc_01 = _mm512_add_ps( b1, acc_01 );
|
||||
|
||||
// c[0,32-47]
|
||||
acc_02 = _mm512_add_ps( b2, acc_02 );
|
||||
|
||||
// c[0,48-63]
|
||||
acc_03 = _mm512_add_ps( b3, acc_03 );
|
||||
|
||||
// c[1,0-15]
|
||||
acc_10 = _mm512_add_ps( b0, acc_10 );
|
||||
|
||||
// c[1,16-31]
|
||||
acc_11 = _mm512_add_ps( b1, acc_11 );
|
||||
|
||||
// c[1,32-47]
|
||||
acc_12 = _mm512_add_ps( b2, acc_12 );
|
||||
|
||||
// c[1,48-63]
|
||||
acc_13 = _mm512_add_ps( b3, acc_13 );
|
||||
|
||||
// c[2,0-15]
|
||||
acc_20 = _mm512_add_ps( b0, acc_20 );
|
||||
|
||||
// c[2,16-31]
|
||||
acc_21 = _mm512_add_ps( b1, acc_21 );
|
||||
|
||||
// c[2,32-47]
|
||||
acc_22 = _mm512_add_ps( b2, acc_22 );
|
||||
|
||||
// c[2,48-63]
|
||||
acc_23 = _mm512_add_ps( b3, acc_23 );
|
||||
|
||||
// c[3,0-15]
|
||||
acc_30 = _mm512_add_ps( b0, acc_30 );
|
||||
|
||||
// c[3,16-31]
|
||||
acc_31 = _mm512_add_ps( b1, acc_31 );
|
||||
|
||||
// c[3,32-47]
|
||||
acc_32 = _mm512_add_ps( b2, acc_32 );
|
||||
|
||||
// c[3,48-63]
|
||||
acc_33 = _mm512_add_ps( b3, acc_33 );
|
||||
|
||||
// c[4,0-15]
|
||||
acc_40 = _mm512_add_ps( b0, acc_40 );
|
||||
|
||||
// c[4,16-31]
|
||||
acc_41 = _mm512_add_ps( b1, acc_41 );
|
||||
|
||||
// c[4,32-47]
|
||||
acc_42 = _mm512_add_ps( b2, acc_42 );
|
||||
|
||||
// c[4,48-63]
|
||||
acc_43 = _mm512_add_ps( b3, acc_43 );
|
||||
|
||||
// c[5,0-15]
|
||||
acc_50 = _mm512_add_ps( b0, acc_50 );
|
||||
|
||||
// c[5,16-31]
|
||||
acc_51 = _mm512_add_ps( b1, acc_51 );
|
||||
|
||||
// c[5,32-47]
|
||||
acc_52 = _mm512_add_ps( b2, acc_52 );
|
||||
|
||||
// c[5,48-63]
|
||||
acc_53 = _mm512_add_ps( b3, acc_53 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -643,29 +690,76 @@ POST_OPS_RELU_6x64:
|
||||
{
|
||||
__m512 zero = _mm512_setzero_ps();
|
||||
|
||||
// c[0,0-15]
|
||||
acc_00 = _mm512_max_ps( zero, acc_00 );
|
||||
|
||||
// c[0,16-31]
|
||||
acc_01 = _mm512_max_ps( zero, acc_01 );
|
||||
|
||||
// c[0,32-47]
|
||||
acc_02 = _mm512_max_ps( zero, acc_02 );
|
||||
|
||||
// c[0,48-63]
|
||||
acc_03 = _mm512_max_ps( zero, acc_03 );
|
||||
|
||||
// c[1,0-15]
|
||||
acc_10 = _mm512_max_ps( zero, acc_10 );
|
||||
|
||||
// c[1,16-31]
|
||||
acc_11 = _mm512_max_ps( zero, acc_11 );
|
||||
|
||||
// c[1,32-47]
|
||||
acc_12 = _mm512_max_ps( zero, acc_12 );
|
||||
|
||||
// c[1,48-63]
|
||||
acc_13 = _mm512_max_ps( zero, acc_13 );
|
||||
|
||||
// c[2,0-15]
|
||||
acc_20 = _mm512_max_ps( zero, acc_20 );
|
||||
|
||||
// c[2,16-31]
|
||||
acc_21 = _mm512_max_ps( zero, acc_21 );
|
||||
|
||||
// c[2,32-47]
|
||||
acc_22 = _mm512_max_ps( zero, acc_22 );
|
||||
|
||||
// c[2,48-63]
|
||||
acc_23 = _mm512_max_ps( zero, acc_23 );
|
||||
|
||||
// c[3,0-15]
|
||||
acc_30 = _mm512_max_ps( zero, acc_30 );
|
||||
|
||||
// c[3,16-31]
|
||||
acc_31 = _mm512_max_ps( zero, acc_31 );
|
||||
|
||||
// c[3,32-47]
|
||||
acc_32 = _mm512_max_ps( zero, acc_32 );
|
||||
|
||||
// c[3,48-63]
|
||||
acc_33 = _mm512_max_ps( zero, acc_33 );
|
||||
|
||||
// c[4,0-15]
|
||||
acc_40 = _mm512_max_ps( zero, acc_40 );
|
||||
|
||||
// c[4,16-31]
|
||||
acc_41 = _mm512_max_ps( zero, acc_41 );
|
||||
|
||||
// c[4,32-47]
|
||||
acc_42 = _mm512_max_ps( zero, acc_42 );
|
||||
|
||||
// c[4,48-63]
|
||||
acc_43 = _mm512_max_ps( zero, acc_43 );
|
||||
|
||||
// c[5,0-15]
|
||||
acc_50 = _mm512_max_ps( zero, acc_50 );
|
||||
|
||||
// c[5,16-31]
|
||||
acc_51 = _mm512_max_ps( zero, acc_51 );
|
||||
|
||||
// c[5,32-47]
|
||||
acc_52 = _mm512_max_ps( zero, acc_52 );
|
||||
|
||||
// c[5,48-63]
|
||||
acc_53 = _mm512_max_ps( zero, acc_53 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -770,29 +864,76 @@ POST_OPS_GELU_TANH_6x64:
|
||||
__m512 dn, z, x, r2, r, y;
|
||||
__m512i tmpout;
|
||||
|
||||
// c[0, 0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_00, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0, 16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_01, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0, 32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_02, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0, 48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_03, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1, 0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_10, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1, 16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_11, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1, 32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_12, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1, 48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_13, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2, 0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_20, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2, 16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_21, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2, 32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_22, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2, 48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_23, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[3, 0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_30, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[3, 16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_31, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[3, 32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_32, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[3, 48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_33, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[4, 0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_40, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[4, 16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_41, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[4, 32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_42, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[4, 48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_43, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[5, 0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_50, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[5, 16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_51, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[5, 32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_52, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[5, 48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_53, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -801,29 +942,76 @@ POST_OPS_GELU_ERF_6x64:
|
||||
{
|
||||
__m512 y, r, r2;
|
||||
|
||||
// c[0, 0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_00, y, r, r2)
|
||||
|
||||
// c[0, 16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_01, y, r, r2)
|
||||
|
||||
// c[0, 32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_02, y, r, r2)
|
||||
|
||||
// c[0, 48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_03, y, r, r2)
|
||||
|
||||
// c[1, 0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_10, y, r, r2)
|
||||
|
||||
// c[1, 16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_11, y, r, r2)
|
||||
|
||||
// c[1, 32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_12, y, r, r2)
|
||||
|
||||
// c[1, 48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_13, y, r, r2)
|
||||
|
||||
// c[2, 0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_20, y, r, r2)
|
||||
|
||||
// c[2, 16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_21, y, r, r2)
|
||||
|
||||
// c[2, 32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_22, y, r, r2)
|
||||
|
||||
// c[2, 48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_23, y, r, r2)
|
||||
|
||||
// c[3, 0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_30, y, r, r2)
|
||||
|
||||
// c[3, 16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_31, y, r, r2)
|
||||
|
||||
// c[3, 32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_32, y, r, r2)
|
||||
|
||||
// c[3, 48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_33, y, r, r2)
|
||||
|
||||
// c[4, 0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_40, y, r, r2)
|
||||
|
||||
// c[4, 16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_41, y, r, r2)
|
||||
|
||||
// c[4, 32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_42, y, r, r2)
|
||||
|
||||
// c[4, 48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_43, y, r, r2)
|
||||
|
||||
// c[5, 0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_50, y, r, r2)
|
||||
|
||||
// c[5, 16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_51, y, r, r2)
|
||||
|
||||
// c[5, 32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_52, y, r, r2)
|
||||
|
||||
// c[5, 48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_53, y, r, r2)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
|
||||
@@ -433,25 +433,64 @@ POST_OPS_BIAS_5x64:
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
}
|
||||
|
||||
// c[0,0-15]
|
||||
acc_00 = _mm512_add_ps( b0, acc_00 );
|
||||
|
||||
// c[0,16-31]
|
||||
acc_01 = _mm512_add_ps( b1, acc_01 );
|
||||
|
||||
// c[0,32-47]
|
||||
acc_02 = _mm512_add_ps( b2, acc_02 );
|
||||
|
||||
// c[0,48-63]
|
||||
acc_03 = _mm512_add_ps( b3, acc_03 );
|
||||
|
||||
// c[1,0-15]
|
||||
acc_10 = _mm512_add_ps( b0, acc_10 );
|
||||
|
||||
// c[1,16-31]
|
||||
acc_11 = _mm512_add_ps( b1, acc_11 );
|
||||
|
||||
// c[1,32-47]
|
||||
acc_12 = _mm512_add_ps( b2, acc_12 );
|
||||
|
||||
// c[1,48-63]
|
||||
acc_13 = _mm512_add_ps( b3, acc_13 );
|
||||
|
||||
// c[2,0-15]
|
||||
acc_20 = _mm512_add_ps( b0, acc_20 );
|
||||
|
||||
// c[2,16-31]
|
||||
acc_21 = _mm512_add_ps( b1, acc_21 );
|
||||
|
||||
// c[2,32-47]
|
||||
acc_22 = _mm512_add_ps( b2, acc_22 );
|
||||
|
||||
// c[2,48-63]
|
||||
acc_23 = _mm512_add_ps( b3, acc_23 );
|
||||
|
||||
// c[3,0-15]
|
||||
acc_30 = _mm512_add_ps( b0, acc_30 );
|
||||
|
||||
// c[3,16-31]
|
||||
acc_31 = _mm512_add_ps( b1, acc_31 );
|
||||
|
||||
// c[3,32-47]
|
||||
acc_32 = _mm512_add_ps( b2, acc_32 );
|
||||
|
||||
// c[3,48-63]
|
||||
acc_33 = _mm512_add_ps( b3, acc_33 );
|
||||
|
||||
// c[4,0-15]
|
||||
acc_40 = _mm512_add_ps( b0, acc_40 );
|
||||
|
||||
// c[4,16-31]
|
||||
acc_41 = _mm512_add_ps( b1, acc_41 );
|
||||
|
||||
// c[4,32-47]
|
||||
acc_42 = _mm512_add_ps( b2, acc_42 );
|
||||
|
||||
// c[4,48-63]
|
||||
acc_43 = _mm512_add_ps( b3, acc_43 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -460,25 +499,64 @@ POST_OPS_RELU_5x64:
|
||||
{
|
||||
__m512 zero = _mm512_setzero_ps();
|
||||
|
||||
// c[0,0-15]
|
||||
acc_00 = _mm512_max_ps( zero, acc_00 );
|
||||
|
||||
// c[0,16-31]
|
||||
acc_01 = _mm512_max_ps( zero, acc_01 );
|
||||
|
||||
// c[0,32-47]
|
||||
acc_02 = _mm512_max_ps( zero, acc_02 );
|
||||
|
||||
// c[0,48-63]
|
||||
acc_03 = _mm512_max_ps( zero, acc_03 );
|
||||
|
||||
// c[1,0-15]
|
||||
acc_10 = _mm512_max_ps( zero, acc_10 );
|
||||
|
||||
// c[1,16-31]
|
||||
acc_11 = _mm512_max_ps( zero, acc_11 );
|
||||
|
||||
// c[1,32-47]
|
||||
acc_12 = _mm512_max_ps( zero, acc_12 );
|
||||
|
||||
// c[1,48-63]
|
||||
acc_13 = _mm512_max_ps( zero, acc_13 );
|
||||
|
||||
// c[2,0-15]
|
||||
acc_20 = _mm512_max_ps( zero, acc_20 );
|
||||
|
||||
// c[2,16-31]
|
||||
acc_21 = _mm512_max_ps( zero, acc_21 );
|
||||
|
||||
// c[2,32-47]
|
||||
acc_22 = _mm512_max_ps( zero, acc_22 );
|
||||
|
||||
// c[2,48-63]
|
||||
acc_23 = _mm512_max_ps( zero, acc_23 );
|
||||
|
||||
// c[3,0-15]
|
||||
acc_30 = _mm512_max_ps( zero, acc_30 );
|
||||
|
||||
// c[3,16-31]
|
||||
acc_31 = _mm512_max_ps( zero, acc_31 );
|
||||
|
||||
// c[3,32-47]
|
||||
acc_32 = _mm512_max_ps( zero, acc_32 );
|
||||
|
||||
// c[3,48-63]
|
||||
acc_33 = _mm512_max_ps( zero, acc_33 );
|
||||
|
||||
// c[4,0-15]
|
||||
acc_40 = _mm512_max_ps( zero, acc_40 );
|
||||
|
||||
// c[4,16-31]
|
||||
acc_41 = _mm512_max_ps( zero, acc_41 );
|
||||
|
||||
// c[4,32-47]
|
||||
acc_42 = _mm512_max_ps( zero, acc_42 );
|
||||
|
||||
// c[4,48-63]
|
||||
acc_43 = _mm512_max_ps( zero, acc_43 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -569,25 +647,64 @@ POST_OPS_GELU_TANH_5x64:
|
||||
__m512 dn, z, x, r2, r, y;
|
||||
__m512i tmpout;
|
||||
|
||||
// c[0,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_00, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_01, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_02, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_03, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_10, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_11, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_12, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_13, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_20, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_21, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_22, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_23, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[3,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_30, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[3,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_31, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[3,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_32, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[3,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_33, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[4,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_40, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[4,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_41, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[4,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_42, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[4,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_43, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -596,25 +713,64 @@ POST_OPS_GELU_ERF_5x64:
|
||||
{
|
||||
__m512 y, r, r2;
|
||||
|
||||
// c[0,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_00, y, r, r2)
|
||||
|
||||
// c[0,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_01, y, r, r2)
|
||||
|
||||
// c[0,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_02, y, r, r2)
|
||||
|
||||
// c[0,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_03, y, r, r2)
|
||||
|
||||
// c[1,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_10, y, r, r2)
|
||||
|
||||
// c[1,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_11, y, r, r2)
|
||||
|
||||
// c[1,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_12, y, r, r2)
|
||||
|
||||
// c[1,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_13, y, r, r2)
|
||||
|
||||
// c[2,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_20, y, r, r2)
|
||||
|
||||
// c[2,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_21, y, r, r2)
|
||||
|
||||
// c[2,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_22, y, r, r2)
|
||||
|
||||
// c[2,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_23, y, r, r2)
|
||||
|
||||
// c[3,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_30, y, r, r2)
|
||||
|
||||
// c[3,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_31, y, r, r2)
|
||||
|
||||
// c[3,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_32, y, r, r2)
|
||||
|
||||
// c[3,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_33, y, r, r2)
|
||||
|
||||
// c[4,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_40, y, r, r2)
|
||||
|
||||
// c[4,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_41, y, r, r2)
|
||||
|
||||
// c[4,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_42, y, r, r2)
|
||||
|
||||
// c[4,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_43, y, r, r2)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -2270,21 +2426,52 @@ POST_OPS_BIAS_4x64:
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
}
|
||||
|
||||
// c[0,0-15]
|
||||
acc_00 = _mm512_add_ps( b0, acc_00 );
|
||||
|
||||
// c[0,16-31]
|
||||
acc_01 = _mm512_add_ps( b1, acc_01 );
|
||||
|
||||
// c[0,32-47]
|
||||
acc_02 = _mm512_add_ps( b2, acc_02 );
|
||||
|
||||
// c[0,48-63]
|
||||
acc_03 = _mm512_add_ps( b3, acc_03 );
|
||||
|
||||
// c[1,0-15]
|
||||
acc_10 = _mm512_add_ps( b0, acc_10 );
|
||||
|
||||
// c[1,16-31]
|
||||
acc_11 = _mm512_add_ps( b1, acc_11 );
|
||||
|
||||
// c[1,32-47]
|
||||
acc_12 = _mm512_add_ps( b2, acc_12 );
|
||||
|
||||
// c[1,48-63]
|
||||
acc_13 = _mm512_add_ps( b3, acc_13 );
|
||||
|
||||
// c[2,0-15]
|
||||
acc_20 = _mm512_add_ps( b0, acc_20 );
|
||||
|
||||
// c[2,16-31]
|
||||
acc_21 = _mm512_add_ps( b1, acc_21 );
|
||||
|
||||
// c[2,32-47]
|
||||
acc_22 = _mm512_add_ps( b2, acc_22 );
|
||||
|
||||
// c[2,48-63]
|
||||
acc_23 = _mm512_add_ps( b3, acc_23 );
|
||||
|
||||
// c[3,0-15]
|
||||
acc_30 = _mm512_add_ps( b0, acc_30 );
|
||||
|
||||
// c[3,16-31]
|
||||
acc_31 = _mm512_add_ps( b1, acc_31 );
|
||||
|
||||
// c[3,32-47]
|
||||
acc_32 = _mm512_add_ps( b2, acc_32 );
|
||||
|
||||
// c[3,48-63]
|
||||
acc_33 = _mm512_add_ps( b3, acc_33 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -2293,21 +2480,52 @@ POST_OPS_RELU_4x64:
|
||||
{
|
||||
__m512 zero = _mm512_setzero_ps();
|
||||
|
||||
// c[0,0-15]
|
||||
acc_00 = _mm512_max_ps( zero, acc_00 );
|
||||
|
||||
// c[0,16-31]
|
||||
acc_01 = _mm512_max_ps( zero, acc_01 );
|
||||
|
||||
// c[0,32-47]
|
||||
acc_02 = _mm512_max_ps( zero, acc_02 );
|
||||
|
||||
// c[0,48-63]
|
||||
acc_03 = _mm512_max_ps( zero, acc_03 );
|
||||
|
||||
// c[1,0-15]
|
||||
acc_10 = _mm512_max_ps( zero, acc_10 );
|
||||
|
||||
// c[1,16-31]
|
||||
acc_11 = _mm512_max_ps( zero, acc_11 );
|
||||
|
||||
// c[1,32-47]
|
||||
acc_12 = _mm512_max_ps( zero, acc_12 );
|
||||
|
||||
// c[1,48-63]
|
||||
acc_13 = _mm512_max_ps( zero, acc_13 );
|
||||
|
||||
// c[2,0-15]
|
||||
acc_20 = _mm512_max_ps( zero, acc_20 );
|
||||
|
||||
// c[2,16-31]
|
||||
acc_21 = _mm512_max_ps( zero, acc_21 );
|
||||
|
||||
// c[2,32-47]
|
||||
acc_22 = _mm512_max_ps( zero, acc_22 );
|
||||
|
||||
// c[2,48-63]
|
||||
acc_23 = _mm512_max_ps( zero, acc_23 );
|
||||
|
||||
// c[3,0-15]
|
||||
acc_30 = _mm512_max_ps( zero, acc_30 );
|
||||
|
||||
// c[3,16-31]
|
||||
acc_31 = _mm512_max_ps( zero, acc_31 );
|
||||
|
||||
// c[3,32-47]
|
||||
acc_32 = _mm512_max_ps( zero, acc_32 );
|
||||
|
||||
// c[3,48-63]
|
||||
acc_33 = _mm512_max_ps( zero, acc_33 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -2386,21 +2604,52 @@ POST_OPS_GELU_TANH_4x64:
|
||||
__m512 dn, z, x, r2, r, y;
|
||||
__m512i tmpout;
|
||||
|
||||
// c[0,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_00, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_01, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_02, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_03, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_10, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_11, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_12, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_13, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_20, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_21, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_22, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_23, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[3,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_30, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[3,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_31, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[3,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_32, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[3,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_33, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -2409,21 +2658,52 @@ POST_OPS_GELU_ERF_4x64:
|
||||
{
|
||||
__m512 y, r, r2;
|
||||
|
||||
// c[0,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_00, y, r, r2)
|
||||
|
||||
// c[0,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_01, y, r, r2)
|
||||
|
||||
// c[0,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_02, y, r, r2)
|
||||
|
||||
// c[0,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_03, y, r, r2)
|
||||
|
||||
// c[1,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_10, y, r, r2)
|
||||
|
||||
// c[1,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_11, y, r, r2)
|
||||
|
||||
// c[1,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_12, y, r, r2)
|
||||
|
||||
// c[1,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_13, y, r, r2)
|
||||
|
||||
// c[2,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_20, y, r, r2)
|
||||
|
||||
// c[2,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_21, y, r, r2)
|
||||
|
||||
// c[2,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_22, y, r, r2)
|
||||
|
||||
// c[2,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_23, y, r, r2)
|
||||
|
||||
// c[3,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_30, y, r, r2)
|
||||
|
||||
// c[3,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_31, y, r, r2)
|
||||
|
||||
// c[3,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_32, y, r, r2)
|
||||
|
||||
// c[3,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_33, y, r, r2)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -3829,17 +4109,40 @@ POST_OPS_BIAS_3x64:
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
}
|
||||
|
||||
// c[0,0-15]
|
||||
acc_00 = _mm512_add_ps( b0, acc_00 );
|
||||
|
||||
// c[0,16-31]
|
||||
acc_01 = _mm512_add_ps( b1, acc_01 );
|
||||
|
||||
// c[0,32-47]
|
||||
acc_02 = _mm512_add_ps( b2, acc_02 );
|
||||
|
||||
// c[0,48-63]
|
||||
acc_03 = _mm512_add_ps( b3, acc_03 );
|
||||
|
||||
// c[1,0-15]
|
||||
acc_10 = _mm512_add_ps( b0, acc_10 );
|
||||
|
||||
// c[1,16-31]
|
||||
acc_11 = _mm512_add_ps( b1, acc_11 );
|
||||
|
||||
// c[1,32-47]
|
||||
acc_12 = _mm512_add_ps( b2, acc_12 );
|
||||
|
||||
// c[1,48-63]
|
||||
acc_13 = _mm512_add_ps( b3, acc_13 );
|
||||
|
||||
// c[2,0-15]
|
||||
acc_20 = _mm512_add_ps( b0, acc_20 );
|
||||
|
||||
// c[2,16-31]
|
||||
acc_21 = _mm512_add_ps( b1, acc_21 );
|
||||
|
||||
// c[2,32-47]
|
||||
acc_22 = _mm512_add_ps( b2, acc_22 );
|
||||
|
||||
// c[2,48-63]
|
||||
acc_23 = _mm512_add_ps( b3, acc_23 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -3848,17 +4151,40 @@ POST_OPS_RELU_3x64:
|
||||
{
|
||||
__m512 zero = _mm512_setzero_ps();
|
||||
|
||||
// c[0,0-15]
|
||||
acc_00 = _mm512_max_ps( zero, acc_00 );
|
||||
|
||||
// c[0,16-31]
|
||||
acc_01 = _mm512_max_ps( zero, acc_01 );
|
||||
|
||||
// c[0,32-47]
|
||||
acc_02 = _mm512_max_ps( zero, acc_02 );
|
||||
|
||||
// c[0,48-63]
|
||||
acc_03 = _mm512_max_ps( zero, acc_03 );
|
||||
|
||||
// c[1,0-15]
|
||||
acc_10 = _mm512_max_ps( zero, acc_10 );
|
||||
|
||||
// c[1,16-31]
|
||||
acc_11 = _mm512_max_ps( zero, acc_11 );
|
||||
|
||||
// c[1,32-47]
|
||||
acc_12 = _mm512_max_ps( zero, acc_12 );
|
||||
|
||||
// c[1,48-63]
|
||||
acc_13 = _mm512_max_ps( zero, acc_13 );
|
||||
|
||||
// c[2,0-15]
|
||||
acc_20 = _mm512_max_ps( zero, acc_20 );
|
||||
|
||||
// c[2,16-31]
|
||||
acc_21 = _mm512_max_ps( zero, acc_21 );
|
||||
|
||||
// c[2,32-47]
|
||||
acc_22 = _mm512_max_ps( zero, acc_22 );
|
||||
|
||||
// c[2,48-63]
|
||||
acc_23 = _mm512_max_ps( zero, acc_23 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -3925,17 +4251,40 @@ POST_OPS_GELU_TANH_3x64:
|
||||
__m512 dn, z, x, r2, r, y;
|
||||
__m512i tmpout;
|
||||
|
||||
// c[0,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_00, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_01, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_02, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_03, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_10, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_11, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_12, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_13, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_20, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_21, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_22, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[2,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_23, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -3944,17 +4293,40 @@ POST_OPS_GELU_ERF_3x64:
|
||||
{
|
||||
__m512 y, r, r2;
|
||||
|
||||
// c[0,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_00, y, r, r2)
|
||||
|
||||
// c[0,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_01, y, r, r2)
|
||||
|
||||
// c[0,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_02, y, r, r2)
|
||||
|
||||
// c[0,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_03, y, r, r2)
|
||||
|
||||
// c[1,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_10, y, r, r2)
|
||||
|
||||
// c[1,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_11, y, r, r2)
|
||||
|
||||
// c[1,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_12, y, r, r2)
|
||||
|
||||
// c[1,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_13, y, r, r2)
|
||||
|
||||
// c[2,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_20, y, r, r2)
|
||||
|
||||
// c[2,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_21, y, r, r2)
|
||||
|
||||
// c[2,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_22, y, r, r2)
|
||||
|
||||
// c[2,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_23, y, r, r2)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -5117,13 +5489,28 @@ POST_OPS_BIAS_2x64:
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
}
|
||||
|
||||
// c[0,0-15]
|
||||
acc_00 = _mm512_add_ps( b0, acc_00 );
|
||||
|
||||
// c[0,16-31]
|
||||
acc_01 = _mm512_add_ps( b1, acc_01 );
|
||||
|
||||
// c[0,32-47]
|
||||
acc_02 = _mm512_add_ps( b2, acc_02 );
|
||||
|
||||
// c[0,48-63]
|
||||
acc_03 = _mm512_add_ps( b3, acc_03 );
|
||||
|
||||
// c[1,0-15]
|
||||
acc_10 = _mm512_add_ps( b0, acc_10 );
|
||||
|
||||
// c[1,16-31]
|
||||
acc_11 = _mm512_add_ps( b1, acc_11 );
|
||||
|
||||
// c[1,32-47]
|
||||
acc_12 = _mm512_add_ps( b2, acc_12 );
|
||||
|
||||
// c[1,48-63]
|
||||
acc_13 = _mm512_add_ps( b3, acc_13 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -5132,13 +5519,28 @@ POST_OPS_RELU_2x64:
|
||||
{
|
||||
__m512 zero = _mm512_setzero_ps();
|
||||
|
||||
// c[0,0-15]
|
||||
acc_00 = _mm512_max_ps( zero, acc_00 );
|
||||
|
||||
// c[0,16-31]
|
||||
acc_01 = _mm512_max_ps( zero, acc_01 );
|
||||
|
||||
// c[0,32-47]
|
||||
acc_02 = _mm512_max_ps( zero, acc_02 );
|
||||
|
||||
// c[0,48-63]
|
||||
acc_03 = _mm512_max_ps( zero, acc_03 );
|
||||
|
||||
// c[1,0-15]
|
||||
acc_10 = _mm512_max_ps( zero, acc_10 );
|
||||
|
||||
// c[1,16-31]
|
||||
acc_11 = _mm512_max_ps( zero, acc_11 );
|
||||
|
||||
// c[1,32-47]
|
||||
acc_12 = _mm512_max_ps( zero, acc_12 );
|
||||
|
||||
// c[1,48-63]
|
||||
acc_13 = _mm512_max_ps( zero, acc_13 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -5193,13 +5595,28 @@ POST_OPS_GELU_TANH_2x64:
|
||||
__m512 dn, z, x, r2, r, y;
|
||||
__m512i tmpout;
|
||||
|
||||
// c[0,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_00, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_01, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_02, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_03, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_10, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_11, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_12, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[1,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_13, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -5208,13 +5625,28 @@ POST_OPS_GELU_ERF_2x64:
|
||||
{
|
||||
__m512 y, r, r2;
|
||||
|
||||
// c[0,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_00, y, r, r2)
|
||||
|
||||
// c[0,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_01, y, r, r2)
|
||||
|
||||
// c[0,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_02, y, r, r2)
|
||||
|
||||
// c[0,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_03, y, r, r2)
|
||||
|
||||
// c[1,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_10, y, r, r2)
|
||||
|
||||
// c[1,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_11, y, r, r2)
|
||||
|
||||
// c[1,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_12, y, r, r2)
|
||||
|
||||
// c[1,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_13, y, r, r2)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -6121,9 +6553,16 @@ POST_OPS_BIAS_1x64:
|
||||
post_ops_attr.post_op_c_j + ( 3 * 16 ) );
|
||||
}
|
||||
|
||||
// c[0,0-15]
|
||||
acc_00 = _mm512_add_ps( b0, acc_00 );
|
||||
|
||||
// c[0,16-31]
|
||||
acc_01 = _mm512_add_ps( b1, acc_01 );
|
||||
|
||||
// c[0,32-47]
|
||||
acc_02 = _mm512_add_ps( b2, acc_02 );
|
||||
|
||||
// c[0,48-63]
|
||||
acc_03 = _mm512_add_ps( b3, acc_03 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -6132,9 +6571,16 @@ POST_OPS_RELU_1x64:
|
||||
{
|
||||
__m512 zero = _mm512_setzero_ps();
|
||||
|
||||
// c[0,0-15]
|
||||
acc_00 = _mm512_max_ps( zero, acc_00 );
|
||||
|
||||
// c[0,16-31]
|
||||
acc_01 = _mm512_max_ps( zero, acc_01 );
|
||||
|
||||
// c[0,32-47]
|
||||
acc_02 = _mm512_max_ps( zero, acc_02 );
|
||||
|
||||
// c[0,48-63]
|
||||
acc_03 = _mm512_max_ps( zero, acc_03 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -6177,9 +6623,16 @@ POST_OPS_GELU_TANH_1x64:
|
||||
__m512 dn, z, x, r2, r, y;
|
||||
__m512i tmpout;
|
||||
|
||||
// c[0,0-15]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_00, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,16-31]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_01, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,32-47]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_02, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
// c[0,48-63]
|
||||
GELU_TANH_F32_AVX512_DEF(acc_03, y, r, r2, x, z, dn, tmpout)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
@@ -6188,9 +6641,16 @@ POST_OPS_GELU_ERF_1x64:
|
||||
{
|
||||
__m512 y, r, r2;
|
||||
|
||||
// c[0,0-15]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_00, y, r, r2)
|
||||
|
||||
// c[0,16-31]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_01, y, r, r2)
|
||||
|
||||
// c[0,32-47]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_02, y, r, r2)
|
||||
|
||||
// c[0,48-63]
|
||||
GELU_ERF_F32_AVX512_DEF(acc_03, y, r, r2)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
|
||||
Reference in New Issue
Block a user