Added Post-Ops Support for F32 RD Kernels

- Support for Post-Ops has been added for all F32 RD AVX512 and AVX2
  kernels.

AMD-Internal: [SWLCSG-3497]
Change-Id: Ia2967417303d8278c547957878d93c42c887109e
This commit is contained in:
Arnav Sharma
2025-04-07 16:10:15 +05:30
committed by Arnav Sharma
parent a5f11a1540
commit 267aae80ea
7 changed files with 23546 additions and 5642 deletions

View File

@@ -92,19 +92,67 @@
reg = _mm_mul_ps(reg, selector); \
reg = _mm_add_ps(reg, zero_point); \
//Zero-out the given YMM accumulator registers
#define ZERO_ACC_YMM_4_REG(ymm0,ymm1,ymm2,ymm3) \
ymm0 = _mm256_setzero_ps(); \
ymm1 = _mm256_setzero_ps(); \
ymm2 = _mm256_setzero_ps(); \
ymm3 = _mm256_setzero_ps();
/* Zero-out all YMM registers */
#define ZERO_YMM_ALL \
ymm0 = _mm256_setzero_ps(); \
ymm1 = _mm256_setzero_ps(); \
ymm2 = _mm256_setzero_ps(); \
ymm3 = _mm256_setzero_ps(); \
ymm4 = _mm256_setzero_ps(); \
ymm5 = _mm256_setzero_ps(); \
ymm6 = _mm256_setzero_ps(); \
ymm7 = _mm256_setzero_ps(); \
ymm8 = _mm256_setzero_ps(); \
ymm9 = _mm256_setzero_ps(); \
ymm10 = _mm256_setzero_ps(); \
ymm11 = _mm256_setzero_ps(); \
ymm12 = _mm256_setzero_ps(); \
ymm13 = _mm256_setzero_ps(); \
ymm14 = _mm256_setzero_ps(); \
ymm15 = _mm256_setzero_ps();
//Zero-out the given XMM accumulator registers
#define ZERO_ACC_XMM_4_REG(xmm0,xmm1,xmm2,xmm3) \
xmm0 = _mm_setzero_ps(); \
xmm1 = _mm_setzero_ps(); \
xmm2 = _mm_setzero_ps(); \
xmm3 = _mm_setzero_ps();
/* Zero-out all XMM registers */
#define ZERO_XMM_ALL \
xmm0 = _mm_setzero_ps(); \
xmm1 = _mm_setzero_ps(); \
xmm2 = _mm_setzero_ps(); \
xmm3 = _mm_setzero_ps(); \
xmm4 = _mm_setzero_ps(); \
xmm5 = _mm_setzero_ps(); \
xmm6 = _mm_setzero_ps(); \
xmm7 = _mm_setzero_ps();
// Zero-out the given YMM accumulator registers
#define ZERO_ACC_YMM_4_REG(ymm0, ymm1, ymm2, ymm3) \
ymm0 = _mm256_setzero_ps(); \
ymm1 = _mm256_setzero_ps(); \
ymm2 = _mm256_setzero_ps(); \
ymm3 = _mm256_setzero_ps();
#define ZERO_ACC_YMM_3_REG(ymm0,ymm1,ymm2) \
ymm0 = _mm256_setzero_ps(); \
ymm1 = _mm256_setzero_ps(); \
ymm2 = _mm256_setzero_ps();
#define ZERO_ACC_YMM_2_REG(ymm0,ymm1) \
ymm0 = _mm256_setzero_ps(); \
ymm1 = _mm256_setzero_ps();
// Zero-out the given YMM accumulator registers
#define ZERO_ACC_XMM_4_REG(xmm0, xmm1, xmm2, xmm3) \
xmm0 = _mm_setzero_ps(); \
xmm1 = _mm_setzero_ps(); \
xmm2 = _mm_setzero_ps(); \
xmm3 = _mm_setzero_ps();
#define ZERO_ACC_XMM_3_REG(xmm0,xmm1,xmm2) \
xmm0 = _mm_setzero_ps(); \
xmm1 = _mm_setzero_ps(); \
xmm2 = _mm_setzero_ps();
#define ZERO_ACC_XMM_2_REG(xmm0,xmm1) \
xmm0 = _mm_setzero_ps(); \
xmm1 = _mm_setzero_ps();
/*Multiply alpha with accumulator registers and store back*/
#define ALPHA_MUL_ACC_YMM_4_REG(ymm0,ymm1,ymm2,ymm3,alpha) \
@@ -266,6 +314,9 @@ multiply with Beta, and add to alpha*A*B*/
F32_F32_MATRIX_ADD_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
F32_MATRIX_ADD_1COL_YMM(scr0,m_ind,r_ind0); \
#ifdef F32_F32_MATRIX_ADD_2COL
#undef F32_F32_MATRIX_ADD_2COL
#endif
#define F32_F32_MATRIX_ADD_2COL(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
F32_F32_MATRIX_ADD_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
F32_F32_MATRIX_ADD_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
@@ -287,6 +338,9 @@ multiply with Beta, and add to alpha*A*B*/
); \
scr = _mm256_mul_ps( scr, scl_fct ); \
#ifdef BF16_F32_MATRIX_ADD_2COL
#undef BF16_F32_MATRIX_ADD_2COL
#endif
#define BF16_F32_MATRIX_ADD_2COL(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
BF16_F32_MATRIX_ADD_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
BF16_F32_MATRIX_ADD_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
@@ -418,6 +472,9 @@ multiply with Beta, and add to alpha*A*B*/
F32_F32_MATRIX_MUL_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
F32_MATRIX_MUL_1COL_YMM(scr0,m_ind,r_ind0); \
#ifdef F32_F32_MATRIX_MUL_2COL
#undef F32_F32_MATRIX_MUL_2COL
#endif
#define F32_F32_MATRIX_MUL_2COL(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
F32_F32_MATRIX_MUL_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
F32_F32_MATRIX_MUL_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
@@ -452,6 +509,9 @@ multiply with Beta, and add to alpha*A*B*/
BF16_F32_MATRIX_MUL_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
F32_MATRIX_MUL_1COL_YMM(scr0,m_ind,r_ind0); \
#ifdef BF16_F32_MATRIX_MUL_2COL
#undef BF16_F32_MATRIX_MUL_2COL
#endif
#define BF16_F32_MATRIX_MUL_2COL(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
BF16_F32_MATRIX_MUL_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
BF16_F32_MATRIX_MUL_LOAD_YMM(scr1,scl_fct1,m_ind,1); \

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -40,6 +40,72 @@
#include "../sigmoid_avx512.h"
#include "../math_utils_avx512.h"
/* Zero-out all ZMM registers */
#define ZERO_ZMM_ALL \
zmm0 = _mm512_setzero_ps(); \
zmm1 = _mm512_setzero_ps(); \
zmm2 = _mm512_setzero_ps(); \
zmm3 = _mm512_setzero_ps(); \
zmm4 = _mm512_setzero_ps(); \
zmm5 = _mm512_setzero_ps(); \
zmm6 = _mm512_setzero_ps(); \
zmm7 = _mm512_setzero_ps(); \
zmm8 = _mm512_setzero_ps(); \
zmm9 = _mm512_setzero_ps(); \
zmm10 = _mm512_setzero_ps(); \
zmm11 = _mm512_setzero_ps(); \
zmm12 = _mm512_setzero_ps(); \
zmm13 = _mm512_setzero_ps(); \
zmm14 = _mm512_setzero_ps(); \
zmm15 = _mm512_setzero_ps(); \
zmm16 = _mm512_setzero_ps(); \
zmm17 = _mm512_setzero_ps(); \
zmm18 = _mm512_setzero_ps(); \
zmm19 = _mm512_setzero_ps(); \
zmm20 = _mm512_setzero_ps(); \
zmm21 = _mm512_setzero_ps(); \
zmm22 = _mm512_setzero_ps(); \
zmm23 = _mm512_setzero_ps(); \
zmm24 = _mm512_setzero_ps(); \
zmm25 = _mm512_setzero_ps(); \
zmm26 = _mm512_setzero_ps(); \
zmm27 = _mm512_setzero_ps(); \
zmm28 = _mm512_setzero_ps(); \
zmm29 = _mm512_setzero_ps(); \
zmm30 = _mm512_setzero_ps(); \
zmm31 = _mm512_setzero_ps();
/* Zero-out all YMM registers */
#define ZERO_YMM_ALL \
ymm0 = _mm256_setzero_ps(); \
ymm1 = _mm256_setzero_ps(); \
ymm2 = _mm256_setzero_ps(); \
ymm3 = _mm256_setzero_ps(); \
ymm4 = _mm256_setzero_ps(); \
ymm5 = _mm256_setzero_ps(); \
ymm6 = _mm256_setzero_ps(); \
ymm7 = _mm256_setzero_ps(); \
ymm8 = _mm256_setzero_ps(); \
ymm9 = _mm256_setzero_ps(); \
ymm10 = _mm256_setzero_ps(); \
ymm11 = _mm256_setzero_ps(); \
ymm12 = _mm256_setzero_ps(); \
ymm13 = _mm256_setzero_ps(); \
ymm14 = _mm256_setzero_ps(); \
ymm15 = _mm256_setzero_ps();
/* Zero-out all XMM registers */
#define ZERO_XMM_ALL \
xmm0 = _mm_setzero_ps(); \
xmm1 = _mm_setzero_ps(); \
xmm2 = _mm_setzero_ps(); \
xmm3 = _mm_setzero_ps(); \
xmm4 = _mm_setzero_ps(); \
xmm5 = _mm_setzero_ps(); \
xmm6 = _mm_setzero_ps(); \
xmm7 = _mm_setzero_ps();
/* ReLU scale (Parametric ReLU): f(x) = x, when x > 0 and f(x) = a*x when x <= 0 */
#define RELU_SCALE_OP_F32S_AVX512(reg) \
/* Generate indenx of elements <= 0.*/ \
@@ -62,26 +128,60 @@
\
reg = _mm512_min_ps( _mm512_max_ps( reg, min ), max ); \
//Zero-out the given ZMM accumulator registers
#define ZERO_ACC_ZMM_4_REG(zmm0,zmm1,zmm2,zmm3) \
zmm0 = _mm512_setzero_ps(); \
zmm1 = _mm512_setzero_ps(); \
zmm2 = _mm512_setzero_ps(); \
zmm3 = _mm512_setzero_ps();
// Zero-out the given ZMM accumulator registers
#define ZERO_ACC_ZMM_4_REG(zmm0,zmm1,zmm2,zmm3) \
zmm0 = _mm512_setzero_ps(); \
zmm1 = _mm512_setzero_ps(); \
zmm2 = _mm512_setzero_ps(); \
zmm3 = _mm512_setzero_ps();
#define ZERO_ACC_ZMM_3_REG(zmm0,zmm1,zmm2) \
zmm0 = _mm512_setzero_ps(); \
zmm1 = _mm512_setzero_ps(); \
zmm2 = _mm512_setzero_ps();
#define ZERO_ACC_ZMM_2_REG(zmm0,zmm1) \
zmm0 = _mm512_setzero_ps(); \
zmm1 = _mm512_setzero_ps();
// Zero-out the given YMM accumulator registers
#define ZERO_ACC_YMM_4_REG(ymm0, ymm1, ymm2, ymm3) \
ymm0 = _mm256_setzero_ps(); \
ymm1 = _mm256_setzero_ps(); \
ymm2 = _mm256_setzero_ps(); \
ymm3 = _mm256_setzero_ps();
#define ZERO_ACC_YMM_3_REG(ymm0,ymm1,ymm2) \
ymm0 = _mm256_setzero_ps(); \
ymm1 = _mm256_setzero_ps(); \
ymm2 = _mm256_setzero_ps();
#define ZERO_ACC_YMM_2_REG(ymm0,ymm1) \
ymm0 = _mm256_setzero_ps(); \
ymm1 = _mm256_setzero_ps();
// Zero-out the given YMM accumulator registers
#define ZERO_ACC_XMM_4_REG(xmm0, xmm1, xmm2, xmm3) \
xmm0 = _mm_setzero_ps(); \
xmm1 = _mm_setzero_ps(); \
xmm2 = _mm_setzero_ps(); \
xmm3 = _mm_setzero_ps();
xmm0 = _mm_setzero_ps(); \
xmm1 = _mm_setzero_ps(); \
xmm2 = _mm_setzero_ps(); \
xmm3 = _mm_setzero_ps();
#define ZERO_ACC_XMM_3_REG(xmm0,xmm1,xmm2) \
xmm0 = _mm_setzero_ps(); \
xmm1 = _mm_setzero_ps(); \
xmm2 = _mm_setzero_ps();
#define ZERO_ACC_XMM_2_REG(xmm0,xmm1) \
xmm0 = _mm_setzero_ps(); \
xmm1 = _mm_setzero_ps();
/*Multiply alpha with accumulator registers and store back*/
#define ALPHA_MUL_ACC_ZMM_4_REG(zmm0,zmm1,zmm2,zmm3,alpha) \
zmm0 = _mm512_mul_ps(zmm0,alpha); \
zmm1 = _mm512_mul_ps(zmm1,alpha); \
zmm2 = _mm512_mul_ps(zmm2,alpha); \
zmm3 = _mm512_mul_ps(zmm3,alpha);
zmm0 = _mm512_mul_ps(zmm0,alpha); \
zmm1 = _mm512_mul_ps(zmm1,alpha); \
zmm2 = _mm512_mul_ps(zmm2,alpha); \
zmm3 = _mm512_mul_ps(zmm3,alpha);
// BF16 bias helper macros.
#define BF16_F32_BIAS_LOAD(scr,mask,n_ind) \

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff