mirror of
https://github.com/amd/blis.git
synced 2026-05-04 14:31:12 +00:00
Added low precision GEMM - bf16bf16f32of32
Feature Addition: Added a new variant of low precision GEMM to addon - BFloat16. The kernel takes bf16 type inputs and perform BF16 GEMM operations. The intermediate accumulation and output are in float. 1. Compute kernels will perform computations only if B matrix is reordered in accordance with the usage of AVX-512 BF16 instruction - dpbf16_ps 2. Kernel for packing B matrix is provided Change-Id: If5d08213068869eff060c9998596d2d2703a6793
This commit is contained in:
@@ -334,6 +334,58 @@ BLIS_INLINE void lpgemm_u8s8s32o32_get_threading
|
||||
}
|
||||
}
|
||||
|
||||
BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
|
||||
(
|
||||
dim_t* n_threads,
|
||||
dim_t* ic_ways,
|
||||
dim_t* jc_ways,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
rntm_t* rntm_g
|
||||
)
|
||||
{
|
||||
*n_threads = bli_rntm_num_threads( rntm_g );
|
||||
*jc_ways = bli_rntm_jc_ways( rntm_g );
|
||||
*ic_ways = bli_rntm_ic_ways( rntm_g );
|
||||
|
||||
if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) )
|
||||
{
|
||||
// If BLIS_IC_NT or JC_NT are set.
|
||||
// Default cases.
|
||||
*ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1;
|
||||
*jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1;
|
||||
|
||||
*n_threads = ( *jc_ways ) * ( *ic_ways );
|
||||
}
|
||||
else if ( ( *n_threads ) > 1 )
|
||||
{
|
||||
|
||||
dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 );
|
||||
|
||||
if ( n <= NR )
|
||||
{
|
||||
// If n is less than micro panel dimension, allocating all threads
|
||||
// to ic resulted in gains.
|
||||
( *ic_ways ) = ( *n_threads );
|
||||
( *jc_ways ) = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
|
||||
bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways );
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Setting all the values to 1 in case n_threads <= 1. This ensures
|
||||
// the threading parameters are valid.
|
||||
*n_threads = 1;
|
||||
*jc_ways = 1;
|
||||
*ic_ways = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Some aspects of sgemm smart threading incorporated here. Eventually this
|
||||
// will be redirected to the sgemm smart threading API.
|
||||
BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
|
||||
@@ -496,6 +548,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
|
||||
GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32)
|
||||
|
||||
#else
|
||||
@@ -564,6 +617,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
|
||||
GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16)
|
||||
GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32)
|
||||
GEN_LPGEMM_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
GEN_LPGEMM_DECORATOR(float,float,float,f32f32f32of32)
|
||||
|
||||
#endif
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_post_ops.h"
|
||||
#include "aocl_bf16_type.h"
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
|
||||
@@ -64,6 +65,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
|
||||
GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32)
|
||||
|
||||
#else
|
||||
@@ -92,6 +94,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
|
||||
GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s16o16)
|
||||
GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32)
|
||||
GEN_LPGEMM_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
GEN_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32)
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user