Implemented f32tobf16 reorder function

Description:
aocl_reorder_f32obf16 function is implemented to
reorder input weight matrix of data type float to
bfloat16.

The reordering is done to match the input requirements
of API aocl_gemm_bf16bf16f32o<f32|bf16>.

The objective of the API is to convert a model/matrix
of type f32 to bf16 and process when machine supports
bf16 FMA instruction _mm512_dpbf16_ps but the model
is still in float

Change-Id: Ib7c743d52d01a1ac09e84ac120577ec9e02f90f5
This commit is contained in:
Nallani Bhaskar
2024-10-28 07:22:35 +00:00
committed by sireesha.sanga
parent 880a971dc5
commit e6b79a4060
10 changed files with 1568 additions and 16 deletions

View File

@@ -210,6 +210,117 @@ AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32)
reorderb_nr64_bf16bf16f32of32( &b, &b_reorder, &rntm_g, lcntx_g );
}
AOCL_GEMM_REORDER_MXP(float, bfloat16, f32obf16)
{
trans_t blis_trans;
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(trans, &blis_trans);
if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) ||
(k <= 0) || (n <= 0))
{
return; // Error.
}
inc_t rs_b, cs_b;
if ((order == 'r') || (order == 'R'))
{
if ((bli_is_notrans(blis_trans) && (ldb < n)) ||
(bli_is_trans(blis_trans) && (ldb < k)))
{
return; // Error.
}
else
{
rs_b = bli_is_notrans(blis_trans) ? ldb : 1;
cs_b = bli_is_notrans(blis_trans) ? 1 : ldb;
}
}
else if ((order == 'c') || (order == 'C'))
{
if ((bli_is_notrans(blis_trans) && (ldb < k)) ||
(bli_is_trans(blis_trans) && (ldb < n)))
{
return; // Error.
}
else
{
rs_b = bli_is_notrans(blis_trans) ? 1 : ldb;
cs_b = bli_is_notrans(blis_trans) ? ldb : 1;
}
}
else
{
return; // Error
}
// Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it.
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.",
__FILE__, __LINE__);
return; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
AOCL_MATRIX_TYPE input_mat_type;
bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type);
if (input_mat_type == A_MATRIX)
{
return; // A reorder not supported.
}
#if (defined(BLIS_KERNELS_ZEN4))
if (n == 1)
{
if (rs_b == 1)
{
for (dim_t k0 = 0; k0 < k; k0++)
{
memcpy(&reorder_buf_addr[k0], (char *)(&input_buf_addr[k0]) + 2, sizeof(bfloat16));
}
}
else
{
for (dim_t k0 = 0; k0 < k; k0++)
{
memcpy(&reorder_buf_addr[k0], (char *)(&input_buf_addr[k0 * rs_b]) + 2, sizeof(bfloat16));
}
}
return;
}
#endif
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global(&rntm_g);
bli_pba_rntm_set_pba(&rntm_g);
lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(F32OBF16);
// Create dummy b_reorder obj.
lpgemm_obj_t b_reorder;
b_reorder.storage.aligned_buffer = reorder_buf_addr;
// Create dummy original b obj;
lpgemm_obj_t b;
b.storage.aligned_buffer = (void *)input_buf_addr;
b.rs = rs_b;
b.cs = cs_b;
b.width = n;
b.length = k;
reorderb_mxp_nr64_f32obf16(&b, &b_reorder, &rntm_g, lcntx_g);
}
AOCL_GEMM_UNREORDER(bfloat16, bf16bf16f32of32)
{

View File

@@ -83,6 +83,21 @@ AOCL_GEMM_REORDER(int8_t,s8s8s16os16);
AOCL_GEMM_REORDER(int8_t,u8s4s32os32);
AOCL_GEMM_REORDER(int8_t, bf16s4f32of32);
#define AOCL_GEMM_REORDER_MXP(A_type,B_type,LP_SFX) \
BLIS_EXPORT_ADDON void aocl_reorder_ ## LP_SFX \
( \
const char order, \
const char trans, \
const char mat_type, \
const A_type* input_buf_addr, \
B_type* reorder_buf_addr, \
const dim_t k, \
const dim_t n, \
const dim_t ldb \
) \
AOCL_GEMM_REORDER_MXP(float,bfloat16,f32obf16);
#define AOCL_GEMM_UNREORDER(B_type, LP_SFX) \
BLIS_EXPORT_ADDON void aocl_unreorder_ ## LP_SFX \
( \

View File

@@ -47,6 +47,7 @@
XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(S8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
XMACRO(U8S4S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(F32OBF16, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \
#define LPGEMM_BLKSZ_MAP_ZEN \
XMACRO(U8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
@@ -57,6 +58,7 @@
XMACRO(S8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
XMACRO(U8S4S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(BF16S4F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
XMACRO(F32OBF16, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
#define LPGEMM_BLKSZ_UPD_MAP_ZEN4_TO_ZEN \
XMACRO(F32F32F32OF32, 144, 8160, 512, 6, 16, 1, 6, 16, 1) \

View File

@@ -142,6 +142,7 @@ static void _lpgemm_cntx_init_func_map()
#define KMACRO(ID,FUNC_PTR) global_cntx_t_list[ID].kern_fun_ptr = FUNC_PTR;
#define PAMACRO(ID,FUNC_PTR) global_cntx_t_list[ID].packa_fun_ptr = FUNC_PTR;
#define PBMACRO(ID,FUNC_PTR) global_cntx_t_list[ID].packb_fun_ptr = FUNC_PTR;
#define PBMXPMACRO(ID, FUNC_PTR) global_cntx_t_list[ID].packb_mxp_fun_ptr = FUNC_PTR;
#define UBMACRO(ID, FUNC_PTR) global_cntx_t_list[ID].unpackb_fun_ptr = FUNC_PTR;
#define PBSMACRO(ID, FUNC_PTR) global_cntx_t_list[ID].packsclb_fun_ptr = FUNC_PTR;
#define JITMACRO(ID, FUNC_PTR) global_cntx_t_list[ID].jit_kernel = FUNC_PTR;
@@ -155,6 +156,7 @@ static void _lpgemm_cntx_init_func_map()
global_cntx_t_list[F32F32F32OF32].kern_fun_ptr = NULL;
global_cntx_t_list[BF16BF16F32OF32].kern_fun_ptr = NULL;
global_cntx_t_list[BF16S4F32OF32].kern_fun_ptr = NULL;
global_cntx_t_list[F32OBF16].kern_fun_ptr = NULL;
// Kernel dispatch object factory.
if ( bli_cpuid_is_avx512bf16_supported() == TRUE )
@@ -163,6 +165,7 @@ static void _lpgemm_cntx_init_func_map()
LPGEMM_KERN_FUNC_MAP_AVX512_VNNI_BF16
LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16
LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16
LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI_BF16
LPGEMM_UNPACKB_FUNC_MAP_AVX512_VNNI_BF16
LPGEMM_PACKSCLB_FUNC_MAP_AVX512_VNNI_BF16
@@ -210,6 +213,7 @@ static void _lpgemm_cntx_init_func_map()
LPGEMM_KERN_FUNC_MAP_AVX512_VNNI
LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI
LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI
LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI
#endif
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
@@ -226,9 +230,10 @@ static void _lpgemm_cntx_init_func_map()
LPGEMM_PACKB_FUNC_MAP_AVX2
#endif
}
// If built with a config not supporting zen3/zen4/amdzen, error out
// since reference kernels are not available.
if ( global_cntx_t_list[F32F32F32OF32].kern_fun_ptr == NULL )
if (global_cntx_t_list[F32F32F32OF32].kern_fun_ptr == NULL)
{
bli_print_msg( "AOCL_GEMM is not compiled using correct Zen config."
" Compile using zen3/zen4/amdzen config.",
@@ -237,6 +242,7 @@ static void _lpgemm_cntx_init_func_map()
}
#undef PBMACRO
#undef PBMXPMACRO
#undef PAMACRO
#undef KMACRO
}

View File

@@ -65,8 +65,8 @@
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
PAMACRO(S8S8S16OS16, packa_u8s8s16os16)
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \
@@ -78,6 +78,9 @@
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2 \
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI_BF16 \
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
#define LPGEMM_UNPACKB_FUNC_MAP_AVX512_VNNI_BF16 \
UBMACRO(BF16BF16F32OF32, unpackb_nr64_bf16bf16f32of32)
@@ -119,7 +122,10 @@
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
PAMACRO(S8S8S16OS16, packa_u8s8s16os16) \
PAMACRO(S8S8S16OS16, packa_u8s8s16os16)
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI \
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI \
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
@@ -170,10 +176,10 @@
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \
PBMACRO(BF16S4F32OF32, NULL) \
PBSMACRO(BF16S4F32OF32, NULL) \
PBSMACRO(BF16S4F32OF32, NULL)
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_TO_AVX2 \
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512 \
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
#define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512 \
UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \

View File

@@ -254,11 +254,13 @@ void unreorderb_nr64_bf16bf16f32of32
}
}
void reorderb_nr64_bf16s4f32of32(
lpgemm_obj_t *b,
lpgemm_obj_t *b_reorder,
rntm_t *rntm,
lpgemm_cntx_t *lcntx)
void reorderb_nr64_bf16s4f32of32
(
lpgemm_obj_t *b,
lpgemm_obj_t *b_reorder,
rntm_t *rntm,
lpgemm_cntx_t *lcntx
)
{
dim_t NC = lcntx->blksz.NC;
dim_t KC = lcntx->blksz.KC;
@@ -377,3 +379,129 @@ void reorderb_nr64_bf16s4f32of32(
b_reorder->cs = cs_b_reorder;
b_reorder->mtag = REORDERED;
}
void reorderb_mxp_nr64_f32obf16
(
lpgemm_obj_t *b,
lpgemm_obj_t *b_reorder,
rntm_t *rntm,
lpgemm_cntx_t *lcntx
)
{
dim_t NC = lcntx->blksz.NC;
dim_t KC = lcntx->blksz.KC;
dim_t NR = lcntx->blksz.NR;
// Extracting the matrix properties from the lpgemm object
dim_t rs_b = b->rs;
dim_t cs_b = b->cs;
dim_t n = b->width;
dim_t k = b->length;
dim_t rs_b_reorder;
dim_t cs_b_reorder;
// k needs to be a multiple of 2 so that it can be used with dpbf
// instruction. Padding is added in cases this condition is not
// satisfied, and therefore the k offset used for packed/reordered
// buffer needs to be updated.
dim_t k_updated = k;
k_updated += (k_updated & 0x1);
dim_t n_threads = bli_rntm_num_threads(rntm);
n_threads = (n_threads > 0) ? n_threads : 1;
#ifdef BLIS_ENABLE_OPENMP
_Pragma("omp parallel num_threads(n_threads)")
{
// Initialise a local thrinfo obj for work split across threads.
thrinfo_t thread_jc;
bli_thrinfo_set_n_way(n_threads, &thread_jc);
bli_thrinfo_set_work_id(omp_get_thread_num(), &thread_jc);
#else
{
// Initialise a local thrinfo obj for work split across threads.
thrinfo_t thread_jc;
bli_thrinfo_set_n_way(1, &thread_jc);
bli_thrinfo_set_work_id(0, &thread_jc);
#endif
// Compute the JC loop thread range for the current thread.
dim_t jc_start, jc_end;
bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end);
for (dim_t jc = jc_start; jc < jc_end; jc += NC)
{
dim_t nc0 = bli_min((jc_end - jc), NC);
dim_t jc_cur_loop = jc;
dim_t jc_cur_loop_rem = 0;
dim_t n_sub_updated;
get_B_panel_reordered_start_offset_width(
jc, n, NC, 16,
&jc_cur_loop, &jc_cur_loop_rem,
&nc0, &n_sub_updated);
for (dim_t pc = 0; pc < k; pc += KC)
{
dim_t kc0 = bli_min((k - pc), KC);
// k needs to be a multiple of 2 so that it can be used with dpbf
// instruction. Padding is added in cases this condition is not
// satisfied, and therefore the k offset used for packed/reordered
// buffer needs to be updated.
dim_t kc0_updated = kc0;
kc0_updated += (kc0_updated & 0x1);
// The offsets are calculated in such a way that it resembles
// the reorder buffer traversal in single threaded reordering.
// The panel boundaries (KCxNC) remain as it is accessed in
// single thread, and as a consequence a thread with jc_start
// inside the panel cannot consider NC range for reorder. It
// has to work with NC' < NC, and the offset is calulated using
// prev NC panels spanning k dim + cur NC panel spaning pc loop
// cur iteration + (NC - NC') spanning current kc0 (<= KC).
//
// Eg: Consider the following reordered buffer diagram:
// t1 t2
// | |
// | |..NC..|
// | | |
// |.NC. |.NC. |NC'|NC"
// pc=0-+-----+-----+---+--+
// KC| | | | |
// | 1 | 3 | 5 |
// pc=KC-+-----+-----+---st-+
// KC| | | | |
// | 2 | 4 | 6 | 7|
// pc=k=2KC-+-----+-----+---+--+
// |jc=0 |jc=NC|jc=2NC|
//
// The numbers 1,2..6,7 denotes the order in which reordered
// KCxNC blocks are stored in memory, ie: block 1 followed by 2
// followed by 3, etc. Given two threads t1 and t2, and t2 needs
// to acces point st in the reorder buffer to write the data:
// The offset calulation logic will be:
// jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC,
// n_sub_updated = NC, k = 2KC, kc0_updated = KC
//
// st = ( jc_cur_loop * k ) <traverse blocks 1,2,3,4>
// + ( n_sub_updated * pc ) <traverse block 5>
// + ( NC' * kc0_updated) <traverse block 6>
((pack_f32bf16)lcntx->packb_mxp_fun_ptr)(
(((bfloat16 *)b_reorder->storage.aligned_buffer) +
(jc_cur_loop * k_updated) + (n_sub_updated * pc) +
(jc_cur_loop_rem * kc0_updated)),
(((float *)b->storage.aligned_buffer) +
(rs_b * pc) + (jc * cs_b)),
rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder);
}
adjust_B_panel_reordered_jc(&jc, jc_cur_loop);
}
}
b_reorder->rs = rs_b_reorder;
b_reorder->cs = cs_b_reorder;
b_reorder->mtag = REORDERED;
}

View File

@@ -53,6 +53,14 @@ void reorderb_nr64_bf16s4f32of32
lpgemm_cntx_t* lcntx
);
void reorderb_mxp_nr64_f32obf16
(
lpgemm_obj_t * b,
lpgemm_obj_t * b_reorder,
rntm_t* rntm,
lpgemm_cntx_t* lcntx
);
void unreorderb_nr64_bf16bf16f32of32
(
lpgemm_obj_t * b,

View File

@@ -69,10 +69,11 @@ typedef enum
BF16BF16F32OF32 = 3, // bf16 - A, bf16 - B, float - C
S8S8S32OS32 = 4, // int8_t - A, int8_t - B, int32_t - C
S8S8S16OS16 = 5, // int8_t - A, int8_t - B, int16_t - C
U8S4S32OS32 = 6, // Only used for reordering int4_t B matrix.
BF16S4F32OF32 = 7 // Only used for reordering int4_t B matrix.
U8S4S32OS32 = 6, // Only used for reordering int4_t B matrix.
BF16S4F32OF32 = 7, // Only used for reordering int4_t B matrix.
F32OBF16 = 8 // Only used for reordering input float matrix to bf16 reorder
} AOCL_OPERATION_TYPE;
#define AOCL_OPERATION_TYPE_LEN 8
#define AOCL_OPERATION_TYPE_LEN 9
typedef enum
{
@@ -160,6 +161,7 @@ typedef struct
lpgemm_block_size_t blksz;
void_fp kern_fun_ptr;
void_fp packa_fun_ptr;
void_fp packb_mxp_fun_ptr;
void_fp packb_fun_ptr;
void_fp unpackb_fun_ptr;
void_fp packsclb_fun_ptr;

View File

@@ -47,7 +47,20 @@ BLIS_INLINE dim_t get_packb_bf16bf16f32of32_min_NR()
return 16;
}
typedef void (*pack_s4bf16)(
typedef void (*pack_f32bf16)
(
bfloat16*,
const float*,
const dim_t,
const dim_t,
const dim_t,
const dim_t,
dim_t*,
dim_t*
);
typedef void (*pack_s4bf16)
(
bfloat16 *,
const int8_t *,
const dim_t,
@@ -93,6 +106,18 @@ typedef void (*pack_s4)
lpgemm_pre_op*
);
void packb_mxp_nr64_f32obf16
(
bfloat16 *pack_b_buffer_bf16bf16f32of32,
const float *b,
const dim_t rs_b,
const dim_t cs_b,
const dim_t NC,
const dim_t KC,
dim_t *rs_p,
dim_t *cs_p
);
void packb_nr64_bf16bf16f32of32
(
bfloat16* pack_b_buffer_bf16bf16f32of32,

File diff suppressed because it is too large Load Diff