mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Added Post-Ops Support for F32 RD Kernels
- Support for Post-Ops has been added for all F32 RD AVX512 and AVX2 kernels. AMD-Internal: [SWLCSG-3497] Change-Id: Ia2967417303d8278c547957878d93c42c887109e
This commit is contained in:
committed by
Arnav Sharma
parent
a5f11a1540
commit
267aae80ea
@@ -92,19 +92,67 @@
|
||||
reg = _mm_mul_ps(reg, selector); \
|
||||
reg = _mm_add_ps(reg, zero_point); \
|
||||
|
||||
//Zero-out the given YMM accumulator registers
|
||||
#define ZERO_ACC_YMM_4_REG(ymm0,ymm1,ymm2,ymm3) \
|
||||
ymm0 = _mm256_setzero_ps(); \
|
||||
ymm1 = _mm256_setzero_ps(); \
|
||||
ymm2 = _mm256_setzero_ps(); \
|
||||
ymm3 = _mm256_setzero_ps();
|
||||
/* Zero-out all YMM registers */
|
||||
#define ZERO_YMM_ALL \
|
||||
ymm0 = _mm256_setzero_ps(); \
|
||||
ymm1 = _mm256_setzero_ps(); \
|
||||
ymm2 = _mm256_setzero_ps(); \
|
||||
ymm3 = _mm256_setzero_ps(); \
|
||||
ymm4 = _mm256_setzero_ps(); \
|
||||
ymm5 = _mm256_setzero_ps(); \
|
||||
ymm6 = _mm256_setzero_ps(); \
|
||||
ymm7 = _mm256_setzero_ps(); \
|
||||
ymm8 = _mm256_setzero_ps(); \
|
||||
ymm9 = _mm256_setzero_ps(); \
|
||||
ymm10 = _mm256_setzero_ps(); \
|
||||
ymm11 = _mm256_setzero_ps(); \
|
||||
ymm12 = _mm256_setzero_ps(); \
|
||||
ymm13 = _mm256_setzero_ps(); \
|
||||
ymm14 = _mm256_setzero_ps(); \
|
||||
ymm15 = _mm256_setzero_ps();
|
||||
|
||||
//Zero-out the given XMM accumulator registers
|
||||
#define ZERO_ACC_XMM_4_REG(xmm0,xmm1,xmm2,xmm3) \
|
||||
xmm0 = _mm_setzero_ps(); \
|
||||
xmm1 = _mm_setzero_ps(); \
|
||||
xmm2 = _mm_setzero_ps(); \
|
||||
xmm3 = _mm_setzero_ps();
|
||||
/* Zero-out all XMM registers */
|
||||
#define ZERO_XMM_ALL \
|
||||
xmm0 = _mm_setzero_ps(); \
|
||||
xmm1 = _mm_setzero_ps(); \
|
||||
xmm2 = _mm_setzero_ps(); \
|
||||
xmm3 = _mm_setzero_ps(); \
|
||||
xmm4 = _mm_setzero_ps(); \
|
||||
xmm5 = _mm_setzero_ps(); \
|
||||
xmm6 = _mm_setzero_ps(); \
|
||||
xmm7 = _mm_setzero_ps();
|
||||
|
||||
// Zero-out the given YMM accumulator registers
|
||||
#define ZERO_ACC_YMM_4_REG(ymm0, ymm1, ymm2, ymm3) \
|
||||
ymm0 = _mm256_setzero_ps(); \
|
||||
ymm1 = _mm256_setzero_ps(); \
|
||||
ymm2 = _mm256_setzero_ps(); \
|
||||
ymm3 = _mm256_setzero_ps();
|
||||
|
||||
#define ZERO_ACC_YMM_3_REG(ymm0,ymm1,ymm2) \
|
||||
ymm0 = _mm256_setzero_ps(); \
|
||||
ymm1 = _mm256_setzero_ps(); \
|
||||
ymm2 = _mm256_setzero_ps();
|
||||
|
||||
#define ZERO_ACC_YMM_2_REG(ymm0,ymm1) \
|
||||
ymm0 = _mm256_setzero_ps(); \
|
||||
ymm1 = _mm256_setzero_ps();
|
||||
|
||||
// Zero-out the given YMM accumulator registers
|
||||
#define ZERO_ACC_XMM_4_REG(xmm0, xmm1, xmm2, xmm3) \
|
||||
xmm0 = _mm_setzero_ps(); \
|
||||
xmm1 = _mm_setzero_ps(); \
|
||||
xmm2 = _mm_setzero_ps(); \
|
||||
xmm3 = _mm_setzero_ps();
|
||||
|
||||
#define ZERO_ACC_XMM_3_REG(xmm0,xmm1,xmm2) \
|
||||
xmm0 = _mm_setzero_ps(); \
|
||||
xmm1 = _mm_setzero_ps(); \
|
||||
xmm2 = _mm_setzero_ps();
|
||||
|
||||
#define ZERO_ACC_XMM_2_REG(xmm0,xmm1) \
|
||||
xmm0 = _mm_setzero_ps(); \
|
||||
xmm1 = _mm_setzero_ps();
|
||||
|
||||
/*Multiply alpha with accumulator registers and store back*/
|
||||
#define ALPHA_MUL_ACC_YMM_4_REG(ymm0,ymm1,ymm2,ymm3,alpha) \
|
||||
@@ -266,6 +314,9 @@ multiply with Beta, and add to alpha*A*B*/
|
||||
F32_F32_MATRIX_ADD_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
|
||||
F32_MATRIX_ADD_1COL_YMM(scr0,m_ind,r_ind0); \
|
||||
|
||||
#ifdef F32_F32_MATRIX_ADD_2COL
|
||||
#undef F32_F32_MATRIX_ADD_2COL
|
||||
#endif
|
||||
#define F32_F32_MATRIX_ADD_2COL(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
|
||||
F32_F32_MATRIX_ADD_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
|
||||
F32_F32_MATRIX_ADD_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
|
||||
@@ -287,6 +338,9 @@ multiply with Beta, and add to alpha*A*B*/
|
||||
); \
|
||||
scr = _mm256_mul_ps( scr, scl_fct ); \
|
||||
|
||||
#ifdef BF16_F32_MATRIX_ADD_2COL
|
||||
#undef BF16_F32_MATRIX_ADD_2COL
|
||||
#endif
|
||||
#define BF16_F32_MATRIX_ADD_2COL(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
|
||||
BF16_F32_MATRIX_ADD_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
|
||||
BF16_F32_MATRIX_ADD_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
|
||||
@@ -418,6 +472,9 @@ multiply with Beta, and add to alpha*A*B*/
|
||||
F32_F32_MATRIX_MUL_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
|
||||
F32_MATRIX_MUL_1COL_YMM(scr0,m_ind,r_ind0); \
|
||||
|
||||
#ifdef F32_F32_MATRIX_MUL_2COL
|
||||
#undef F32_F32_MATRIX_MUL_2COL
|
||||
#endif
|
||||
#define F32_F32_MATRIX_MUL_2COL(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
|
||||
F32_F32_MATRIX_MUL_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
|
||||
F32_F32_MATRIX_MUL_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
|
||||
@@ -452,6 +509,9 @@ multiply with Beta, and add to alpha*A*B*/
|
||||
BF16_F32_MATRIX_MUL_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
|
||||
F32_MATRIX_MUL_1COL_YMM(scr0,m_ind,r_ind0); \
|
||||
|
||||
#ifdef BF16_F32_MATRIX_MUL_2COL
|
||||
#undef BF16_F32_MATRIX_MUL_2COL
|
||||
#endif
|
||||
#define BF16_F32_MATRIX_MUL_2COL(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
|
||||
BF16_F32_MATRIX_MUL_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
|
||||
BF16_F32_MATRIX_MUL_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -40,6 +40,72 @@
|
||||
#include "../sigmoid_avx512.h"
|
||||
#include "../math_utils_avx512.h"
|
||||
|
||||
/* Zero-out all ZMM registers */
|
||||
#define ZERO_ZMM_ALL \
|
||||
zmm0 = _mm512_setzero_ps(); \
|
||||
zmm1 = _mm512_setzero_ps(); \
|
||||
zmm2 = _mm512_setzero_ps(); \
|
||||
zmm3 = _mm512_setzero_ps(); \
|
||||
zmm4 = _mm512_setzero_ps(); \
|
||||
zmm5 = _mm512_setzero_ps(); \
|
||||
zmm6 = _mm512_setzero_ps(); \
|
||||
zmm7 = _mm512_setzero_ps(); \
|
||||
zmm8 = _mm512_setzero_ps(); \
|
||||
zmm9 = _mm512_setzero_ps(); \
|
||||
zmm10 = _mm512_setzero_ps(); \
|
||||
zmm11 = _mm512_setzero_ps(); \
|
||||
zmm12 = _mm512_setzero_ps(); \
|
||||
zmm13 = _mm512_setzero_ps(); \
|
||||
zmm14 = _mm512_setzero_ps(); \
|
||||
zmm15 = _mm512_setzero_ps(); \
|
||||
zmm16 = _mm512_setzero_ps(); \
|
||||
zmm17 = _mm512_setzero_ps(); \
|
||||
zmm18 = _mm512_setzero_ps(); \
|
||||
zmm19 = _mm512_setzero_ps(); \
|
||||
zmm20 = _mm512_setzero_ps(); \
|
||||
zmm21 = _mm512_setzero_ps(); \
|
||||
zmm22 = _mm512_setzero_ps(); \
|
||||
zmm23 = _mm512_setzero_ps(); \
|
||||
zmm24 = _mm512_setzero_ps(); \
|
||||
zmm25 = _mm512_setzero_ps(); \
|
||||
zmm26 = _mm512_setzero_ps(); \
|
||||
zmm27 = _mm512_setzero_ps(); \
|
||||
zmm28 = _mm512_setzero_ps(); \
|
||||
zmm29 = _mm512_setzero_ps(); \
|
||||
zmm30 = _mm512_setzero_ps(); \
|
||||
zmm31 = _mm512_setzero_ps();
|
||||
|
||||
/* Zero-out all YMM registers */
|
||||
#define ZERO_YMM_ALL \
|
||||
ymm0 = _mm256_setzero_ps(); \
|
||||
ymm1 = _mm256_setzero_ps(); \
|
||||
ymm2 = _mm256_setzero_ps(); \
|
||||
ymm3 = _mm256_setzero_ps(); \
|
||||
ymm4 = _mm256_setzero_ps(); \
|
||||
ymm5 = _mm256_setzero_ps(); \
|
||||
ymm6 = _mm256_setzero_ps(); \
|
||||
ymm7 = _mm256_setzero_ps(); \
|
||||
ymm8 = _mm256_setzero_ps(); \
|
||||
ymm9 = _mm256_setzero_ps(); \
|
||||
ymm10 = _mm256_setzero_ps(); \
|
||||
ymm11 = _mm256_setzero_ps(); \
|
||||
ymm12 = _mm256_setzero_ps(); \
|
||||
ymm13 = _mm256_setzero_ps(); \
|
||||
ymm14 = _mm256_setzero_ps(); \
|
||||
ymm15 = _mm256_setzero_ps();
|
||||
|
||||
/* Zero-out all XMM registers */
|
||||
#define ZERO_XMM_ALL \
|
||||
xmm0 = _mm_setzero_ps(); \
|
||||
xmm1 = _mm_setzero_ps(); \
|
||||
xmm2 = _mm_setzero_ps(); \
|
||||
xmm3 = _mm_setzero_ps(); \
|
||||
xmm4 = _mm_setzero_ps(); \
|
||||
xmm5 = _mm_setzero_ps(); \
|
||||
xmm6 = _mm_setzero_ps(); \
|
||||
xmm7 = _mm_setzero_ps();
|
||||
|
||||
|
||||
/* ReLU scale (Parametric ReLU): f(x) = x, when x > 0 and f(x) = a*x when x <= 0 */
|
||||
#define RELU_SCALE_OP_F32S_AVX512(reg) \
|
||||
/* Generate indenx of elements <= 0.*/ \
|
||||
@@ -62,26 +128,60 @@
|
||||
\
|
||||
reg = _mm512_min_ps( _mm512_max_ps( reg, min ), max ); \
|
||||
|
||||
//Zero-out the given ZMM accumulator registers
|
||||
#define ZERO_ACC_ZMM_4_REG(zmm0,zmm1,zmm2,zmm3) \
|
||||
zmm0 = _mm512_setzero_ps(); \
|
||||
zmm1 = _mm512_setzero_ps(); \
|
||||
zmm2 = _mm512_setzero_ps(); \
|
||||
zmm3 = _mm512_setzero_ps();
|
||||
|
||||
// Zero-out the given ZMM accumulator registers
|
||||
#define ZERO_ACC_ZMM_4_REG(zmm0,zmm1,zmm2,zmm3) \
|
||||
zmm0 = _mm512_setzero_ps(); \
|
||||
zmm1 = _mm512_setzero_ps(); \
|
||||
zmm2 = _mm512_setzero_ps(); \
|
||||
zmm3 = _mm512_setzero_ps();
|
||||
|
||||
#define ZERO_ACC_ZMM_3_REG(zmm0,zmm1,zmm2) \
|
||||
zmm0 = _mm512_setzero_ps(); \
|
||||
zmm1 = _mm512_setzero_ps(); \
|
||||
zmm2 = _mm512_setzero_ps();
|
||||
|
||||
#define ZERO_ACC_ZMM_2_REG(zmm0,zmm1) \
|
||||
zmm0 = _mm512_setzero_ps(); \
|
||||
zmm1 = _mm512_setzero_ps();
|
||||
|
||||
// Zero-out the given YMM accumulator registers
|
||||
#define ZERO_ACC_YMM_4_REG(ymm0, ymm1, ymm2, ymm3) \
|
||||
ymm0 = _mm256_setzero_ps(); \
|
||||
ymm1 = _mm256_setzero_ps(); \
|
||||
ymm2 = _mm256_setzero_ps(); \
|
||||
ymm3 = _mm256_setzero_ps();
|
||||
|
||||
#define ZERO_ACC_YMM_3_REG(ymm0,ymm1,ymm2) \
|
||||
ymm0 = _mm256_setzero_ps(); \
|
||||
ymm1 = _mm256_setzero_ps(); \
|
||||
ymm2 = _mm256_setzero_ps();
|
||||
|
||||
#define ZERO_ACC_YMM_2_REG(ymm0,ymm1) \
|
||||
ymm0 = _mm256_setzero_ps(); \
|
||||
ymm1 = _mm256_setzero_ps();
|
||||
|
||||
// Zero-out the given YMM accumulator registers
|
||||
#define ZERO_ACC_XMM_4_REG(xmm0, xmm1, xmm2, xmm3) \
|
||||
xmm0 = _mm_setzero_ps(); \
|
||||
xmm1 = _mm_setzero_ps(); \
|
||||
xmm2 = _mm_setzero_ps(); \
|
||||
xmm3 = _mm_setzero_ps();
|
||||
xmm0 = _mm_setzero_ps(); \
|
||||
xmm1 = _mm_setzero_ps(); \
|
||||
xmm2 = _mm_setzero_ps(); \
|
||||
xmm3 = _mm_setzero_ps();
|
||||
|
||||
#define ZERO_ACC_XMM_3_REG(xmm0,xmm1,xmm2) \
|
||||
xmm0 = _mm_setzero_ps(); \
|
||||
xmm1 = _mm_setzero_ps(); \
|
||||
xmm2 = _mm_setzero_ps();
|
||||
|
||||
#define ZERO_ACC_XMM_2_REG(xmm0,xmm1) \
|
||||
xmm0 = _mm_setzero_ps(); \
|
||||
xmm1 = _mm_setzero_ps();
|
||||
|
||||
/*Multiply alpha with accumulator registers and store back*/
|
||||
#define ALPHA_MUL_ACC_ZMM_4_REG(zmm0,zmm1,zmm2,zmm3,alpha) \
|
||||
zmm0 = _mm512_mul_ps(zmm0,alpha); \
|
||||
zmm1 = _mm512_mul_ps(zmm1,alpha); \
|
||||
zmm2 = _mm512_mul_ps(zmm2,alpha); \
|
||||
zmm3 = _mm512_mul_ps(zmm3,alpha);
|
||||
zmm0 = _mm512_mul_ps(zmm0,alpha); \
|
||||
zmm1 = _mm512_mul_ps(zmm1,alpha); \
|
||||
zmm2 = _mm512_mul_ps(zmm2,alpha); \
|
||||
zmm3 = _mm512_mul_ps(zmm3,alpha);
|
||||
|
||||
// BF16 bias helper macros.
|
||||
#define BF16_F32_BIAS_LOAD(scr,mask,n_ind) \
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user