Adding support for AOCL_ENABLE_INSTRUCTIONS for f32 LPGEMM API.

-Currently lpgemm sets the context (block sizes and micro-kernels) based
on the ISA of the machine it is being executed on. However this approach
does not give the flexibility to select a different context at runtime.
In order to enable runtime selection of context, the context
initialization is modified to read the AOCL_ENABLE_INSTRUCTIONS env
variable and set the context based on the same. As part of this commit,
only f32 context selection is enabled.
-Bug fixes in scale ops in f32 micro-kernels and GEMV path selection.
-Added vectorized f32 packing kernels for NR=16(AVX2) and NR=64(AVX512).
This is only for B matrix and helps remove dependency of f32 lpgemm api
on the BLIS packing framework.

AMD Internal: [CPUPL-5959]

Change-Id: I4b459aaf33c54423952f89905ba43cf119ce20f6
This commit is contained in:
Mithun Mohan
2024-10-28 06:38:57 +00:00
parent 9ce2696fc9
commit 097cda9f9e
18 changed files with 1374 additions and 439 deletions

View File

@@ -55,6 +55,8 @@ static lpgemm_eltwise_ops_cntx_t
global_eltwise_ops_cntx_t_list[AOCL_ELTWISE_OPS_OPERATION_TYPE_LEN] \
__attribute__((aligned(64))); //Post-ops only utils without gemm.
static arch_t global_lpgemm_enable_arch = BLIS_ARCH_ERROR;
// This array is to store function pointers to jit generated kernels.
static void* global_jit_kernels[ LPGEMM_BF16_MR ]
[ ( LPGEMM_BF16_NR / NUM_F32_ELEMS_PER_ZMM ) + 1 ]
@@ -67,6 +69,25 @@ static void* global_jit_kernels[ LPGEMM_BF16_MR ]
static bli_pthread_once_t once_check_lpgemm_func_map_init = BLIS_PTHREAD_ONCE_INIT;
static void _lpgemm_init_enable_arch()
{
arch_t arch_id = bli_arch_query_id();
bool enbl_instr = bli_aocl_enable_instruction_query();
if ( ( enbl_instr == TRUE ) &&
( ( arch_id == BLIS_ARCH_ZEN3 ) ||
( arch_id == BLIS_ARCH_ZEN2 ) ||
( arch_id == BLIS_ARCH_ZEN ) ) )
{
global_lpgemm_enable_arch = BLIS_ARCH_ZEN3;
}
}
arch_t lpgemm_get_enabled_arch()
{
return global_lpgemm_enable_arch;
}
static void _lpgemm_util_cntx_init_func_map()
{
#define UMACRO(ID,FUNC_PTR) global_util_cntx_t_list[ID].kern_fun_ptr = FUNC_PTR;
@@ -175,6 +196,13 @@ static void _lpgemm_cntx_init_func_map()
}
#endif
#endif
// If arch is updated at runtime, it is expeceted to be honoured.
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
{
LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
}
}
else if ( bli_cpuid_is_avx512vnni_supported() == TRUE )
{
@@ -183,6 +211,12 @@ static void _lpgemm_cntx_init_func_map()
LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI
LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI
#endif
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
{
LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2
LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2;
}
}
else if ( bli_cpuid_is_avx2fma3_supported() == TRUE )
{
@@ -263,6 +297,11 @@ static void _lpgemm_cntx_init_blksz_map()
if ( bli_cpuid_is_avx512vnni_supported() == TRUE )
{
LPGEMM_BLKSZ_MAP_ZEN4
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
{
LPGEMM_BLKSZ_UPD_MAP_ZEN4_TO_ZEN
}
}
else if ( bli_cpuid_is_avx2fma3_supported() == TRUE )
{
@@ -276,6 +315,45 @@ static void _lpgemm_cntx_init_blksz_map()
#undef XMACRO
}
BLIS_INLINE void lpgemm_set_sup_thres_global_cntx
(
AOCL_OPERATION_TYPE op_type,
dim_t MT,
dim_t NT,
dim_t KT
)
{
global_cntx_t_list[op_type].sup_thres.MT = MT;
global_cntx_t_list[op_type].sup_thres.NT = NT;
global_cntx_t_list[op_type].sup_thres.KT = KT;
}
static void _lpgemm_cntx_init_sup_thres_map()
{
#define STMACRO(ID,MT,NT,KT) \
lpgemm_set_sup_thres_global_cntx(ID, MT, NT, KT); \
if ( bli_cpuid_is_avx512vnni_supported() == TRUE )
{
LPGEMM_SUP_THRES_MAP_ZEN4
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
{
LPGEMM_SUP_THRES_UPD_MAP_ZEN4_TO_ZEN
}
}
else if ( bli_cpuid_is_avx2fma3_supported() == TRUE )
{
LPGEMM_SUP_THRES_MAP_ZEN
}
else
{
LPGEMM_SUP_THRES_MAP_ZEN
}
#undef STMACRO
}
BLIS_INLINE void lpgemm_set_block_sizes_global_eltwise_ops_cntx
(
AOCL_ELTWISE_OPS_OPERATION_TYPE op_type,
@@ -317,8 +395,10 @@ static void _lpgemm_eltwise_ops_cntx_init_blksz_map()
static void lpgemm_cntx_init_map()
{
_lpgemm_init_enable_arch();
_lpgemm_cntx_init_func_map();
_lpgemm_cntx_init_blksz_map();
_lpgemm_cntx_init_sup_thres_map();
_lpgemm_eltwise_ops_cntx_init_blksz_map();
_lpgemm_eltwise_ops_cntx_init_func_map();
_lpgemm_util_cntx_init_func_map();
@@ -375,6 +455,21 @@ dim_t lpgemm_get_block_size_MR_global_cntx( AOCL_OPERATION_TYPE op_type )
return global_cntx_t_list[op_type].blksz.MR;
}
dim_t lpgemm_get_sup_thres_MT_global_cntx( AOCL_OPERATION_TYPE op_type )
{
return global_cntx_t_list[op_type].sup_thres.MT;
}
dim_t lpgemm_get_sup_thres_NT_global_cntx( AOCL_OPERATION_TYPE op_type )
{
return global_cntx_t_list[op_type].sup_thres.NT;
}
dim_t lpgemm_get_sup_thres_KT_global_cntx( AOCL_OPERATION_TYPE op_type )
{
return global_cntx_t_list[op_type].sup_thres.KT;
}
void lpgemm_get_packa_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs )
{
*rs = lcntx->pack_s.packa_rs;