Added low precision GEMM - bf16bf16f32of32

Feature Addition: Added a new variant of low precision GEMM to addon - BFloat16. The kernel takes bf16 type inputs and perform BF16 GEMM operations. The intermediate accumulation and output are in float.

1. Compute kernels will perform computations only if B matrix is reordered in accordance with the usage of AVX-512 BF16 instruction - dpbf16_ps
2. Kernel for packing B matrix is provided

Change-Id: If5d08213068869eff060c9998596d2d2703a6793
This commit is contained in:
eashdash
2022-08-17 08:25:30 +00:00
committed by Eashan Dash
parent 219c41ded9
commit 4e3e00fb7e
23 changed files with 7942 additions and 734 deletions

View File

@@ -334,6 +334,58 @@ BLIS_INLINE void lpgemm_u8s8s32o32_get_threading
}
}
BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
(
dim_t* n_threads,
dim_t* ic_ways,
dim_t* jc_ways,
dim_t m,
dim_t n,
dim_t k,
rntm_t* rntm_g
)
{
*n_threads = bli_rntm_num_threads( rntm_g );
*jc_ways = bli_rntm_jc_ways( rntm_g );
*ic_ways = bli_rntm_ic_ways( rntm_g );
if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) )
{
// If BLIS_IC_NT or JC_NT are set.
// Default cases.
*ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1;
*jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1;
*n_threads = ( *jc_ways ) * ( *ic_ways );
}
else if ( ( *n_threads ) > 1 )
{
dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 );
if ( n <= NR )
{
// If n is less than micro panel dimension, allocating all threads
// to ic resulted in gains.
( *ic_ways ) = ( *n_threads );
( *jc_ways ) = 1;
}
else
{
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways );
}
}
else
{
// Setting all the values to 1 in case n_threads <= 1. This ensures
// the threading parameters are valid.
*n_threads = 1;
*jc_ways = 1;
*ic_ways = 1;
}
}
// Some aspects of sgemm smart threading incorporated here. Eventually this
// will be redirected to the sgemm smart threading API.
BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
@@ -496,6 +548,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16)
GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32)
GEN_LPGEMM_OPENMP_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32)
GEN_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32)
#else
@@ -564,6 +617,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16)
GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32)
GEN_LPGEMM_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32)
GEN_LPGEMM_DECORATOR(float,float,float,f32f32f32of32)
#endif

View File

@@ -37,6 +37,7 @@
#include "lpgemm_types.h"
#include "lpgemm_post_ops.h"
#include "aocl_bf16_type.h"
#ifdef BLIS_ENABLE_OPENMP
@@ -64,6 +65,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16)
GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32)
GEN_LPGEMM_OPENMP_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32)
GEN_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32)
#else
@@ -92,6 +94,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s16o16)
GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32)
GEN_LPGEMM_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32)
GEN_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32)
#endif