Added 6x16 and 6xlt16 main kernels for f32 using AVX512 instructions (#38)

* Implemented 6xlt8 AVX2 kernel for n<8 inputs

* Implemented fringe kernels for 6x16 and 6xlt16 AVX512 kernels for FP32

* Implemented m-fringe kernels for 6xlt8 kernel for AVX2

* Implemented m-fringe kernels for 6xlt8 kernel for AVX2

* Added the deleted kernels and fixed bias bug

AMD-Internal: SWLCSG-3556
This commit is contained in:
Vankadari, Meghana
2025-06-05 15:17:02 +05:30
committed by GitHub
parent 14e46ad83b
commit 37efbd284e
11 changed files with 12500 additions and 425 deletions

View File

@@ -64,6 +64,43 @@ typedef void (*lpgemm_m_fringe_f32_ker_ft)
lpgemm_post_op_attr post_ops_attr
);
typedef void (*lpgemm_n_fringe_f32_ker_ft)
(
const dim_t m0,
const dim_t k0,
const float* a,
const dim_t rs_a,
const dim_t cs_a,
const dim_t ps_a,
const float* b,
const dim_t rs_b,
const dim_t cs_b,
float* c,
const dim_t rs_c,
const float alpha,
const float beta,
lpgemm_post_op* post_ops_list,
lpgemm_post_op_attr post_ops_attr
);
typedef void (*lpgemm_mn_fringe_f32_mask_ker_ft)
(
const dim_t k0,
const float* a,
const dim_t rs_a,
const dim_t cs_a,
const float* b,
const dim_t rs_b,
const dim_t cs_b,
float* c,
const dim_t rs_c,
const float alpha,
const float beta,
const dim_t n0_rem,
lpgemm_post_op* post_ops_list,
lpgemm_post_op_attr post_ops_attr
);
#define LPGEMM_MAIN_KERN(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
@@ -242,6 +279,11 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x32);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x32);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x32);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x32);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x16);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x16);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x16);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x16);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x16);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x16);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x16);
LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x16);
@@ -363,6 +405,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x48);
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x48m);
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x32m);
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x16m);
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x8m);
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x4m);
LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x2m);
@@ -472,6 +515,8 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16);
LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16);
LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6xlt16);
LPGEMM_N_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6xlt16m);
LPGEMM_N_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_6xlt8m);
LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16);
@@ -674,6 +719,18 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_5xlt8);
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_4xlt8);
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_3xlt8);
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_2xlt8);
LPGEMM_MN_LT_NR0_FRINGE_KERN(float,float,float,f32f32f32of32_1xlt8);
#define LPGEMM_MN_LT_NR0_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \