mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Implemented f32tobf16 reorder function
Description: aocl_reorder_f32obf16 function is implemented to reorder input weight matrix of data type float to bfloat16. The reordering is done to match the input requirements of API aocl_gemm_bf16bf16f32o<f32|bf16>. The objective of the API is to convert a model/matrix of type f32 to bf16 and process when machine supports bf16 FMA instruction _mm512_dpbf16_ps but the model is still in float Change-Id: Ib7c743d52d01a1ac09e84ac120577ec9e02f90f5
This commit is contained in:
committed by
sireesha.sanga
parent
880a971dc5
commit
e6b79a4060
@@ -210,6 +210,117 @@ AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32)
|
||||
reorderb_nr64_bf16bf16f32of32( &b, &b_reorder, &rntm_g, lcntx_g );
|
||||
}
|
||||
|
||||
AOCL_GEMM_REORDER_MXP(float, bfloat16, f32obf16)
|
||||
{
|
||||
trans_t blis_trans;
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(trans, &blis_trans);
|
||||
|
||||
if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) ||
|
||||
(k <= 0) || (n <= 0))
|
||||
{
|
||||
return; // Error.
|
||||
}
|
||||
|
||||
inc_t rs_b, cs_b;
|
||||
if ((order == 'r') || (order == 'R'))
|
||||
{
|
||||
if ((bli_is_notrans(blis_trans) && (ldb < n)) ||
|
||||
(bli_is_trans(blis_trans) && (ldb < k)))
|
||||
{
|
||||
return; // Error.
|
||||
}
|
||||
else
|
||||
{
|
||||
rs_b = bli_is_notrans(blis_trans) ? ldb : 1;
|
||||
cs_b = bli_is_notrans(blis_trans) ? 1 : ldb;
|
||||
}
|
||||
}
|
||||
else if ((order == 'c') || (order == 'C'))
|
||||
{
|
||||
if ((bli_is_notrans(blis_trans) && (ldb < k)) ||
|
||||
(bli_is_trans(blis_trans) && (ldb < n)))
|
||||
{
|
||||
return; // Error.
|
||||
}
|
||||
else
|
||||
{
|
||||
rs_b = bli_is_notrans(blis_trans) ? 1 : ldb;
|
||||
cs_b = bli_is_notrans(blis_trans) ? ldb : 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return; // Error
|
||||
}
|
||||
|
||||
// Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it.
|
||||
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.",
|
||||
__FILE__, __LINE__);
|
||||
return; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
AOCL_MATRIX_TYPE input_mat_type;
|
||||
bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type);
|
||||
|
||||
if (input_mat_type == A_MATRIX)
|
||||
{
|
||||
return; // A reorder not supported.
|
||||
}
|
||||
|
||||
#if (defined(BLIS_KERNELS_ZEN4))
|
||||
if (n == 1)
|
||||
{
|
||||
if (rs_b == 1)
|
||||
{
|
||||
for (dim_t k0 = 0; k0 < k; k0++)
|
||||
{
|
||||
memcpy(&reorder_buf_addr[k0], (char *)(&input_buf_addr[k0]) + 2, sizeof(bfloat16));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (dim_t k0 = 0; k0 < k; k0++)
|
||||
{
|
||||
memcpy(&reorder_buf_addr[k0], (char *)(&input_buf_addr[k0 * rs_b]) + 2, sizeof(bfloat16));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global(&rntm_g);
|
||||
bli_pba_rntm_set_pba(&rntm_g);
|
||||
|
||||
lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(F32OBF16);
|
||||
|
||||
// Create dummy b_reorder obj.
|
||||
lpgemm_obj_t b_reorder;
|
||||
b_reorder.storage.aligned_buffer = reorder_buf_addr;
|
||||
|
||||
// Create dummy original b obj;
|
||||
lpgemm_obj_t b;
|
||||
b.storage.aligned_buffer = (void *)input_buf_addr;
|
||||
b.rs = rs_b;
|
||||
b.cs = cs_b;
|
||||
b.width = n;
|
||||
b.length = k;
|
||||
|
||||
reorderb_mxp_nr64_f32obf16(&b, &b_reorder, &rntm_g, lcntx_g);
|
||||
}
|
||||
|
||||
AOCL_GEMM_UNREORDER(bfloat16, bf16bf16f32of32)
|
||||
{
|
||||
|
||||
@@ -83,6 +83,21 @@ AOCL_GEMM_REORDER(int8_t,s8s8s16os16);
|
||||
AOCL_GEMM_REORDER(int8_t,u8s4s32os32);
|
||||
AOCL_GEMM_REORDER(int8_t, bf16s4f32of32);
|
||||
|
||||
#define AOCL_GEMM_REORDER_MXP(A_type,B_type,LP_SFX) \
|
||||
BLIS_EXPORT_ADDON void aocl_reorder_ ## LP_SFX \
|
||||
( \
|
||||
const char order, \
|
||||
const char trans, \
|
||||
const char mat_type, \
|
||||
const A_type* input_buf_addr, \
|
||||
B_type* reorder_buf_addr, \
|
||||
const dim_t k, \
|
||||
const dim_t n, \
|
||||
const dim_t ldb \
|
||||
) \
|
||||
|
||||
AOCL_GEMM_REORDER_MXP(float,bfloat16,f32obf16);
|
||||
|
||||
#define AOCL_GEMM_UNREORDER(B_type, LP_SFX) \
|
||||
BLIS_EXPORT_ADDON void aocl_unreorder_ ## LP_SFX \
|
||||
( \
|
||||
|
||||
@@ -47,6 +47,7 @@
|
||||
XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(S8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
|
||||
XMACRO(U8S4S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(F32OBF16, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
|
||||
#define LPGEMM_BLKSZ_MAP_ZEN \
|
||||
XMACRO(U8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
|
||||
@@ -57,6 +58,7 @@
|
||||
XMACRO(S8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
|
||||
XMACRO(U8S4S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(BF16S4F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
XMACRO(F32OBF16, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
|
||||
#define LPGEMM_BLKSZ_UPD_MAP_ZEN4_TO_ZEN \
|
||||
XMACRO(F32F32F32OF32, 144, 8160, 512, 6, 16, 1, 6, 16, 1) \
|
||||
|
||||
@@ -142,6 +142,7 @@ static void _lpgemm_cntx_init_func_map()
|
||||
#define KMACRO(ID,FUNC_PTR) global_cntx_t_list[ID].kern_fun_ptr = FUNC_PTR;
|
||||
#define PAMACRO(ID,FUNC_PTR) global_cntx_t_list[ID].packa_fun_ptr = FUNC_PTR;
|
||||
#define PBMACRO(ID,FUNC_PTR) global_cntx_t_list[ID].packb_fun_ptr = FUNC_PTR;
|
||||
#define PBMXPMACRO(ID, FUNC_PTR) global_cntx_t_list[ID].packb_mxp_fun_ptr = FUNC_PTR;
|
||||
#define UBMACRO(ID, FUNC_PTR) global_cntx_t_list[ID].unpackb_fun_ptr = FUNC_PTR;
|
||||
#define PBSMACRO(ID, FUNC_PTR) global_cntx_t_list[ID].packsclb_fun_ptr = FUNC_PTR;
|
||||
#define JITMACRO(ID, FUNC_PTR) global_cntx_t_list[ID].jit_kernel = FUNC_PTR;
|
||||
@@ -155,6 +156,7 @@ static void _lpgemm_cntx_init_func_map()
|
||||
global_cntx_t_list[F32F32F32OF32].kern_fun_ptr = NULL;
|
||||
global_cntx_t_list[BF16BF16F32OF32].kern_fun_ptr = NULL;
|
||||
global_cntx_t_list[BF16S4F32OF32].kern_fun_ptr = NULL;
|
||||
global_cntx_t_list[F32OBF16].kern_fun_ptr = NULL;
|
||||
|
||||
// Kernel dispatch object factory.
|
||||
if ( bli_cpuid_is_avx512bf16_supported() == TRUE )
|
||||
@@ -163,6 +165,7 @@ static void _lpgemm_cntx_init_func_map()
|
||||
LPGEMM_KERN_FUNC_MAP_AVX512_VNNI_BF16
|
||||
LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16
|
||||
LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16
|
||||
LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI_BF16
|
||||
LPGEMM_UNPACKB_FUNC_MAP_AVX512_VNNI_BF16
|
||||
LPGEMM_PACKSCLB_FUNC_MAP_AVX512_VNNI_BF16
|
||||
|
||||
@@ -210,6 +213,7 @@ static void _lpgemm_cntx_init_func_map()
|
||||
LPGEMM_KERN_FUNC_MAP_AVX512_VNNI
|
||||
LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI
|
||||
LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI
|
||||
LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI
|
||||
#endif
|
||||
|
||||
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
|
||||
@@ -226,9 +230,10 @@ static void _lpgemm_cntx_init_func_map()
|
||||
LPGEMM_PACKB_FUNC_MAP_AVX2
|
||||
#endif
|
||||
}
|
||||
|
||||
// If built with a config not supporting zen3/zen4/amdzen, error out
|
||||
// since reference kernels are not available.
|
||||
if ( global_cntx_t_list[F32F32F32OF32].kern_fun_ptr == NULL )
|
||||
if (global_cntx_t_list[F32F32F32OF32].kern_fun_ptr == NULL)
|
||||
{
|
||||
bli_print_msg( "AOCL_GEMM is not compiled using correct Zen config."
|
||||
" Compile using zen3/zen4/amdzen config.",
|
||||
@@ -237,6 +242,7 @@ static void _lpgemm_cntx_init_func_map()
|
||||
}
|
||||
|
||||
#undef PBMACRO
|
||||
#undef PBMXPMACRO
|
||||
#undef PAMACRO
|
||||
#undef KMACRO
|
||||
}
|
||||
|
||||
@@ -65,8 +65,8 @@
|
||||
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
|
||||
PAMACRO(S8S8S16OS16, packa_u8s8s16os16)
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
|
||||
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
|
||||
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
|
||||
PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \
|
||||
@@ -78,6 +78,9 @@
|
||||
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2 \
|
||||
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
|
||||
|
||||
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
|
||||
|
||||
#define LPGEMM_UNPACKB_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
UBMACRO(BF16BF16F32OF32, unpackb_nr64_bf16bf16f32of32)
|
||||
|
||||
@@ -119,7 +122,10 @@
|
||||
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
|
||||
PAMACRO(S8S8S16OS16, packa_u8s8s16os16) \
|
||||
PAMACRO(S8S8S16OS16, packa_u8s8s16os16)
|
||||
|
||||
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI \
|
||||
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI \
|
||||
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
|
||||
@@ -170,10 +176,10 @@
|
||||
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
|
||||
PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \
|
||||
PBMACRO(BF16S4F32OF32, NULL) \
|
||||
PBSMACRO(BF16S4F32OF32, NULL) \
|
||||
PBSMACRO(BF16S4F32OF32, NULL)
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_TO_AVX2 \
|
||||
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
|
||||
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512 \
|
||||
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
|
||||
|
||||
#define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512 \
|
||||
UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \
|
||||
|
||||
@@ -254,11 +254,13 @@ void unreorderb_nr64_bf16bf16f32of32
|
||||
}
|
||||
}
|
||||
|
||||
void reorderb_nr64_bf16s4f32of32(
|
||||
lpgemm_obj_t *b,
|
||||
lpgemm_obj_t *b_reorder,
|
||||
rntm_t *rntm,
|
||||
lpgemm_cntx_t *lcntx)
|
||||
void reorderb_nr64_bf16s4f32of32
|
||||
(
|
||||
lpgemm_obj_t *b,
|
||||
lpgemm_obj_t *b_reorder,
|
||||
rntm_t *rntm,
|
||||
lpgemm_cntx_t *lcntx
|
||||
)
|
||||
{
|
||||
dim_t NC = lcntx->blksz.NC;
|
||||
dim_t KC = lcntx->blksz.KC;
|
||||
@@ -377,3 +379,129 @@ void reorderb_nr64_bf16s4f32of32(
|
||||
b_reorder->cs = cs_b_reorder;
|
||||
b_reorder->mtag = REORDERED;
|
||||
}
|
||||
|
||||
void reorderb_mxp_nr64_f32obf16
|
||||
(
|
||||
lpgemm_obj_t *b,
|
||||
lpgemm_obj_t *b_reorder,
|
||||
rntm_t *rntm,
|
||||
lpgemm_cntx_t *lcntx
|
||||
)
|
||||
{
|
||||
dim_t NC = lcntx->blksz.NC;
|
||||
dim_t KC = lcntx->blksz.KC;
|
||||
dim_t NR = lcntx->blksz.NR;
|
||||
|
||||
// Extracting the matrix properties from the lpgemm object
|
||||
dim_t rs_b = b->rs;
|
||||
dim_t cs_b = b->cs;
|
||||
dim_t n = b->width;
|
||||
dim_t k = b->length;
|
||||
|
||||
dim_t rs_b_reorder;
|
||||
dim_t cs_b_reorder;
|
||||
|
||||
// k needs to be a multiple of 2 so that it can be used with dpbf
|
||||
// instruction. Padding is added in cases this condition is not
|
||||
// satisfied, and therefore the k offset used for packed/reordered
|
||||
// buffer needs to be updated.
|
||||
dim_t k_updated = k;
|
||||
k_updated += (k_updated & 0x1);
|
||||
|
||||
dim_t n_threads = bli_rntm_num_threads(rntm);
|
||||
n_threads = (n_threads > 0) ? n_threads : 1;
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
_Pragma("omp parallel num_threads(n_threads)")
|
||||
{
|
||||
// Initialise a local thrinfo obj for work split across threads.
|
||||
thrinfo_t thread_jc;
|
||||
bli_thrinfo_set_n_way(n_threads, &thread_jc);
|
||||
bli_thrinfo_set_work_id(omp_get_thread_num(), &thread_jc);
|
||||
#else
|
||||
{
|
||||
// Initialise a local thrinfo obj for work split across threads.
|
||||
thrinfo_t thread_jc;
|
||||
bli_thrinfo_set_n_way(1, &thread_jc);
|
||||
bli_thrinfo_set_work_id(0, &thread_jc);
|
||||
#endif
|
||||
// Compute the JC loop thread range for the current thread.
|
||||
dim_t jc_start, jc_end;
|
||||
bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end);
|
||||
|
||||
for (dim_t jc = jc_start; jc < jc_end; jc += NC)
|
||||
{
|
||||
dim_t nc0 = bli_min((jc_end - jc), NC);
|
||||
|
||||
dim_t jc_cur_loop = jc;
|
||||
dim_t jc_cur_loop_rem = 0;
|
||||
dim_t n_sub_updated;
|
||||
|
||||
get_B_panel_reordered_start_offset_width(
|
||||
jc, n, NC, 16,
|
||||
&jc_cur_loop, &jc_cur_loop_rem,
|
||||
&nc0, &n_sub_updated);
|
||||
|
||||
for (dim_t pc = 0; pc < k; pc += KC)
|
||||
{
|
||||
dim_t kc0 = bli_min((k - pc), KC);
|
||||
|
||||
// k needs to be a multiple of 2 so that it can be used with dpbf
|
||||
// instruction. Padding is added in cases this condition is not
|
||||
// satisfied, and therefore the k offset used for packed/reordered
|
||||
// buffer needs to be updated.
|
||||
dim_t kc0_updated = kc0;
|
||||
kc0_updated += (kc0_updated & 0x1);
|
||||
|
||||
// The offsets are calculated in such a way that it resembles
|
||||
// the reorder buffer traversal in single threaded reordering.
|
||||
// The panel boundaries (KCxNC) remain as it is accessed in
|
||||
// single thread, and as a consequence a thread with jc_start
|
||||
// inside the panel cannot consider NC range for reorder. It
|
||||
// has to work with NC' < NC, and the offset is calulated using
|
||||
// prev NC panels spanning k dim + cur NC panel spaning pc loop
|
||||
// cur iteration + (NC - NC') spanning current kc0 (<= KC).
|
||||
//
|
||||
// Eg: Consider the following reordered buffer diagram:
|
||||
// t1 t2
|
||||
// | |
|
||||
// | |..NC..|
|
||||
// | | |
|
||||
// |.NC. |.NC. |NC'|NC"
|
||||
// pc=0-+-----+-----+---+--+
|
||||
// KC| | | | |
|
||||
// | 1 | 3 | 5 |
|
||||
// pc=KC-+-----+-----+---st-+
|
||||
// KC| | | | |
|
||||
// | 2 | 4 | 6 | 7|
|
||||
// pc=k=2KC-+-----+-----+---+--+
|
||||
// |jc=0 |jc=NC|jc=2NC|
|
||||
//
|
||||
// The numbers 1,2..6,7 denotes the order in which reordered
|
||||
// KCxNC blocks are stored in memory, ie: block 1 followed by 2
|
||||
// followed by 3, etc. Given two threads t1 and t2, and t2 needs
|
||||
// to acces point st in the reorder buffer to write the data:
|
||||
// The offset calulation logic will be:
|
||||
// jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC,
|
||||
// n_sub_updated = NC, k = 2KC, kc0_updated = KC
|
||||
//
|
||||
// st = ( jc_cur_loop * k ) <traverse blocks 1,2,3,4>
|
||||
// + ( n_sub_updated * pc ) <traverse block 5>
|
||||
// + ( NC' * kc0_updated) <traverse block 6>
|
||||
((pack_f32bf16)lcntx->packb_mxp_fun_ptr)(
|
||||
(((bfloat16 *)b_reorder->storage.aligned_buffer) +
|
||||
(jc_cur_loop * k_updated) + (n_sub_updated * pc) +
|
||||
(jc_cur_loop_rem * kc0_updated)),
|
||||
(((float *)b->storage.aligned_buffer) +
|
||||
(rs_b * pc) + (jc * cs_b)),
|
||||
rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder);
|
||||
}
|
||||
|
||||
adjust_B_panel_reordered_jc(&jc, jc_cur_loop);
|
||||
}
|
||||
}
|
||||
|
||||
b_reorder->rs = rs_b_reorder;
|
||||
b_reorder->cs = cs_b_reorder;
|
||||
b_reorder->mtag = REORDERED;
|
||||
}
|
||||
@@ -53,6 +53,14 @@ void reorderb_nr64_bf16s4f32of32
|
||||
lpgemm_cntx_t* lcntx
|
||||
);
|
||||
|
||||
void reorderb_mxp_nr64_f32obf16
|
||||
(
|
||||
lpgemm_obj_t * b,
|
||||
lpgemm_obj_t * b_reorder,
|
||||
rntm_t* rntm,
|
||||
lpgemm_cntx_t* lcntx
|
||||
);
|
||||
|
||||
void unreorderb_nr64_bf16bf16f32of32
|
||||
(
|
||||
lpgemm_obj_t * b,
|
||||
|
||||
@@ -69,10 +69,11 @@ typedef enum
|
||||
BF16BF16F32OF32 = 3, // bf16 - A, bf16 - B, float - C
|
||||
S8S8S32OS32 = 4, // int8_t - A, int8_t - B, int32_t - C
|
||||
S8S8S16OS16 = 5, // int8_t - A, int8_t - B, int16_t - C
|
||||
U8S4S32OS32 = 6, // Only used for reordering int4_t B matrix.
|
||||
BF16S4F32OF32 = 7 // Only used for reordering int4_t B matrix.
|
||||
U8S4S32OS32 = 6, // Only used for reordering int4_t B matrix.
|
||||
BF16S4F32OF32 = 7, // Only used for reordering int4_t B matrix.
|
||||
F32OBF16 = 8 // Only used for reordering input float matrix to bf16 reorder
|
||||
} AOCL_OPERATION_TYPE;
|
||||
#define AOCL_OPERATION_TYPE_LEN 8
|
||||
#define AOCL_OPERATION_TYPE_LEN 9
|
||||
|
||||
typedef enum
|
||||
{
|
||||
@@ -160,6 +161,7 @@ typedef struct
|
||||
lpgemm_block_size_t blksz;
|
||||
void_fp kern_fun_ptr;
|
||||
void_fp packa_fun_ptr;
|
||||
void_fp packb_mxp_fun_ptr;
|
||||
void_fp packb_fun_ptr;
|
||||
void_fp unpackb_fun_ptr;
|
||||
void_fp packsclb_fun_ptr;
|
||||
|
||||
@@ -47,7 +47,20 @@ BLIS_INLINE dim_t get_packb_bf16bf16f32of32_min_NR()
|
||||
return 16;
|
||||
}
|
||||
|
||||
typedef void (*pack_s4bf16)(
|
||||
typedef void (*pack_f32bf16)
|
||||
(
|
||||
bfloat16*,
|
||||
const float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
dim_t*,
|
||||
dim_t*
|
||||
);
|
||||
|
||||
typedef void (*pack_s4bf16)
|
||||
(
|
||||
bfloat16 *,
|
||||
const int8_t *,
|
||||
const dim_t,
|
||||
@@ -93,6 +106,18 @@ typedef void (*pack_s4)
|
||||
lpgemm_pre_op*
|
||||
);
|
||||
|
||||
void packb_mxp_nr64_f32obf16
|
||||
(
|
||||
bfloat16 *pack_b_buffer_bf16bf16f32of32,
|
||||
const float *b,
|
||||
const dim_t rs_b,
|
||||
const dim_t cs_b,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t *rs_p,
|
||||
dim_t *cs_p
|
||||
);
|
||||
|
||||
void packb_nr64_bf16bf16f32of32
|
||||
(
|
||||
bfloat16* pack_b_buffer_bf16bf16f32of32,
|
||||
|
||||
1249
kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_f32obf16_amd512vnni.c
Normal file
1249
kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_f32obf16_amd512vnni.c
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user