Implemented group level static quantization for s8s8s32of32|bf16 APIs

Details:
- Group quantization is technique to improve accuracy
  where scale factors to quantize inputs and weights
  varies at group level instead of per channel
  and per tensor level.
- Added new bench files to test GEMM with symmetric static
  quantization.
- Added new get_size and reorder functions to account for
  storing sum of col-values separately per group.
- Added new framework, kernels to support the same.
- The scalefactors could be of type float or bf16.

AMD-Internal:[SWLCSG-3274]

Change-Id: I3e69ecd56faa2679a4f084031d35ffb76556230f
This commit is contained in:
Meghana Vankadari
2025-02-19 05:59:07 +00:00
parent 99770558bb
commit 7243a5d521
25 changed files with 41770 additions and 52 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -153,7 +153,31 @@
F32_S32_BETA_OP(c_int32_ ## m_ind ## p2,m_ir,m_ind,2,scratch1,scratch2); \
F32_S32_BETA_OP(c_int32_ ## m_ind ## p3,m_ir,m_ind,3,scratch1,scratch2); \
// Downscale BF16 beta op
// Downscale F32 beta op
#define F32_F32_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \
scratch1 = _mm512_loadu_ps( c + ( rs_c * ( m_ir + m_ind ) ) + ( n_ind * 16 ) ); \
F32_BETA_FMA(reg,scratch1,scratch2) \
#define F32_F32_BETA_OP_NLT16F_MASK(ptr, mask, reg, m_ir, m_ind, n_ind, scratch1, scratch2) \
scratch1 = _mm512_maskz_loadu_ps( mask, ptr + ( rs_c * ( m_ir + m_ind ) ) + ( n_ind * 16 ) ); \
F32_BETA_FMA(reg, scratch1, scratch2) \
#define F32_F32_BETA_OP4(m_ir,m_ind,scratch1,scratch2) \
F32_F32_BETA_OP(acc_ ## m_ind ## 0,m_ir,m_ind,0,scratch1,scratch2); \
F32_F32_BETA_OP(acc_ ## m_ind ## 1,m_ir,m_ind,1,scratch1,scratch2); \
F32_F32_BETA_OP(acc_ ## m_ind ## 2,m_ir,m_ind,2,scratch1,scratch2); \
F32_F32_BETA_OP(acc_ ## m_ind ## 3,m_ir,m_ind,3,scratch1,scratch2); \
#define F32_F32_BETA_OP3(m_ir,m_ind,scratch1,scratch2) \
F32_F32_BETA_OP(acc_ ## m_ind ## 0,m_ir,m_ind,0,scratch1,scratch2); \
F32_F32_BETA_OP(acc_ ## m_ind ## 1,m_ir,m_ind,1,scratch1,scratch2); \
F32_F32_BETA_OP(acc_ ## m_ind ## 2,m_ir,m_ind,2,scratch1,scratch2); \
#define F32_F32_BETA_OP2(m_ir,m_ind,scratch1,scratch2) \
F32_F32_BETA_OP(acc_ ## m_ind ## 0,m_ir,m_ind,0,scratch1,scratch2); \
F32_F32_BETA_OP(acc_ ## m_ind ## 1,m_ir,m_ind,1,scratch1,scratch2); \
// Downscale BF16 beta op
#define BF16_S32_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \
scratch1 = \
_mm512_cvtps_epi32 \
@@ -189,6 +213,60 @@
BF16_S32_BETA_OP(c_int32_ ## m_ind ## p2,m_ir,m_ind,2,scratch1,scratch2); \
BF16_S32_BETA_OP(c_int32_ ## m_ind ## p3,m_ir,m_ind,3,scratch1,scratch2); \
#define F32_BETA_FMA(reg,scratch1,scratch2) \
scratch1 = _mm512_mul_ps( scratch2, scratch1 ); \
reg = _mm512_add_ps( scratch1, reg ); \
#define BF16_F32_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \
scratch1 = \
(__m512)_mm512_sllv_epi32 \
( \
_mm512_cvtepi16_epi32 \
( \
_mm256_maskz_loadu_epi16 \
( \
0xFFFF, \
( bfloat16* )post_ops_attr.buf_downscale + \
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
), _mm512_set1_epi32( 16 ) \
); \
F32_BETA_FMA(reg,scratch1,scratch2) \
#define BF16_F32_BETA_OP_NLT16F_MASK(lmask,reg,m_ind,n_ind,scratch1,scratch2) \
scratch1 = (__m512)_mm512_sllv_epi32 \
( \
_mm512_cvtepi16_epi32 \
( \
_mm256_maskz_loadu_epi16 \
( \
lmask, \
( bfloat16* )post_ops_attr.buf_downscale + \
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) \
) \
), _mm512_set1_epi32( 16 ) \
); \
F32_BETA_FMA(reg,scratch1,scratch2) \
#define BF16_F32_BETA_OP4(m_ir,m_ind,scratch1,scratch2) \
BF16_F32_BETA_OP(acc_ ## m_ind ## 0,m_ir,m_ind,0,scratch1,scratch2); \
BF16_F32_BETA_OP(acc_ ## m_ind ## 1,m_ir,m_ind,1,scratch1,scratch2); \
BF16_F32_BETA_OP(acc_ ## m_ind ## 2,m_ir,m_ind,2,scratch1,scratch2); \
BF16_F32_BETA_OP(acc_ ## m_ind ## 3,m_ir,m_ind,3,scratch1,scratch2); \
#define BF16_F32_BETA_OP3(m_ir,m_ind,scratch1,scratch2) \
BF16_F32_BETA_OP(acc_ ## m_ind ## 0,m_ir,m_ind,0,scratch1,scratch2); \
BF16_F32_BETA_OP(acc_ ## m_ind ## 1,m_ir,m_ind,1,scratch1,scratch2); \
BF16_F32_BETA_OP(acc_ ## m_ind ## 2,m_ir,m_ind,2,scratch1,scratch2); \
#define BF16_F32_BETA_OP2(m_ir,m_ind,scratch1,scratch2) \
BF16_F32_BETA_OP(acc_ ## m_ind ## 0,m_ir,m_ind,0,scratch1,scratch2); \
BF16_F32_BETA_OP(acc_ ## m_ind ## 1,m_ir,m_ind,1,scratch1,scratch2); \
// Default n < 16 beta macro
#define S32_S32_BETA_OP_NLT16F(reg,buf_,scratch1,scratch2) \
scratch1 = _mm512_loadu_si512( buf_ ); \
@@ -307,8 +385,82 @@
) \
); \
// s8 SYMM static quantization scale helper macros.
#define SYM_QUANT_BF16_F32_SCL_LOAD(src,ptr,mask,n_ind) \
src = (__m512)( _mm512_sllv_epi32 \
( \
_mm512_cvtepi16_epi32 \
( \
_mm256_maskz_loadu_epi16 \
( \
( mask ), \
( ( bfloat16* )ptr ) + ( n_ind * 16 ) \
) \
), _mm512_set1_epi32( 16 ) \
) \
); \
// BF16 bias helper macros.
#define SYM_QUANT_F32_F32_SCL_LOAD(src,ptr,mask,n_ind) \
src = ( _mm512_maskz_loadu_ps \
( \
( mask ), \
( ( float* )ptr ) + ( n_ind * 16 ) \
) \
); \
// s8 SYMM static quantization scale helper macros.
#define SYM_QUANT_BF16_F32_SCL_BCST(src,ptr,m_ind) \
src = (__m512)( _mm512_sllv_epi32 \
( \
_mm512_cvtepi16_epi32 \
( \
_mm256_set1_epi16 \
( \
*(( ( bfloat16* )ptr ) + ( m_ind * grp_post_ops_attr.grp_post_op_lda )) \
) \
), _mm512_set1_epi32( 16 ) \
) \
); \
// s8 SYMM static quantization scale helper macros.
#define SYM_QUANT_F32_F32_SCL_BCST(src,ptr,m_ind) \
src = _mm512_set1_ps \
( \
*(( ( float* )ptr ) + ( m_ind * grp_post_ops_attr.grp_post_op_lda )) \
); \
#define CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, n_ind) \
flt_reg_pfx ## m_ind ## n_ind = _mm512_add_ps \
( flt_reg_pfx ## m_ind ## n_ind, \
_mm512_mul_ps( \
_mm512_mul_ps \
( \
_mm512_cvtepi32_ps( \
int_reg_pfx ## m_ind ## p ## n_ind ), \
b_scl ## n_ind ), \
a_scl ## a_scl_ind ) \
); \
#define CVT_ACCUM_REG_APPLY_SCALES_4COL( flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind ) \
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 0) \
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 1) \
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 2) \
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 3) \
#define CVT_ACCUM_REG_APPLY_SCALES_3COL( flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind ) \
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 0) \
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 1) \
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 2) \
#define CVT_ACCUM_REG_APPLY_SCALES_2COL( flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind ) \
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 0) \
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 1) \
#define CVT_ACCUM_REG_APPLY_SCALES_1COL( flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind ) \
CVT_ACCUM_REG_APPLY_SCALES_M_N(flt_reg_pfx, int_reg_pfx, a_scl_ind, m_ind, 0) \
// BF16 bias helper macros.
#define BF16_S32_BIAS_LOAD(scr,mask,n_ind) \
scr = _mm512_cvtps_epi32 \
( \