Added AVX512 and AVX2 FP32 RD Kernels

- Added FP32 RD (dot-product) kernels for both, AVX512 and AVX2 ISAs.
- The FP32 AVX512 primary RD kernel has blocking of dimensions 6x64
  (MRxNR) whereas it is 6x16 (MRxNR) for the AVX2 primary RD kernel.
- Updatd f32 framework to accomodate rd kernels in case of B trans
  with thresholds
- Updated data gen python script
TODO:
    - Post-Ops not yet supported.

Change-Id: Ibf282741f58a1446321273d5b8044db993f23714
This commit is contained in:
Arnav Sharma
2025-03-28 12:18:53 +05:30
committed by Nallani Bhaskar
parent e0b86c69af
commit c68c258fad
9 changed files with 9863 additions and 10 deletions

View File

@@ -399,8 +399,7 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
if ( c_downscale < F32 )
{
post_ops_attr.buf_downscale = c;
}
else
}else
{
post_ops_attr.buf_downscale = NULL;
}
@@ -415,6 +414,21 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
dim_t jc_start, jc_end;
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
// Compute the IC loop thread range for the current thread.
dim_t ic_start, ic_end;
bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end );
// Update the kernel pointer with right kernel
lpgemm_rowvar_f32 ker_ptr = (lpgemm_rowvar_f32) lcntx->kern_fun_ptr;
// Avoid packing of B in transb cases where rd kernels performs better
// than rv + pack. rv kernel calls rd when rs_b==1.
if( (n < 64) && (rs_b == 1) &&
(mtag_b == PACK) && (mtag_a == UNPACKED))
{
mtag_b = UNPACKED;
}
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
{
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
@@ -545,9 +559,6 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
ps_b_use = 1;
}
dim_t ic_start, ic_end;
bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end );
for ( dim_t ic = ic_start; ic < ic_end; ic += MC )
{
dim_t mc0 = bli_min( ( ic_end - ic ), MC );
@@ -584,7 +595,7 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
(
pack_a_buffer_f32f32f32of32,
( a + ( rs_a * ic ) + ( pc * cs_a) ),
rs_a, cs_a,
rs_a, cs_a,
mc0, kc0,
&rs_a_use, &cs_a_use
);
@@ -611,8 +622,8 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
post_ops_attr.post_op_c_j = ( jc + jr );
post_ops_attr.rs_c_downscale = rs_c_downscale;
// Reordered/unpacked B, reordered/unpacked A.
( ( lpgemm_rowvar_f32 )lcntx->kern_fun_ptr )
// Call the micro-kernel
ker_ptr
(
mc0, nr0, kc0,
( float* )a_use, rs_a_use, cs_a_use, ps_a_use,

View File

@@ -89,7 +89,15 @@ void lpgemm_rowvar_ ## LP_SFX \
LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64);
LPGEMM_MAIN_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x64);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m_rd);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x8m_rd);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x4m_rd);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x2m_rd);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x1m_rd);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m_rd);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x48m_rd);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x32m_rd);
LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64);
@@ -143,6 +151,51 @@ void lpgemm_rowvar_ ## LP_SFX \
LPGEMM_MAIN_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_6x64m_sym_quant);
#define LPGEMM_M_RD_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
const dim_t k0, \
const A_type* a, \
const dim_t rs_a, \
const dim_t cs_a, \
const B_type* b, \
const dim_t rs_b, \
const dim_t cs_b, \
C_type* c, \
const dim_t rs_c, \
const dim_t cs_c, \
const C_type alpha, \
const C_type beta, \
lpgemm_post_op* post_ops_list, \
lpgemm_post_op_attr post_ops_attr \
) \
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x64_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x64_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x64_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x64_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x64_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x48_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x48_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x48_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x48_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x48_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x32_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x32_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x32_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x32_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x32_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x16_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x16_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x8_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x8_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x4_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x4_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x2_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x1_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x2_rd);
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x1_rd);
#define LPGEMM_M_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
@@ -256,7 +309,7 @@ void lpgemm_rowvar_ ## LP_SFX \
const B_type* b, \
const dim_t rs_b, \
const dim_t cs_b, \
float* c, \
float* c, \
const dim_t rs_c, \
const C_type alpha, \
const C_type beta, \