Added bf16s4f32 kernels to handle m=4 cases

Details:
- In WOQ, if m = 4, special case kernels are added where
  s4->bf16 conversion happens inside the compute kernel and
  packing is avoided. For all other cases, B matrix is
  dequantized and packed at KC loop level and native bf16
  kernels are re-used at compute level.
- Fixes in bench to avoid accuracy failures when datatype of
  output is bf16.

Change-Id: Ie8db42da536891693d5e82a5336b66514a50ccb2
This commit is contained in:
Meghana Vankadari
2024-09-02 10:39:49 +00:00
committed by Nallani Bhaskar
parent 711dce14d0
commit 2e1cc2f14a
8 changed files with 4961 additions and 17 deletions

View File

@@ -92,6 +92,7 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m);
LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64);
LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32);
#define LPGEMM_M_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
@@ -177,6 +178,75 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32);
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32);
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32);
#define LPGEMM_MAIN_KERN1(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
const dim_t m0, \
const dim_t n0, \
const dim_t k0, \
const A_type* a, \
const dim_t rs_a, \
const dim_t cs_a, \
const dim_t ps_a, \
const B_type* b, \
const dim_t rs_b, \
const dim_t cs_b, \
C_type* c, \
const dim_t rs_c, \
const dim_t cs_c, \
const C_type alpha, \
const C_type beta, \
lpgemm_post_op* post_ops_list, \
lpgemm_post_op_attr post_ops_attr \
) \
LPGEMM_MAIN_KERN1(bfloat16,int8_t,float,bf16s4f32of32_4x64);
#define LPGEMM_M_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
const dim_t k0, \
const A_type* a, \
const dim_t rs_a, \
const dim_t cs_a, \
const B_type* b, \
const dim_t rs_b, \
const dim_t cs_b, \
C_type* c, \
const dim_t rs_c, \
const C_type alpha, \
const C_type beta, \
lpgemm_post_op* post_ops_list, \
lpgemm_post_op_attr post_ops_attr \
) \
LPGEMM_M_FRINGE_KERN1( bfloat16, int8_t, float, bf16s4f32of32_4x48 );
LPGEMM_M_FRINGE_KERN1( bfloat16, int8_t, float, bf16s4f32of32_4x32 );
LPGEMM_M_FRINGE_KERN1( bfloat16, int8_t, float, bf16s4f32of32_4x16 );
#define LPGEMM_N_LT_NR0_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
const dim_t k0, \
const A_type* a, \
const dim_t rs_a, \
const dim_t cs_a, \
const B_type* b, \
const dim_t rs_b, \
const dim_t cs_b, \
C_type* c, \
const dim_t rs_c, \
const C_type alpha, \
const C_type beta, \
const dim_t n0_rem, \
lpgemm_post_op* post_ops_list, \
lpgemm_post_op_attr post_ops_attr \
) \
LPGEMM_N_LT_NR0_FRINGE_KERN1( bfloat16, int8_t, float, bf16s4f32of32_4xlt16 );
#define LPGEMM_N_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \