mirror of
https://github.com/amd/blis.git
synced 2026-05-05 15:01:13 +00:00
Added low precision GEMM - bf16bf16f32of32
Feature Addition: Added a new variant of low precision GEMM to addon - BFloat16. The kernel takes bf16 type inputs and perform BF16 GEMM operations. The intermediate accumulation and output are in float. 1. Compute kernels will perform computations only if B matrix is reordered in accordance with the usage of AVX-512 BF16 instruction - dpbf16_ps 2. Kernel for packing B matrix is provided Change-Id: If5d08213068869eff060c9998596d2d2703a6793
This commit is contained in:
@@ -36,6 +36,7 @@
|
||||
#define BLIS_LPGEMM_KERN_H
|
||||
|
||||
#include "lpgemm_post_ops.h"
|
||||
#include "aocl_bf16_type.h"
|
||||
|
||||
#define LPGEMM_MAIN_KERN(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
@@ -63,6 +64,7 @@ void lpgemm_rowvar_ ## LP_SFX \
|
||||
|
||||
LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64);
|
||||
LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32);
|
||||
LPGEMM_MAIN_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x64);
|
||||
|
||||
#define LPGEMM_M_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
@@ -94,6 +96,12 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32);
|
||||
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32);
|
||||
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32);
|
||||
|
||||
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x64);
|
||||
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x64);
|
||||
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x64);
|
||||
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x64);
|
||||
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x64);
|
||||
|
||||
#define LPGEMM_N_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
@@ -122,6 +130,10 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48);
|
||||
|
||||
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16);
|
||||
|
||||
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x16);
|
||||
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x32);
|
||||
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x48);
|
||||
|
||||
#define LPGEMM_N_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
@@ -149,6 +161,8 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16);
|
||||
|
||||
LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16);
|
||||
|
||||
LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6xlt16);
|
||||
|
||||
#define LPGEMM_MN_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
@@ -189,6 +203,22 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16);
|
||||
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16);
|
||||
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16);
|
||||
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x16);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x16);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x16);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x16);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x16);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x32);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x32);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x32);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x32);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x32);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x48);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x48);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x48);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x48);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x48);
|
||||
|
||||
#define LPGEMM_MN_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
@@ -220,4 +250,10 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16);
|
||||
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1xlt16);
|
||||
|
||||
#endif //BLIS_LPGEMM_KERN_H
|
||||
|
||||
Reference in New Issue
Block a user