Added low precision GEMM - bf16bf16f32of32

Feature Addition: Added a new variant of low precision GEMM to addon - BFloat16. The kernel takes bf16 type inputs and perform BF16 GEMM operations. The intermediate accumulation and output are in float.

1. Compute kernels will perform computations only if B matrix is reordered in accordance with the usage of AVX-512 BF16 instruction - dpbf16_ps
2. Kernel for packing B matrix is provided

Change-Id: If5d08213068869eff060c9998596d2d2703a6793
This commit is contained in:
eashdash
2022-08-17 08:25:30 +00:00
committed by Eashan Dash
parent 219c41ded9
commit 4e3e00fb7e
23 changed files with 7942 additions and 734 deletions

View File

@@ -36,6 +36,7 @@
#define BLIS_LPGEMM_KERN_H
#include "lpgemm_post_ops.h"
#include "aocl_bf16_type.h"
#define LPGEMM_MAIN_KERN(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
@@ -63,6 +64,7 @@ void lpgemm_rowvar_ ## LP_SFX \
LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64);
LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32);
LPGEMM_MAIN_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x64);
#define LPGEMM_M_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
@@ -94,6 +96,12 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32);
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32);
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32);
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x64);
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x64);
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x64);
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x64);
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x64);
#define LPGEMM_N_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
@@ -122,6 +130,10 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48);
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16);
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x16);
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x32);
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x48);
#define LPGEMM_N_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
@@ -149,6 +161,8 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16);
LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16);
LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6xlt16);
#define LPGEMM_MN_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
@@ -189,6 +203,22 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16);
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16);
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x16);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x16);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x16);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x16);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x16);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x32);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x32);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x32);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x32);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x32);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x48);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x48);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x48);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x48);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x48);
#define LPGEMM_MN_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
@@ -220,4 +250,10 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1xlt16);
#endif //BLIS_LPGEMM_KERN_H