Implemented a new set of kernels for f32 using 32 YMM regs

Details:
- These kernels are picked from cntx when GEMM is invoked
  on machines that support AVX512 instructions by forcing the
  AVX2 path using AOCL_ENABLE_INSTRUCTIONS=AVX2 during run-time.
- This path uses the same blocksizes and pack kernels as AVX512
  path.
- GEMV is disabled currently as AVX2 kernels for GEMV are not
  implemented.

AMD-Internal: [SWLCSG-3519]
Change-Id: I75401fac48478fe99edb8e71fa44d36dd7513ae5
This commit is contained in:
Meghana Vankadari
2025-04-01 09:10:51 +00:00
parent 48c7452b08
commit 4745cf876e
11 changed files with 4722 additions and 192 deletions

View File

@@ -148,7 +148,7 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
rs_b = 1;
cs_b = ldb;
}
const inc_t rs_c = ldc;
const inc_t cs_c = 1;
@@ -168,7 +168,7 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32)
// Inputs swapped in column major, A becomes B from kernel point of view.
else if ( ( is_column_major == TRUE ) && ( ( mtag_b == REORDERED ) || (mtag_a == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.",
bli_print_msg(" Reordering of column major matrices is not supported.",
__FILE__, __LINE__ );
goto err_hndl;
}

View File

@@ -57,7 +57,7 @@
XMACRO(F32OBF16, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
#define LPGEMM_BLKSZ_UPD_MAP_ZEN4_TO_ZEN \
XMACRO(F32F32F32OF32, 144, 8160, 512, 6, 16, 1, 6, 16, 1) \
XMACRO(F32F32F32OF32, 144, 8160, 512, 6, 64, 1, 6, 64, 1) \
// The STMACRO follows the format MT, NT, KT which are SUP switch thresholds.
// ID = One of the AOCL_OPERATION_TYPE enum.

View File

@@ -239,8 +239,6 @@ static void _lpgemm_cntx_init_func_map()
}
#endif
#endif
// If arch is updated at runtime, it is expeceted to be honoured.
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
{
@@ -248,6 +246,8 @@ static void _lpgemm_cntx_init_func_map()
LPGEMM_PACKA_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
}
#endif
}
else if ( bli_cpuid_is_avx512vnni_supported() == TRUE )
{
@@ -256,7 +256,6 @@ static void _lpgemm_cntx_init_func_map()
LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI
LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI
LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI
#endif
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
{
@@ -264,6 +263,7 @@ static void _lpgemm_cntx_init_func_map()
LPGEMM_PACKA_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2;
LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2;
}
#endif
}
else if ( bli_cpuid_is_avx2fma3_supported() == TRUE )
{

View File

@@ -53,7 +53,7 @@
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2 \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_256_6x64m) \
KMACRO(BF16BF16F32OF32, NULL) \
KMACRO(BF16S4F32OF32, NULL) \
@@ -76,7 +76,7 @@
PBMACRO(BF16S4F32OF32, packb_nr64_bf16s4f32of32)
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2 \
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI_BF16 \
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
@@ -110,7 +110,7 @@
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_256_6x64m) \
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI \
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
@@ -134,7 +134,7 @@
PBSMACRO(BF16S4F32OF32, packb_nr64_bf16s4f32of32)
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
#define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512_VNNI \
UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \
@@ -151,7 +151,7 @@
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_TO_AVX2 \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_256_6x64m) \
#define LPGEMM_PACKA_FUNC_MAP_AVX512 \
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
@@ -173,7 +173,7 @@
PBSMACRO(BF16S4F32OF32, NULL)
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512 \
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)

View File

@@ -318,6 +318,7 @@ LPGEMV(float, float, float, f32f32f32of32)
LPGEMM_5LOOP(float, float, float, f32f32f32of32)
{
#ifdef BLIS_KERNELS_ZEN4
// Handle using LPGEMV when m or/and n equal to 1
// The avx512 check will be removed when avx2 kernels added in future
@@ -424,8 +425,9 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
// Avoid packing of B in transb cases where rd kernels performs
// better than rv + pack. rv kernel calls rd when rs_b==1.
bool invoke_rd = FALSE;
if( ( ( n < 48 ) || ( m < 16 ) ) &&
( rs_b == 1 ) && ( mtag_b == PACK ) &&
if( ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3) &&
( ( n < 48 ) || ( m < 16 ) ) && ( rs_b == 1 ) && ( mtag_b == PACK ) &&
( mtag_a == UNPACKED ) )
{
invoke_rd = TRUE;

View File

@@ -94,6 +94,7 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x8m_rd);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x4m_rd);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x2m_rd);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x1m_rd);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_256_6x64m);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m_rd);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x48m_rd);
@@ -267,6 +268,12 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x1);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x1);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x1);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_256_5x32);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_256_4x32);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_256_3x32);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_256_2x32);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_256_1x32);
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x64);
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x64);
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x64);

View File

@@ -720,36 +720,36 @@ POST_OPS_MATRIX_ADD_5x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
// c[2:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
// c[3:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
// c[4:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,4,12,13);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,4,12,13);
}
else
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
// c[2:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
// c[3:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
// c[4:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr5,scl_fctr5,4,12,13);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr5,scl_fctr5,4,12,13);
}
}
else
@@ -760,36 +760,36 @@ POST_OPS_MATRIX_ADD_5x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
// c[3:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
// c[4:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,4,12,13);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,4,12,13);
}
else
{
// c[0:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
// c[3:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
// c[4:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr5,scl_fctr5,4,12,13);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr5,scl_fctr5,4,12,13);
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -903,36 +903,36 @@ POST_OPS_MATRIX_MUL_5x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
// c[3:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
// c[4:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,4,12,13);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,4,12,13);
}
else
{
// c[0:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
// c[3:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
// c[4:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr5,scl_fctr5,4,12,13);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr5,scl_fctr5,4,12,13);
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -1676,30 +1676,30 @@ POST_OPS_MATRIX_ADD_4x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
// c[2:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
// c[3:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
}
else
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
// c[2:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
// c[3:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
}
}
else
@@ -1710,30 +1710,30 @@ POST_OPS_MATRIX_ADD_4x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
// c[3:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
}
else
{
// c[0:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
// c[3:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -1835,30 +1835,30 @@ POST_OPS_MATRIX_MUL_4x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
// c[3:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
}
else
{
// c[0:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
// c[3:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -2495,24 +2495,24 @@ POST_OPS_MATRIX_ADD_3x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
// c[2:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
}
else
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
// c[2:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
}
}
else
@@ -2523,24 +2523,24 @@ POST_OPS_MATRIX_ADD_3x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
}
else
{
// c[0:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -2630,24 +2630,24 @@ POST_OPS_MATRIX_MUL_3x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
}
else
{
// c[0:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -3172,18 +3172,18 @@ POST_OPS_MATRIX_ADD_2x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
}
else
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
}
}
else
@@ -3194,18 +3194,18 @@ POST_OPS_MATRIX_ADD_2x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
}
else
{
// c[0:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -3283,18 +3283,18 @@ POST_OPS_MATRIX_MUL_2x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
}
else
{
// c[0:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -3723,12 +3723,12 @@ POST_OPS_MATRIX_ADD_1x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
}
else
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
}
}
else
@@ -3739,12 +3739,12 @@ POST_OPS_MATRIX_ADD_1x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
}
else
{
// c[0:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -3813,12 +3813,12 @@ POST_OPS_MATRIX_MUL_1x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
}
else
{
// c[0:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -4399,36 +4399,36 @@ POST_OPS_MATRIX_ADD_5x8F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,0,4);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,0,4);
// c[1:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,1,6);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,1,6);
// c[2:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,2,8);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,2,8);
// c[3:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,3,10);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,3,10);
// c[4:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,4,12);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,4,12);
}
else
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,0,4);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,0,4);
// c[1:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr2,1,6);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr2,1,6);
// c[2:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr3,2,8);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr3,2,8);
// c[3:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr4,3,10);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr4,3,10);
// c[4:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr5,4,12);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr5,4,12);
}
}
else
@@ -5163,30 +5163,30 @@ POST_OPS_MATRIX_ADD_4x8F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,0,4);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,0,4);
// c[1:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,1,6);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,1,6);
// c[2:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,2,8);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,2,8);
// c[3:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,3,10);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,3,10);
}
else
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,0,4);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,0,4);
// c[1:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr2,1,6);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr2,1,6);
// c[2:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr3,2,8);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr3,2,8);
// c[3:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr4,3,10);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr4,3,10);
}
}
else
@@ -5829,24 +5829,24 @@ POST_OPS_MATRIX_ADD_3x8F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,0,4);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,0,4);
// c[1:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,1,6);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,1,6);
// c[2:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,2,8);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,2,8);
}
else
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,0,4);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,0,4);
// c[1:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr2,1,6);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr2,1,6);
// c[2:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr3,2,8);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr3,2,8);
}
}
else
@@ -6398,18 +6398,18 @@ POST_OPS_MATRIX_ADD_2x8F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,0,4);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,0,4);
// c[1:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,1,6);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,1,6);
}
else
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,0,4);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,0,4);
// c[1:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr2,1,6);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr2,1,6);
}
}
else
@@ -6866,12 +6866,12 @@ POST_OPS_MATRIX_ADD_1x8F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,0,4);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,0,4);
}
else
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,0,4);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,0,4);
}
}
else

View File

@@ -266,6 +266,12 @@ multiply with Beta, and add to alpha*A*B*/
ymm ## r_ind0 = _mm256_add_ps( scr0, ymm ## r_ind0 ); \
ymm ## r_ind1 = _mm256_add_ps( scr1, ymm ## r_ind1 ); \
#define F32_MATRIX_ADD_4COL_YMM(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
ymm ## r_ind0 = _mm256_add_ps( scr0, ymm ## r_ind0 ); \
ymm ## r_ind1 = _mm256_add_ps( scr1, ymm ## r_ind1 ); \
ymm ## r_ind2 = _mm256_add_ps( scr2, ymm ## r_ind2 ); \
ymm ## r_ind3 = _mm256_add_ps( scr3, ymm ## r_ind3 ); \
#define F32_F32_MATRIX_ADD_LOAD_XMM_1ELE(scr,scl_fct,m_ind,n_ind) \
scr = ( __m128 )_mm_load_ss \
( \
@@ -317,11 +323,18 @@ multiply with Beta, and add to alpha*A*B*/
#ifdef F32_F32_MATRIX_ADD_2COL
#undef F32_F32_MATRIX_ADD_2COL
#endif
#define F32_F32_MATRIX_ADD_2COL(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
#define F32_F32_MATRIX_ADD_2COL_YMM(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
F32_F32_MATRIX_ADD_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
F32_F32_MATRIX_ADD_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
F32_MATRIX_ADD_2COL_YMM(scr0,scr1,m_ind,r_ind0,r_ind1); \
#define F32_F32_MATRIX_ADD_4COL_YMM(scr0,scr1,scr2,scr3,scl_fct0,scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
F32_F32_MATRIX_ADD_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
F32_F32_MATRIX_ADD_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
F32_F32_MATRIX_ADD_LOAD_YMM(scr2,scl_fct2,m_ind,2); \
F32_F32_MATRIX_ADD_LOAD_YMM(scr3,scl_fct3,m_ind,3); \
F32_MATRIX_ADD_4COL_YMM(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1, r_ind2, r_ind3); \
//Matrix-Add helpers for BF16 input.
#define BF16_F32_MATRIX_ADD_LOAD_YMM(scr,scl_fct,m_ind,n_ind) \
scr = (__m256)( _mm256_sllv_epi32 \
@@ -338,15 +351,12 @@ multiply with Beta, and add to alpha*A*B*/
); \
scr = _mm256_mul_ps( scr, scl_fct ); \
#ifdef BF16_F32_MATRIX_ADD_2COL
#undef BF16_F32_MATRIX_ADD_2COL
#endif
#define BF16_F32_MATRIX_ADD_2COL(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
#define BF16_F32_MATRIX_ADD_2COL_YMM(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
BF16_F32_MATRIX_ADD_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
BF16_F32_MATRIX_ADD_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
F32_MATRIX_ADD_2COL_YMM(scr0,scr1,m_ind,r_ind0,r_ind1); \
#define BF16_F32_MATRIX_ADD_1COL(scr0,scl_fct0,m_ind,r_ind0) \
#define BF16_F32_MATRIX_ADD_1COL_YMM(scr0,scl_fct0,m_ind,r_ind0) \
BF16_F32_MATRIX_ADD_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
F32_MATRIX_ADD_1COL_YMM(scr0,m_ind,r_ind0); \
@@ -424,6 +434,12 @@ multiply with Beta, and add to alpha*A*B*/
ymm ## r_ind0 = _mm256_mul_ps( scr0, ymm ## r_ind0 ); \
ymm ## r_ind1 = _mm256_mul_ps( scr1, ymm ## r_ind1 ); \
#define F32_MATRIX_MUL_4COL_YMM(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
ymm ## r_ind0 = _mm256_mul_ps( scr0, ymm ## r_ind0 ); \
ymm ## r_ind1 = _mm256_mul_ps( scr1, ymm ## r_ind1 ); \
ymm ## r_ind2 = _mm256_mul_ps( scr2, ymm ## r_ind2 ); \
ymm ## r_ind3 = _mm256_mul_ps( scr3, ymm ## r_ind3 ); \
#define F32_F32_MATRIX_MUL_LOAD_XMM_1ELE(scr,scl_fct,m_ind,n_ind) \
scr = ( __m128 )_mm_load_ss \
( \
@@ -472,14 +488,18 @@ multiply with Beta, and add to alpha*A*B*/
F32_F32_MATRIX_MUL_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
F32_MATRIX_MUL_1COL_YMM(scr0,m_ind,r_ind0); \
#ifdef F32_F32_MATRIX_MUL_2COL
#undef F32_F32_MATRIX_MUL_2COL
#endif
#define F32_F32_MATRIX_MUL_2COL(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
#define F32_F32_MATRIX_MUL_2COL_YMM(scr0,scr1,scl_fct0,scl_fct1,m_ind,r_ind0,r_ind1) \
F32_F32_MATRIX_MUL_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
F32_F32_MATRIX_MUL_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
F32_MATRIX_MUL_2COL_YMM(scr0,scr1,m_ind,r_ind0,r_ind1); \
#define F32_F32_MATRIX_MUL_4COL_YMM(scr0,scr1,scr2,scr3,scl_fct0,scl_fct1,scl_fct2,scl_fct3,m_ind,r_ind0,r_ind1,r_ind2,r_ind3) \
F32_F32_MATRIX_MUL_LOAD_YMM(scr0,scl_fct0,m_ind,0); \
F32_F32_MATRIX_MUL_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
F32_F32_MATRIX_MUL_LOAD_YMM(scr2,scl_fct2,m_ind,2); \
F32_F32_MATRIX_MUL_LOAD_YMM(scr3,scl_fct3,m_ind,3); \
F32_MATRIX_MUL_4COL_YMM(scr0,scr1,scr2,scr3,m_ind,r_ind0,r_ind1, r_ind2,r_ind3); \
//BF16->F32 Matrix Mul Helpers
#define BF16_F32_MATRIX_MUL_LOAD_XMM_1ELE(scr,scl_fct,m_ind,n_ind) \
BF16_F32_MATRIX_ADD_LOAD_XMM_1ELE(scr,scl_fct,m_ind,n_ind) \
@@ -517,6 +537,7 @@ multiply with Beta, and add to alpha*A*B*/
BF16_F32_MATRIX_MUL_LOAD_YMM(scr1,scl_fct1,m_ind,1); \
F32_MATRIX_MUL_2COL_YMM(scr0,scr1,m_ind,r_ind0,r_ind1); \
// TANH
#define TANH_F32S_AVX2(reg, r, r2, x, z, dn, q) \
\

View File

@@ -919,42 +919,42 @@ POST_OPS_MATRIX_ADD_6x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
// c[2:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
// c[3:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
// c[4:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,4,12,13);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,4,12,13);
// c[5:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,5,14,15);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,5,14,15);
}
else
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
// c[2:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
// c[3:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
// c[4:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr5,scl_fctr5,4,12,13);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr5,scl_fctr5,4,12,13);
// c[5:0-15]
BF16_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr6,scl_fctr6,5,14,15);
BF16_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr6,scl_fctr6,5,14,15);
}
}
else
@@ -965,42 +965,42 @@ POST_OPS_MATRIX_ADD_6x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
// c[3:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
// c[4:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,4,12,13);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,4,12,13);
// c[5:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,5,14,15);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,5,14,15);
}
else
{
// c[0:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
// c[3:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
// c[4:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr5,scl_fctr5,4,12,13);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr5,scl_fctr5,4,12,13);
// c[5:0-15]
F32_F32_MATRIX_ADD_2COL(ymm1,ymm2,scl_fctr6,scl_fctr6,5,14,15);
F32_F32_MATRIX_ADD_2COL_YMM(ymm1,ymm2,scl_fctr6,scl_fctr6,5,14,15);
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -1126,42 +1126,42 @@ POST_OPS_MATRIX_MUL_6x16F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,2,8,9);
// c[3:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,3,10,11);
// c[4:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,4,12,13);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,4,12,13);
// c[5:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr2,5,14,15);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr2,5,14,15);
}
else
{
// c[0:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr1,scl_fctr1,0,4,5);
// c[1:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr2,scl_fctr2,1,6,7);
// c[2:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr3,scl_fctr3,2,8,9);
// c[3:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr4,scl_fctr4,3,10,11);
// c[4:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr5,scl_fctr5,4,12,13);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr5,scl_fctr5,4,12,13);
// c[5:0-15]
F32_F32_MATRIX_MUL_2COL(ymm1,ymm2,scl_fctr6,scl_fctr6,5,14,15);
F32_F32_MATRIX_MUL_2COL_YMM(ymm1,ymm2,scl_fctr6,scl_fctr6,5,14,15);
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
@@ -1962,42 +1962,42 @@ POST_OPS_MATRIX_ADD_6x8F:
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,0,4);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,0,4);
// c[1:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,1,6);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,1,6);
// c[2:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,2,8);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,2,8);
// c[3:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,3,10);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,3,10);
// c[4:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,4,12);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,4,12);
// c[5:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,5,14);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,5,14);
}
else
{
// c[0:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr1,0,4);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr1,0,4);
// c[1:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr2,1,6);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr2,1,6);
// c[2:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr3,2,8);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr3,2,8);
// c[3:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr4,3,10);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr4,3,10);
// c[4:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr5,4,12);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr5,4,12);
// c[5:0-15]
BF16_F32_MATRIX_ADD_1COL(ymm1,scl_fctr6,5,14);
BF16_F32_MATRIX_ADD_1COL_YMM(ymm1,scl_fctr6,5,14);
}
}
else

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff