mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Implemented group level static quantization for s8s8s32of32|bf16 APIs
Details: - Group quantization is technique to improve accuracy where scale factors to quantize inputs and weights varies at group level instead of per channel and per tensor level. - Added new bench files to test GEMM with symmetric static quantization. - Added new get_size and reorder functions to account for storing sum of col-values separately per group. - Added new framework, kernels to support the same. - The scalefactors could be of type float or bf16. AMD-Internal:[SWLCSG-3274] Change-Id: I3e69ecd56faa2679a4f084031d35ffb76556230f
This commit is contained in:
2677
kernels/zen4/lpgemm/s8s8s32/lpgemm_6x64rowmajor_s8_grp_amd512vnni.c
Normal file
2677
kernels/zen4/lpgemm/s8s8s32/lpgemm_6x64rowmajor_s8_grp_amd512vnni.c
Normal file
File diff suppressed because it is too large
Load Diff
7609
kernels/zen4/lpgemm/s8s8s32/lpgemm_m_fringe_s8_grp_amd512vnni.c
Normal file
7609
kernels/zen4/lpgemm/s8s8s32/lpgemm_m_fringe_s8_grp_amd512vnni.c
Normal file
File diff suppressed because it is too large
Load Diff
20916
kernels/zen4/lpgemm/s8s8s32/lpgemm_mn_fringe_s8_grp_amd512vnni.c
Normal file
20916
kernels/zen4/lpgemm/s8s8s32/lpgemm_mn_fringe_s8_grp_amd512vnni.c
Normal file
File diff suppressed because it is too large
Load Diff
6938
kernels/zen4/lpgemm/s8s8s32/lpgemm_n_fringe_s8_grp_amd512vnni.c
Normal file
6938
kernels/zen4/lpgemm/s8s8s32/lpgemm_n_fringe_s8_grp_amd512vnni.c
Normal file
File diff suppressed because it is too large
Load Diff
@@ -153,7 +153,31 @@
|
||||
F32_S32_BETA_OP(c_int32_ ## m_ind ## p2,m_ir,m_ind,2,scratch1,scratch2); \
|
||||
F32_S32_BETA_OP(c_int32_ ## m_ind ## p3,m_ir,m_ind,3,scratch1,scratch2); \
|
||||
|
||||
// Downscale BF16 beta op
|
||||
// Downscale F32 beta op
|
||||
#define F32_F32_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \
|
||||
scratch1 = _mm512_loadu_ps( c + ( rs_c * ( m_ir + m_ind ) ) + ( n_ind * 16 ) ); \
|
||||
F32_BETA_FMA(reg,scratch1,scratch2) \
|
||||
|
||||
#define F32_F32_BETA_OP_NLT16F_MASK(ptr, mask, reg, m_ir, m_ind, n_ind, scratch1, scratch2) \
|
||||
scratch1 = _mm512_maskz_loadu_ps( mask, ptr + ( rs_c * ( m_ir + m_ind ) ) + ( n_ind * 16 ) ); \
|
||||
F32_BETA_FMA(reg, scratch1, scratch2) \
|
||||
|
||||
#define F32_F32_BETA_OP4(m_ir,m_ind,scratch1,scratch2) \
|
||||
F32_F32_BETA_OP(acc_ ## m_ind ## 0,m_ir,m_ind,0,scratch1,scratch2); \
|
||||
F32_F32_BETA_OP(acc_ ## m_ind ## 1,m_ir,m_ind,1,scratch1,scratch2); \
|
||||
F32_F32_BETA_OP(acc_ ## m_ind ## 2,m_ir,m_ind,2,scratch1,scratch2); \
|
||||
F32_F32_BETA_OP(acc_ ## m_ind ## 3,m_ir,m_ind,3,scratch1,scratch2); \
|
||||
|
||||
#define F32_F32_BETA_OP3(m_ir,m_ind,scratch1,scratch2) \
|
||||
F32_F32_BETA_OP(acc_ ## m_ind ## 0,m_ir,m_ind,0,scratch1,scratch2); \
|
||||
F32_F32_BETA_OP(acc_ ## m_ind ## 1,m_ir,m_ind,1,scratch1,scratch2); \
|
||||
F32_F32_BETA_OP(acc_ ## m_ind ## 2,m_ir,m_ind,2,scratch1,scratch2); \
|
||||
|
||||
#define F32_F32_BETA_OP2(m_ir,m_ind,scratch1,scratch2) \
|
||||
F32_F32_BETA_OP(acc_ ## m_ind ## 0,m_ir,m_ind,0,scratch1,scratch2); \
|
||||
F32_F32_BETA_OP(acc_ ## m_ind ## 1,m_ir,m_ind,1,scratch1,scratch2); \
|
||||
|
||||
// Downscale BF16 beta op
|
||||
#define BF16_S32_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \
|
||||
scratch1 = \
|
||||
_mm512_cvtps_epi32 \
|
||||
@@ -189,6 +213,60 @@
|
||||
BF16_S32_BETA_OP(c_int32_ ## m_ind ## p2,m_ir,m_ind,2,scratch1,scratch2); \
|
||||
BF16_S32_BETA_OP(c_int32_ ## m_ind ## p3,m_ir,m_ind,3,scratch1,scratch2); \
|
||||
|
||||
#define F32_BETA_FMA(reg,scratch1,scratch2) \
|
||||
scratch1 = _mm512_mul_ps( scratch2, scratch1 ); \
|
||||
reg = _mm512_add_ps( scratch1, reg ); \
|
||||
|
||||
#define BF16_F32_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \
|
||||
scratch1 = \
|
||||
(__m512)_mm512_sllv_epi32 \
|
||||
( \
|
||||
_mm512_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm256_maskz_loadu_epi16 \
|
||||
( \
|
||||
0xFFFF, \
|
||||
( bfloat16* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
), _mm512_set1_epi32( 16 ) \
|
||||
); \
|
||||
F32_BETA_FMA(reg,scratch1,scratch2) \
|
||||
|
||||
#define BF16_F32_BETA_OP_NLT16F_MASK(lmask,reg,m_ind,n_ind,scratch1,scratch2) \
|
||||
scratch1 = (__m512)_mm512_sllv_epi32 \
|
||||
( \
|
||||
_mm512_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm256_maskz_loadu_epi16 \
|
||||
( \
|
||||
lmask, \
|
||||
( bfloat16* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
|
||||
) \
|
||||
), _mm512_set1_epi32( 16 ) \
|
||||
); \
|
||||
F32_BETA_FMA(reg,scratch1,scratch2) \
|
||||
|
||||
|
||||
#define BF16_F32_BETA_OP4(m_ir,m_ind,scratch1,scratch2) \
|
||||
BF16_F32_BETA_OP(acc_ ## m_ind ## 0,m_ir,m_ind,0,scratch1,scratch2); \
|
||||
BF16_F32_BETA_OP(acc_ ## m_ind ## 1,m_ir,m_ind,1,scratch1,scratch2); \
|
||||
BF16_F32_BETA_OP(acc_ ## m_ind ## 2,m_ir,m_ind,2,scratch1,scratch2); \
|
||||
BF16_F32_BETA_OP(acc_ ## m_ind ## 3,m_ir,m_ind,3,scratch1,scratch2); \
|
||||
|
||||
#define BF16_F32_BETA_OP3(m_ir,m_ind,scratch1,scratch2) \
|
||||
BF16_F32_BETA_OP(acc_ ## m_ind ## 0,m_ir,m_ind,0,scratch1,scratch2); \
|
||||
BF16_F32_BETA_OP(acc_ ## m_ind ## 1,m_ir,m_ind,1,scratch1,scratch2); \
|
||||
BF16_F32_BETA_OP(acc_ ## m_ind ## 2,m_ir,m_ind,2,scratch1,scratch2); \
|
||||
|
||||
#define BF16_F32_BETA_OP2(m_ir,m_ind,scratch1,scratch2) \
|
||||
BF16_F32_BETA_OP(acc_ ## m_ind ## 0,m_ir,m_ind,0,scratch1,scratch2); \
|
||||
BF16_F32_BETA_OP(acc_ ## m_ind ## 1,m_ir,m_ind,1,scratch1,scratch2); \
|
||||
|
||||
|
||||
// Default n < 16 beta macro
|
||||
#define S32_S32_BETA_OP_NLT16F(reg,buf_,scratch1,scratch2) \
|
||||
scratch1 = _mm512_loadu_si512( buf_ ); \
|
||||
@@ -307,8 +385,82 @@
|
||||
) \
|
||||
); \
|
||||
|
||||
// s8 SYMM static quantization scale helper macros.
|
||||
#define SYM_QUANT_BF16_F32_SCL_LOAD(src,ptr,mask,n_ind) \
|
||||
src = (__m512)( _mm512_sllv_epi32 \
|
||||
( \
|
||||
_mm512_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm256_maskz_loadu_epi16 \
|
||||
( \
|
||||
( mask ), \
|
||||
( ( bfloat16* )ptr ) + ( n_ind * 16 ) \
|
||||
) \
|
||||
), _mm512_set1_epi32( 16 ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
// BF16 bias helper macros.
|
||||
#define SYM_QUANT_F32_F32_SCL_LOAD(src,ptr,mask,n_ind) \
|
||||
src = ( _mm512_maskz_loadu_ps \
|
||||
( \
|
||||
( mask ), \
|
||||
( ( float* )ptr ) + ( n_ind * 16 ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
// s8 SYMM static quantization scale helper macros.
|
||||
#define SYM_QUANT_BF16_F32_SCL_BCST(src,ptr,m_ind) \
|
||||
src = (__m512)( _mm512_sllv_epi32 \
|
||||
( \
|
||||
_mm512_cvtepi16_epi32 \
|
||||
( \
|
||||
_mm256_set1_epi16 \
|
||||
( \
|
||||
*(( ( bfloat16* )ptr ) + ( m_ind * grp_post_ops_attr.grp_post_op_lda )) \
|
||||
) \
|
||||
), _mm512_set1_epi32( 16 ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
// s8 SYMM static quantization scale helper macros.
|
||||
#define SYM_QUANT_F32_F32_SCL_BCST(src,ptr,m_ind) \
|
||||
src = _mm512_set1_ps \
|
||||
( \
|
||||
*(( ( float* )ptr ) + ( m_ind * grp_post_ops_attr.grp_post_op_lda )) \
|
||||
); \
|
||||
|
||||
#define CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, n_ind) \
|
||||
flt_reg_pfx ## m_ind ## n_ind = _mm512_add_ps \
|
||||
( flt_reg_pfx ## m_ind ## n_ind, \
|
||||
_mm512_mul_ps( \
|
||||
_mm512_mul_ps \
|
||||
( \
|
||||
_mm512_cvtepi32_ps( \
|
||||
int_reg_pfx ## m_ind ## p ## n_ind ), \
|
||||
b_scl ## n_ind ), \
|
||||
a_scl ## a_scl_ind ) \
|
||||
); \
|
||||
|
||||
|
||||
#define CVT_ACCUM_REG_APPLY_SCALES_4COL( flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind ) \
|
||||
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 0) \
|
||||
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 1) \
|
||||
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 2) \
|
||||
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 3) \
|
||||
|
||||
#define CVT_ACCUM_REG_APPLY_SCALES_3COL( flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind ) \
|
||||
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 0) \
|
||||
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 1) \
|
||||
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 2) \
|
||||
|
||||
#define CVT_ACCUM_REG_APPLY_SCALES_2COL( flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind ) \
|
||||
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 0) \
|
||||
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 1) \
|
||||
|
||||
#define CVT_ACCUM_REG_APPLY_SCALES_1COL( flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind ) \
|
||||
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 0) \
|
||||
|
||||
// BF16 bias helper macros.
|
||||
#define BF16_S32_BIAS_LOAD(scr,mask,n_ind) \
|
||||
scr = _mm512_cvtps_epi32 \
|
||||
( \
|
||||
|
||||
Reference in New Issue
Block a user