mirror of
https://github.com/amd/blis.git
synced 2026-05-04 22:41:11 +00:00
Added bf16s4f32 kernels to handle m=4 cases
Details: - In WOQ, if m = 4, special case kernels are added where s4->bf16 conversion happens inside the compute kernel and packing is avoided. For all other cases, B matrix is dequantized and packed at KC loop level and native bf16 kernels are re-used at compute level. - Fixes in bench to avoid accuracy failures when datatype of output is bf16. Change-Id: Ie8db42da536891693d5e82a5336b66514a50ccb2
This commit is contained in:
committed by
Nallani Bhaskar
parent
711dce14d0
commit
2e1cc2f14a
@@ -92,6 +92,7 @@ LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m);
|
||||
LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64);
|
||||
LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32);
|
||||
|
||||
|
||||
#define LPGEMM_M_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
@@ -177,6 +178,75 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32);
|
||||
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32);
|
||||
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32);
|
||||
|
||||
#define LPGEMM_MAIN_KERN1(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
const dim_t m0, \
|
||||
const dim_t n0, \
|
||||
const dim_t k0, \
|
||||
const A_type* a, \
|
||||
const dim_t rs_a, \
|
||||
const dim_t cs_a, \
|
||||
const dim_t ps_a, \
|
||||
const B_type* b, \
|
||||
const dim_t rs_b, \
|
||||
const dim_t cs_b, \
|
||||
C_type* c, \
|
||||
const dim_t rs_c, \
|
||||
const dim_t cs_c, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
lpgemm_post_op* post_ops_list, \
|
||||
lpgemm_post_op_attr post_ops_attr \
|
||||
) \
|
||||
|
||||
LPGEMM_MAIN_KERN1(bfloat16,int8_t,float,bf16s4f32of32_4x64);
|
||||
|
||||
|
||||
#define LPGEMM_M_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
const dim_t k0, \
|
||||
const A_type* a, \
|
||||
const dim_t rs_a, \
|
||||
const dim_t cs_a, \
|
||||
const B_type* b, \
|
||||
const dim_t rs_b, \
|
||||
const dim_t cs_b, \
|
||||
C_type* c, \
|
||||
const dim_t rs_c, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
lpgemm_post_op* post_ops_list, \
|
||||
lpgemm_post_op_attr post_ops_attr \
|
||||
) \
|
||||
|
||||
LPGEMM_M_FRINGE_KERN1( bfloat16, int8_t, float, bf16s4f32of32_4x48 );
|
||||
LPGEMM_M_FRINGE_KERN1( bfloat16, int8_t, float, bf16s4f32of32_4x32 );
|
||||
LPGEMM_M_FRINGE_KERN1( bfloat16, int8_t, float, bf16s4f32of32_4x16 );
|
||||
|
||||
#define LPGEMM_N_LT_NR0_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
const dim_t k0, \
|
||||
const A_type* a, \
|
||||
const dim_t rs_a, \
|
||||
const dim_t cs_a, \
|
||||
const B_type* b, \
|
||||
const dim_t rs_b, \
|
||||
const dim_t cs_b, \
|
||||
C_type* c, \
|
||||
const dim_t rs_c, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
const dim_t n0_rem, \
|
||||
lpgemm_post_op* post_ops_list, \
|
||||
lpgemm_post_op_attr post_ops_attr \
|
||||
) \
|
||||
|
||||
LPGEMM_N_LT_NR0_FRINGE_KERN1( bfloat16, int8_t, float, bf16s4f32of32_4xlt16 );
|
||||
|
||||
|
||||
#define LPGEMM_N_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
|
||||
Reference in New Issue
Block a user