mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Added AVX512 and AVX2 FP32 RD Kernels
- Added FP32 RD (dot-product) kernels for both, AVX512 and AVX2 ISAs.
- The FP32 AVX512 primary RD kernel has blocking of dimensions 6x64
(MRxNR) whereas it is 6x16 (MRxNR) for the AVX2 primary RD kernel.
- Updatd f32 framework to accomodate rd kernels in case of B trans
with thresholds
- Updated data gen python script
TODO:
- Post-Ops not yet supported.
Change-Id: Ibf282741f58a1446321273d5b8044db993f23714
This commit is contained in:
committed by
Nallani Bhaskar
parent
e0b86c69af
commit
c68c258fad
@@ -399,8 +399,7 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
if ( c_downscale < F32 )
|
||||
{
|
||||
post_ops_attr.buf_downscale = c;
|
||||
}
|
||||
else
|
||||
}else
|
||||
{
|
||||
post_ops_attr.buf_downscale = NULL;
|
||||
}
|
||||
@@ -415,6 +414,21 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
dim_t jc_start, jc_end;
|
||||
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
|
||||
|
||||
// Compute the IC loop thread range for the current thread.
|
||||
dim_t ic_start, ic_end;
|
||||
bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end );
|
||||
|
||||
// Update the kernel pointer with right kernel
|
||||
lpgemm_rowvar_f32 ker_ptr = (lpgemm_rowvar_f32) lcntx->kern_fun_ptr;
|
||||
|
||||
// Avoid packing of B in transb cases where rd kernels performs better
|
||||
// than rv + pack. rv kernel calls rd when rs_b==1.
|
||||
if( (n < 64) && (rs_b == 1) &&
|
||||
(mtag_b == PACK) && (mtag_a == UNPACKED))
|
||||
{
|
||||
mtag_b = UNPACKED;
|
||||
}
|
||||
|
||||
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
|
||||
{
|
||||
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
|
||||
@@ -545,9 +559,6 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
ps_b_use = 1;
|
||||
}
|
||||
|
||||
dim_t ic_start, ic_end;
|
||||
bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end );
|
||||
|
||||
for ( dim_t ic = ic_start; ic < ic_end; ic += MC )
|
||||
{
|
||||
dim_t mc0 = bli_min( ( ic_end - ic ), MC );
|
||||
@@ -584,7 +595,7 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
(
|
||||
pack_a_buffer_f32f32f32of32,
|
||||
( a + ( rs_a * ic ) + ( pc * cs_a) ),
|
||||
rs_a, cs_a,
|
||||
rs_a, cs_a,
|
||||
mc0, kc0,
|
||||
&rs_a_use, &cs_a_use
|
||||
);
|
||||
@@ -611,8 +622,8 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
post_ops_attr.post_op_c_j = ( jc + jr );
|
||||
post_ops_attr.rs_c_downscale = rs_c_downscale;
|
||||
|
||||
// Reordered/unpacked B, reordered/unpacked A.
|
||||
( ( lpgemm_rowvar_f32 )lcntx->kern_fun_ptr )
|
||||
// Call the micro-kernel
|
||||
ker_ptr
|
||||
(
|
||||
mc0, nr0, kc0,
|
||||
( float* )a_use, rs_a_use, cs_a_use, ps_a_use,
|
||||
|
||||
@@ -89,7 +89,15 @@ void lpgemm_rowvar_ ## LP_SFX \
|
||||
LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64);
|
||||
LPGEMM_MAIN_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x64);
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m);
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m_rd);
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x8m_rd);
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x4m_rd);
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x2m_rd);
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x1m_rd);
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m);
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m_rd);
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x48m_rd);
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x32m_rd);
|
||||
LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64);
|
||||
|
||||
|
||||
@@ -143,6 +151,51 @@ void lpgemm_rowvar_ ## LP_SFX \
|
||||
|
||||
LPGEMM_MAIN_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_6x64m_sym_quant);
|
||||
|
||||
#define LPGEMM_M_RD_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
const dim_t k0, \
|
||||
const A_type* a, \
|
||||
const dim_t rs_a, \
|
||||
const dim_t cs_a, \
|
||||
const B_type* b, \
|
||||
const dim_t rs_b, \
|
||||
const dim_t cs_b, \
|
||||
C_type* c, \
|
||||
const dim_t rs_c, \
|
||||
const dim_t cs_c, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
lpgemm_post_op* post_ops_list, \
|
||||
lpgemm_post_op_attr post_ops_attr \
|
||||
) \
|
||||
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x64_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x64_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x64_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x64_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x64_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x48_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x48_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x48_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x48_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x48_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x32_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x32_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x32_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x32_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x32_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x16_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x16_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x8_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x8_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x4_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x4_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x2_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_2x1_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x2_rd);
|
||||
LPGEMM_M_RD_FRINGE_KERN(float,float,float,f32f32f32of32_1x1_rd);
|
||||
|
||||
#define LPGEMM_M_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
@@ -256,7 +309,7 @@ void lpgemm_rowvar_ ## LP_SFX \
|
||||
const B_type* b, \
|
||||
const dim_t rs_b, \
|
||||
const dim_t cs_b, \
|
||||
float* c, \
|
||||
float* c, \
|
||||
const dim_t rs_c, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
|
||||
Reference in New Issue
Block a user