Adding support for AOCL_ENABLE_INSTRUCTIONS for f32 LPGEMM API.

-Currently lpgemm sets the context (block sizes and micro-kernels) based
on the ISA of the machine it is being executed on. However this approach
does not give the flexibility to select a different context at runtime.
In order to enable runtime selection of context, the context
initialization is modified to read the AOCL_ENABLE_INSTRUCTIONS env
variable and set the context based on the same. As part of this commit,
only f32 context selection is enabled.
-Bug fixes in scale ops in f32 micro-kernels and GEMV path selection.
-Added vectorized f32 packing kernels for NR=16(AVX2) and NR=64(AVX512).
This is only for B matrix and helps remove dependency of f32 lpgemm api
on the BLIS packing framework.

AMD Internal: [CPUPL-5959]

Change-Id: I4b459aaf33c54423952f89905ba43cf119ce20f6
This commit is contained in:
Mithun Mohan
2024-10-28 06:38:57 +00:00
committed by sireesha.sanga
parent 0d5c09d042
commit 880a971dc5
18 changed files with 1374 additions and 439 deletions

View File

@@ -58,11 +58,6 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32)
// Initialize lpgemm context.
aocl_lpgemm_init_global_cntx();
// Query the global cntx.
cntx_t* cntx = bli_gks_query_cntx();
num_t dt = BLIS_FLOAT;
AOCL_MATRIX_TYPE input_mat_type;
bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type );
@@ -71,15 +66,18 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32)
return 0; // A reorder not supported.
}
const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx );
const dim_t NR = lpgemm_get_block_size_NR_global_cntx( F32F32F32OF32 );
// Extra space since packing does width in multiples of NR.
dim_t n_reorder;
if(n == 1)
#ifdef BLIS_KERNELS_ZEN4
if( ( n == 1 ) && ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
{
//When n == 1, LPGEMV doesn't expect B to be reordered.
n_reorder = 1;
}else
}
else
#endif
{
n_reorder = ( ( n + NR - 1 ) / NR ) * NR;
}
@@ -134,11 +132,6 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
// Initialize lpgemm context.
aocl_lpgemm_init_global_cntx();
// Query the global cntx.
cntx_t* cntx = bli_gks_query_cntx();
num_t dt = BLIS_FLOAT;
AOCL_MATRIX_TYPE input_mat_type;
bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type );
@@ -147,20 +140,14 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
return; // A reorder not supported.
}
const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx );
const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx );
const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx );
// Query the context for various blocksizes.
lpgemm_cntx_t* lcntx = lpgemm_get_global_cntx_obj( F32F32F32OF32 );
dim_t NC = lcntx->blksz.NC;
dim_t KC = lcntx->blksz.KC;
dim_t NR = lcntx->blksz.NR;
inc_t rs_p = NR;
float one_local = *PASTEMAC(s,1);
float* restrict kappa_cast = &one_local;
// Set the schema to "row stored column panels" to indicate packing to
// conventional row-stored column panels.
pack_t schema = BLIS_PACKED_COL_PANELS;
trans_t transc = BLIS_NO_TRANSPOSE;
conj_t conjc = bli_extract_conj( transc );
dim_t rs_b_reorder = 0;
dim_t cs_b_reorder = 0;
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
@@ -170,9 +157,10 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
dim_t n_threads = bli_rntm_num_threads( &rntm_g );
n_threads = ( n_threads > 0 ) ? n_threads : 1;
#ifdef BLIS_KERNELS_ZEN4
//When n == 1, B marix becomes a vector.
//Reordering is avoided so that LPGEMV can process it efficiently.
if(n == 1)
if( ( n == 1 ) && ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
{
if(rs_b == 1)
{
@@ -186,6 +174,7 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
}
return;
}
#endif
#ifdef BLIS_ENABLE_OPENMP
_Pragma( "omp parallel num_threads(n_threads)" )
@@ -220,15 +209,9 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
&nc0, &n_sub_updated
);
// Compute the total number of iterations we'll need.
dim_t n_iter = ( nc0 + NR - 1 ) / NR;
for ( dim_t pc = 0; pc < k; pc += KC )
{
dim_t kc0 = bli_min( ( k - pc ), KC );
inc_t ps_p = kc0 * NR;
const float* b_temp = input_buf_addr + ( jc * cs_b ) + ( pc * rs_b );
// The offsets are calculated in such a way that it resembles
// the reorder buffer traversal in single threaded reordering.
@@ -265,34 +248,13 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
// st = ( jc_cur_loop * k ) <traverse blocks 1,2,3,4>
// + ( n_sub_updated * pc ) <traverse block 5>
// + ( NC' * kc0_updated) <traverse block 6>
float* p_temp = reorder_buf_addr + ( jc_cur_loop * k ) +
( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 );
dim_t jr, it;
// Iterate over every logical micropanel in the source matrix.
for ( jr = 0, it = 0; it < n_iter; jr += NR, it += 1 )
{
dim_t panel_dim_i = bli_min( NR, nc0 - jr );
const float* b_use = b_temp + ( jr * cs_b );
float* p_use = p_temp;
PASTEMAC(s,packm_cxk)
(
conjc,
schema,
panel_dim_i,
NR,
kc0,
kc0,
kappa_cast,
( float* )b_use, cs_b, rs_b,
p_use, rs_p,
cntx
);
p_temp += ps_p;
}
( ( lpgemm_pack_f32 )lcntx->packb_fun_ptr )
(
reorder_buf_addr + ( jc_cur_loop * k ) +
( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 ),
input_buf_addr + ( rs_b * pc ) + ( cs_b * jc ),
rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder
);
}
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );

View File

@@ -41,6 +41,7 @@
#define LPGEMM_BLKSZ_MAP_ZEN4 \
XMACRO(U8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
XMACRO(U8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(F32F32F32OF32, 192, 8064, 512, 6, 64, 1, 6, 64, 1) \
XMACRO(BF16BF16F32OF32, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \
XMACRO(BF16S4F32OF32, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \
XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
@@ -50,12 +51,27 @@
#define LPGEMM_BLKSZ_MAP_ZEN \
XMACRO(U8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
XMACRO(U8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(F32F32F32OF32, 144, 8160, 512, 6, 16, 1, 6, 16, 1) \
XMACRO(BF16BF16F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(S8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
XMACRO(U8S4S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(BF16S4F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
#define LPGEMM_BLKSZ_UPD_MAP_ZEN4_TO_ZEN \
XMACRO(F32F32F32OF32, 144, 8160, 512, 6, 16, 1, 6, 16, 1) \
// The STMACRO follows the format MT, NT, KT which are SUP switch thresholds.
// ID = One of the AOCL_OPERATION_TYPE enum.
#define LPGEMM_SUP_THRES_MAP_ZEN4 \
STMACRO(F32F32F32OF32, 682, 512, 240) \
#define LPGEMM_SUP_THRES_MAP_ZEN \
STMACRO(F32F32F32OF32, 512, 200, 240) \
#define LPGEMM_SUP_THRES_UPD_MAP_ZEN4_TO_ZEN \
STMACRO(F32F32F32OF32, 512, 200, 240) \
#define LPGEMM_ELTWISE_OPS_BLKSZ_MAP_ZEN4 \
XMACRO(BF16OF32, 144, 1024, 2048, 6, 64) \

View File

@@ -55,6 +55,8 @@ static lpgemm_eltwise_ops_cntx_t
global_eltwise_ops_cntx_t_list[AOCL_ELTWISE_OPS_OPERATION_TYPE_LEN] \
__attribute__((aligned(64))); //Post-ops only utils without gemm.
static arch_t global_lpgemm_enable_arch = BLIS_ARCH_ERROR;
// This array is to store function pointers to jit generated kernels.
static void* global_jit_kernels[ LPGEMM_BF16_MR ]
[ ( LPGEMM_BF16_NR / NUM_F32_ELEMS_PER_ZMM ) + 1 ]
@@ -67,6 +69,25 @@ static void* global_jit_kernels[ LPGEMM_BF16_MR ]
static bli_pthread_once_t once_check_lpgemm_func_map_init = BLIS_PTHREAD_ONCE_INIT;
static void _lpgemm_init_enable_arch()
{
arch_t arch_id = bli_arch_query_id();
bool enbl_instr = bli_aocl_enable_instruction_query();
if ( ( enbl_instr == TRUE ) &&
( ( arch_id == BLIS_ARCH_ZEN3 ) ||
( arch_id == BLIS_ARCH_ZEN2 ) ||
( arch_id == BLIS_ARCH_ZEN ) ) )
{
global_lpgemm_enable_arch = BLIS_ARCH_ZEN3;
}
}
arch_t lpgemm_get_enabled_arch()
{
return global_lpgemm_enable_arch;
}
static void _lpgemm_util_cntx_init_func_map()
{
#define UMACRO(ID,FUNC_PTR) global_util_cntx_t_list[ID].kern_fun_ptr = FUNC_PTR;
@@ -175,6 +196,13 @@ static void _lpgemm_cntx_init_func_map()
}
#endif
#endif
// If arch is updated at runtime, it is expeceted to be honoured.
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
{
LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
}
}
else if ( bli_cpuid_is_avx512vnni_supported() == TRUE )
{
@@ -183,6 +211,12 @@ static void _lpgemm_cntx_init_func_map()
LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI
LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI
#endif
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
{
LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2
LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2;
}
}
else if ( bli_cpuid_is_avx2fma3_supported() == TRUE )
{
@@ -263,6 +297,11 @@ static void _lpgemm_cntx_init_blksz_map()
if ( bli_cpuid_is_avx512vnni_supported() == TRUE )
{
LPGEMM_BLKSZ_MAP_ZEN4
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
{
LPGEMM_BLKSZ_UPD_MAP_ZEN4_TO_ZEN
}
}
else if ( bli_cpuid_is_avx2fma3_supported() == TRUE )
{
@@ -276,6 +315,45 @@ static void _lpgemm_cntx_init_blksz_map()
#undef XMACRO
}
BLIS_INLINE void lpgemm_set_sup_thres_global_cntx
(
AOCL_OPERATION_TYPE op_type,
dim_t MT,
dim_t NT,
dim_t KT
)
{
global_cntx_t_list[op_type].sup_thres.MT = MT;
global_cntx_t_list[op_type].sup_thres.NT = NT;
global_cntx_t_list[op_type].sup_thres.KT = KT;
}
static void _lpgemm_cntx_init_sup_thres_map()
{
#define STMACRO(ID,MT,NT,KT) \
lpgemm_set_sup_thres_global_cntx(ID, MT, NT, KT); \
if ( bli_cpuid_is_avx512vnni_supported() == TRUE )
{
LPGEMM_SUP_THRES_MAP_ZEN4
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
{
LPGEMM_SUP_THRES_UPD_MAP_ZEN4_TO_ZEN
}
}
else if ( bli_cpuid_is_avx2fma3_supported() == TRUE )
{
LPGEMM_SUP_THRES_MAP_ZEN
}
else
{
LPGEMM_SUP_THRES_MAP_ZEN
}
#undef STMACRO
}
BLIS_INLINE void lpgemm_set_block_sizes_global_eltwise_ops_cntx
(
AOCL_ELTWISE_OPS_OPERATION_TYPE op_type,
@@ -317,8 +395,10 @@ static void _lpgemm_eltwise_ops_cntx_init_blksz_map()
static void lpgemm_cntx_init_map()
{
_lpgemm_init_enable_arch();
_lpgemm_cntx_init_func_map();
_lpgemm_cntx_init_blksz_map();
_lpgemm_cntx_init_sup_thres_map();
_lpgemm_eltwise_ops_cntx_init_blksz_map();
_lpgemm_eltwise_ops_cntx_init_func_map();
_lpgemm_util_cntx_init_func_map();
@@ -375,6 +455,21 @@ dim_t lpgemm_get_block_size_MR_global_cntx( AOCL_OPERATION_TYPE op_type )
return global_cntx_t_list[op_type].blksz.MR;
}
dim_t lpgemm_get_sup_thres_MT_global_cntx( AOCL_OPERATION_TYPE op_type )
{
return global_cntx_t_list[op_type].sup_thres.MT;
}
dim_t lpgemm_get_sup_thres_NT_global_cntx( AOCL_OPERATION_TYPE op_type )
{
return global_cntx_t_list[op_type].sup_thres.NT;
}
dim_t lpgemm_get_sup_thres_KT_global_cntx( AOCL_OPERATION_TYPE op_type )
{
return global_cntx_t_list[op_type].sup_thres.KT;
}
void lpgemm_get_packa_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs )
{
*rs = lcntx->pack_s.packa_rs;

View File

@@ -61,6 +61,14 @@ dim_t lpgemm_get_block_size_NR_global_cntx( AOCL_OPERATION_TYPE op_type );
dim_t lpgemm_get_block_size_MR_global_cntx( AOCL_OPERATION_TYPE op_type );
dim_t lpgemm_get_sup_thres_MT_global_cntx( AOCL_OPERATION_TYPE op_type );
dim_t lpgemm_get_sup_thres_NT_global_cntx( AOCL_OPERATION_TYPE op_type );
dim_t lpgemm_get_sup_thres_KT_global_cntx( AOCL_OPERATION_TYPE op_type );
arch_t lpgemm_get_enabled_arch();
void lpgemm_get_packa_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs );
void lpgemm_get_packb_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs );

View File

@@ -44,7 +44,7 @@
// TODO: Add reference kernels for BF16/VNNI kernels for ISA combinations
// that is not supported.
// Genoa
// AVX512 + VNNI + BF16
#define LPGEMM_KERN_FUNC_MAP_AVX512_VNNI_BF16 \
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \
@@ -54,6 +54,9 @@
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2 \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16 \
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
@@ -65,12 +68,16 @@
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \
PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \
PBMACRO(BF16S4F32OF32, packb_nr64_bf16s4f32of32)
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2 \
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
#define LPGEMM_UNPACKB_FUNC_MAP_AVX512_VNNI_BF16 \
UBMACRO(BF16BF16F32OF32, unpackb_nr64_bf16bf16f32of32)
@@ -93,6 +100,7 @@
UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx512_kernel) \
// AVX512 + VNNI
#define LPGEMM_KERN_FUNC_MAP_AVX512_VNNI \
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \
@@ -102,6 +110,9 @@
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI \
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
@@ -113,18 +124,23 @@
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI \
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \
PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \
PBSMACRO(BF16S4F32OF32, packb_nr64_bf16s4f32of32)
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
#define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512_VNNI \
UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \
UMACRO(F32_GELU_ERF, lpgemm_util_f32_gelu_erf_avx512_kernel) \
UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx512_kernel) \
// AVX512
#define LPGEMM_KERN_FUNC_MAP_AVX512 \
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \
@@ -134,6 +150,9 @@
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_TO_AVX2 \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
#define LPGEMM_PACKA_FUNC_MAP_AVX512 \
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
@@ -145,6 +164,7 @@
#define LPGEMM_PACKB_FUNC_MAP_AVX512 \
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
PBMACRO(BF16BF16F32OF32, NULL) \
PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
@@ -152,13 +172,16 @@
PBMACRO(BF16S4F32OF32, NULL) \
PBSMACRO(BF16S4F32OF32, NULL) \
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_TO_AVX2 \
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
#define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512 \
UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \
UMACRO(F32_GELU_ERF, lpgemm_util_f32_gelu_erf_avx512_kernel) \
UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx512_kernel) \
// Milan
// AVX2
#define LPGEMM_KERN_FUNC_MAP_AVX2 \
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
KMACRO(U8S8S32OS32, NULL) \
@@ -179,6 +202,7 @@
#define LPGEMM_PACKB_FUNC_MAP_AVX2 \
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
PBMACRO(U8S8S32OS32, NULL) \
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
PBMACRO(BF16BF16F32OF32, NULL) \
KMACRO(BF16S4F32OF32, NULL) \
PBMACRO(S8S8S32OS32, NULL) \

View File

@@ -75,25 +75,9 @@ void lpgemm_pack_a_f32f32f32of32
cntx_t* cntx
);
void lpgemm_pack_b_f32f32f32of32
(
const float* input_buf_addr_b,
float* reorder_buf_addr_b,
const dim_t n,
const dim_t k,
const dim_t rs_b,
const dim_t cs_b,
const dim_t ps_p,
const dim_t NR,
cntx_t* cntx
);
#ifdef BLIS_KERNELS_ZEN4
LPGEMV(float, float, float, f32f32f32of32)
{
cntx_t *cntx = bli_gks_query_cntx();
num_t dt = BLIS_FLOAT;
const float* a_use = (float*)a;
inc_t rs_a_use = rs_a;
inc_t cs_a_use = cs_a;
@@ -101,21 +85,20 @@ LPGEMV(float, float, float, f32f32f32of32)
float* b_use = (float*)b;
inc_t rs_b_use = rs_b;
inc_t cs_b_use = cs_b;
inc_t ps_b_use;
siz_t mem_a_size_req = 0;
mem_t mem_a = BLIS_MEM_INITIALIZER;
siz_t mem_b_size_req = 0;
mem_t mem_b = BLIS_MEM_INITIALIZER;
mem_t mem_a = BLIS_MEM_INITIALIZER;
siz_t mem_b_size_req = 0;
mem_t mem_b = BLIS_MEM_INITIALIZER;
float* pack_a_buffer_f32f32f32of32;
float* pack_b_buffer_f32f32f32of32;
// Query the context for various blocksizes.
const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_NR, cntx);
const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_NC, cntx);
const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_KC, cntx);
const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_KC, cntx);
const dim_t NC = lcntx->blksz.NC;
const dim_t KC = lcntx->blksz.KC;
const dim_t MC = lcntx->blksz.MC;
const dim_t NR = lcntx->blksz.NR;
// Strides are updated based on matrix packing/reordering.
float *c_use = NULL;
@@ -161,9 +144,9 @@ LPGEMV(float, float, float, f32f32f32of32)
// Compute the IC loop thread range for the current thread.
dim_t ic_start, ic_end;
thread_ic.n_way = ( thread_ic.n_way == 1 ) ?
( thread->n_threads ) : ( thread_ic.n_way );
thread_ic.work_id = thread->tid;
thread_ic.n_way = ( thread_ic.n_way == 1 ) ?
( thread->n_threads ) : ( thread_ic.n_way );
thread_ic.work_id = thread->tid;
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
for (dim_t ic = ic_start; ic < ic_end; ic += MC)
@@ -219,9 +202,9 @@ LPGEMV(float, float, float, f32f32f32of32)
{
// Compute the JC loop thread range for the current thread.
dim_t jc_start, jc_end;
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
( thread->n_threads ) : ( thread_jc.n_way );
thread_jc.work_id = thread->tid;
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
( thread->n_threads ) : ( thread_jc.n_way );
thread_jc.work_id = thread->tid;
bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end);
if ( mtag_a == PACK )
@@ -296,16 +279,13 @@ LPGEMV(float, float, float, f32f32f32of32)
// Set the strides for pack buffer.
rs_b_use = NR;
cs_b_use = 1;
ps_b_use = kc0;
lpgemm_pack_b_f32f32f32of32
(
( b + ( rs_b * pc ) + ( cs_b * jc ) ),
pack_b_buffer_f32f32f32of32 + ( n_sub_updated * pc ),
nc0 , kc0,
rs_b, cs_b, ( NR * ps_b_use ), NR,
cntx
);
( ( lpgemm_pack_f32 )lcntx->packb_fun_ptr )
(
pack_b_buffer_f32f32f32of32 + ( n_sub_updated * pc ),
b + ( rs_b * pc ) + ( cs_b * jc ),
rs_b, cs_b, nc0, kc0, &rs_b_use, &cs_b_use
);
}
b_use = pack_b_buffer_f32f32f32of32;
}
@@ -339,10 +319,10 @@ LPGEMV(float, float, float, f32f32f32of32)
} // jc loop
// Release pack buffers.
if ( ( mtag_b == PACK ) && ( bli_mem_is_alloc( &mem_b ) ) )
{
bli_pba_release( rntm, &mem_b );
}
if ( ( mtag_b == PACK ) && ( bli_mem_is_alloc( &mem_b ) ) )
{
bli_pba_release( rntm, &mem_b );
}
}
}
#endif
@@ -352,7 +332,9 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
#ifdef BLIS_KERNELS_ZEN4
// Handle using LPGEMV when m or/and n equal to 1
// The avx512 check will be removed when avx2 kernels added in future
if ( ( ( m == 1 ) || ( n == 1 ) ) && (bli_cpuid_is_avx512_supported() == TRUE) )
if ( ( ( m == 1 ) || ( n == 1 ) ) &&
( bli_cpuid_is_avx512_supported() == TRUE ) &&
( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
{
lpgemv_rowvar_f32f32f32of32(m, n, k,
a, rs_a, cs_a, mtag_a,
@@ -371,14 +353,12 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
// Query the global cntx.
cntx_t* cntx = bli_gks_query_cntx();
num_t dt = BLIS_FLOAT;
// Query the context for various blocksizes.
const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx );
const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx );
const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx );
const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx );
const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx );
const dim_t NC = lcntx->blksz.NC;
const dim_t KC = lcntx->blksz.KC;
const dim_t MC = lcntx->blksz.MC;
const dim_t NR = lcntx->blksz.NR;
const dim_t MR = lcntx->blksz.MR;
// Strides are updated based on matrix packing/reordering.
const float* a_use = NULL;
@@ -535,13 +515,12 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
if ( ( jc_packb_end > jc_packb_start ) &&
( jc_packb_start < ( jc + nc0 ) ) )
{
lpgemm_pack_b_f32f32f32of32
( ( lpgemm_pack_f32 )lcntx->packb_fun_ptr )
(
( b + ( rs_b * pc ) + ( cs_b * jc ) + ( cs_b * jc_packb_start ) ),
pack_b_buffer_f32f32f32of32 + ( jc_packb_start * kc0 ),
( jc_packb_end - jc_packb_start ), kc0,
rs_b, cs_b, ( NR * ps_b_use ), NR,
cntx
b + ( rs_b * pc ) + ( cs_b * jc ) +
( cs_b * jc_packb_start ),
rs_b, cs_b, nc0, kc0, &rs_b_use, &cs_b_use
);
}
@@ -740,58 +719,3 @@ void lpgemm_pack_a_f32f32f32of32
p_temp += ps_p;
}
}
void lpgemm_pack_b_f32f32f32of32
(
const float* input_buf_addr_b,
float* reorder_buf_addr_b,
const dim_t n,
const dim_t k,
const dim_t rs_b,
const dim_t cs_b,
const dim_t ps_p,
const dim_t NR,
cntx_t* cntx
)
{
float one_local = *PASTEMAC(s,1);
float* restrict kappa_cast = &one_local;
// Set the schema to "row stored column panels" to indicate packing to
// conventional row-stored column panels.
pack_t schema = BLIS_PACKED_COL_PANELS;
trans_t transc = BLIS_NO_TRANSPOSE;
conj_t conjc = bli_extract_conj( transc );
// Compute the total number of iterations we'll need.
dim_t n_iter = ( n + NR - 1 ) / NR;
inc_t rs_p = NR;
float* p_temp = reorder_buf_addr_b;
dim_t jr, it;
// Iterate over every logical micropanel in the source matrix.
for ( jr = 0, it = 0; it < n_iter; jr += NR, it += 1 )
{
dim_t panel_dim_i = bli_min( NR, n - jr );
const float* b_use = input_buf_addr_b + ( jr * cs_b );
float* p_use = p_temp;
PASTEMAC(s,packm_cxk)
(
conjc,
schema,
panel_dim_i,
NR,
k,
k,
kappa_cast,
( float* )b_use, cs_b, rs_b,
p_use, rs_p,
cntx
);
p_temp += ps_p;
}
}

View File

@@ -69,7 +69,7 @@ typedef enum
BF16BF16F32OF32 = 3, // bf16 - A, bf16 - B, float - C
S8S8S32OS32 = 4, // int8_t - A, int8_t - B, int32_t - C
S8S8S16OS16 = 5, // int8_t - A, int8_t - B, int16_t - C
U8S4S32OS32 = 6, // Only used for reordering int4_t B matrix.
U8S4S32OS32 = 6, // Only used for reordering int4_t B matrix.
BF16S4F32OF32 = 7 // Only used for reordering int4_t B matrix.
} AOCL_OPERATION_TYPE;
#define AOCL_OPERATION_TYPE_LEN 8
@@ -148,6 +148,13 @@ typedef struct
dim_t packb_cs;
} lpgemm_pack_strides_t;
typedef struct
{
dim_t MT;
dim_t NT;
dim_t KT;
} lpgemm_sup_thres_t;
typedef struct
{
lpgemm_block_size_t blksz;
@@ -157,6 +164,7 @@ typedef struct
void_fp unpackb_fun_ptr;
void_fp packsclb_fun_ptr;
lpgemm_pack_strides_t pack_s;
lpgemm_sup_thres_t sup_thres;
} lpgemm_cntx_t;
typedef struct

View File

@@ -503,7 +503,6 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
}
else if ( ( *n_threads ) > 1 )
{
dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 );
dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 );
dim_t mr_blks = ( m + MR - 1 ) / MR;
@@ -558,22 +557,17 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
rntm_t* rntm_g
)
{
// Query the global cntx.
cntx_t* cntx = bli_gks_query_cntx();
num_t dt = BLIS_FLOAT;
// Query the context for SUP limits.
const dim_t MT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx );
const dim_t NT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx );
const dim_t KT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx );
const dim_t MT = lpgemm_get_sup_thres_MT_global_cntx( F32F32F32OF32 );
const dim_t NT = lpgemm_get_sup_thres_NT_global_cntx( F32F32F32OF32 );
const dim_t KT = lpgemm_get_sup_thres_KT_global_cntx( F32F32F32OF32 );
// Query the context for various blocksizes.
const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx );
const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx );
const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx );
const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx );
const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx );
dim_t NR = lpgemm_get_block_size_NR_global_cntx( F32F32F32OF32 );
dim_t MR = lpgemm_get_block_size_MR_global_cntx( F32F32F32OF32 );
dim_t MC = lpgemm_get_block_size_MC_global_cntx( F32F32F32OF32 );
dim_t NC = lpgemm_get_block_size_NC_global_cntx( F32F32F32OF32 );
dim_t KC = lpgemm_get_block_size_KC_global_cntx( F32F32F32OF32 );
const dim_t MT_2 = MT / 2;
@@ -640,7 +634,7 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
if ( ( m >= MT ) && ( n >= NT ) && ( k >= KT ) )
{
if (((k <= page_size_b_floatx2) && (m_ic > MT_2) && (n_jc >= NT)) ||
if (((k >= page_size_b_floatx2) && (m_ic > MT_2) && (n_jc >= NT)) ||
((bli_cpuid_is_avx512_supported() == FALSE) && (k > page_size_b_floatx2)))
{
bli_rntm_set_pack_b( 1, rntm_g );

View File

@@ -31,8 +31,8 @@
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef BLIS_GEMM_F32_PACKA
#define BLIS_GEMM_F32_PACKA
#ifndef BLIS_GEMM_F32_PACKAB
#define BLIS_GEMM_F32_PACKAB
void packa_mr16_f32f32f32of32_col_major
(
@@ -45,6 +45,43 @@ void packa_mr16_f32f32f32of32_col_major
dim_t* rs_p,
dim_t* cs_p
);
#endif
typedef void (*lpgemm_pack_f32)
(
float*,
const float*,
const dim_t,
const dim_t,
const dim_t,
const dim_t,
dim_t*,
dim_t*
);
void packb_nr64_f32f32f32of32
(
float* pack_b_buffer,
const float* b,
const dim_t rs_b,
const dim_t cs_b,
const dim_t NC,
const dim_t KC,
dim_t* rs_p,
dim_t* cs_p
);
void packb_nr16_f32f32f32of32
(
float* pack_b_buffer,
const float* b,
const dim_t rs_b,
const dim_t cs_b,
const dim_t NC,
const dim_t KC,
dim_t* rs_p,
dim_t* cs_p
);
#endif //BLIS_GEMM_F32_PACKAB

View File

@@ -1313,7 +1313,7 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
} \
} \
\
if ( global_dscale_out == 'y' || global_can_dscale == 'y') \
if ( ( global_dscale_out == 'y' ) || ( global_can_dscale == 'y' ) ) \
{ \
post_ops->seq_vector[cur_op_index] = SCALE; \
cur_op_index++; \

View File

@@ -530,16 +530,16 @@ POST_OPS_DOWNSCALE_5x16F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 4 );
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
}
//c[0, 0-7]
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
@@ -1096,14 +1096,14 @@ POST_OPS_DOWNSCALE_4x16F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
}
//c[0, 0-7]
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
@@ -1574,12 +1574,12 @@ POST_OPS_DOWNSCALE_3x16F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
}
//c[0, 0-7]
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
@@ -1959,10 +1959,10 @@ POST_OPS_DOWNSCALE_2x16F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
}
//c[0, 0-7]
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
@@ -2262,8 +2262,8 @@ POST_OPS_DOWNSCALE_1x16F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
//c[0, 0-7]
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
@@ -2666,16 +2666,16 @@ POST_OPS_DOWNSCALE_5x8F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 4 );
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
}
//c[0, 0-7]
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
@@ -3083,14 +3083,14 @@ POST_OPS_DOWNSCALE_4x8F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
}
//c[0, 0-7]
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
@@ -3444,12 +3444,12 @@ POST_OPS_DOWNSCALE_3x8F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
}
//c[0, 0-7]
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
@@ -3750,10 +3750,10 @@ POST_OPS_DOWNSCALE_2x8F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
}
//c[0, 0-7]
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
@@ -3996,8 +3996,8 @@ POST_OPS_DOWNSCALE_1x8F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
}
//c[0, 0-7]
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
@@ -6017,13 +6017,13 @@ POST_OPS_DOWNSCALE_5x2F:
{
if( post_ops_list_temp->scale_factor_len > 1 )
{
selector0 = ( __m128 )_mm_load_sd( (const double*)( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
selector0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = ( __m128 )_mm_load_sd( (const double*)(float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
zero_point0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
}
//c[0, 0-3]
F32_SCL_MULRND_SSE(xmm4, selector0, zero_point0);
@@ -6429,13 +6429,13 @@ POST_OPS_DOWNSCALE_4x2F:
{
if( post_ops_list_temp->scale_factor_len > 1 )
{
selector0 = ( __m128 )_mm_load_sd( (const double*)( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
selector0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = ( __m128 )_mm_load_sd( (const double*)(float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
zero_point0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
}
//c[0, 0-3]
F32_SCL_MULRND_SSE(xmm4, selector0, zero_point0);
@@ -6785,13 +6785,13 @@ POST_OPS_DOWNSCALE_3x2F:
{
if( post_ops_list_temp->scale_factor_len > 1 )
{
selector0 = ( __m128 )_mm_load_sd( (const double*)( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
selector0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = ( __m128 )_mm_load_sd( (const double*)(float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
zero_point0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
}
//c[0, 0-3]
F32_SCL_MULRND_SSE(xmm4, selector0, zero_point0);
@@ -7090,13 +7090,13 @@ POST_OPS_DOWNSCALE_2x2F:
{
if( post_ops_list_temp->scale_factor_len > 1 )
{
selector0 = _mm_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 4) );
selector0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm_loadu_ps( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 4 ) );
zero_point0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
}
//c[0, 0-3]
F32_SCL_MULRND_SSE(xmm4, selector0, zero_point0);
@@ -7335,13 +7335,13 @@ POST_OPS_DOWNSCALE_1x2F:
{
if( post_ops_list_temp->scale_factor_len > 1 )
{
selector0 = ( __m128 )_mm_load_sd( (const double*)( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
selector0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = ( __m128 )_mm_load_sd( (const double*)(float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
zero_point0 = ( __m128 )_mm_load_sd( (const double*)( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
}
//c[0, 0-3]
F32_SCL_MULRND_SSE(xmm4, selector0, zero_point0);
@@ -7698,12 +7698,12 @@ POST_OPS_DOWNSCALE_5x1F:
{
if( post_ops_list_temp->scale_factor_len > 1 )
{
selector0 = _mm_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
selector0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 4) );
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm_loadu_ps( (float* )post_ops_list_temp->op_args1 +
zero_point0 = ( __m128 )_mm_load_ss( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 4 ) );
}
//c[0, 0-3]
@@ -8109,12 +8109,12 @@ POST_OPS_DOWNSCALE_4x1F:
{
if( post_ops_list_temp->scale_factor_len > 1 )
{
selector0 = _mm_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
selector0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 4) );
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm_loadu_ps( (float* )post_ops_list_temp->op_args1 +
zero_point0 = ( __m128 )_mm_load_ss( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 4 ) );
}
//c[0, 0-3]
@@ -8464,12 +8464,12 @@ POST_OPS_DOWNSCALE_3x1F:
{
if( post_ops_list_temp->scale_factor_len > 1 )
{
selector0 = _mm_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
selector0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 4) );
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm_loadu_ps( (float* )post_ops_list_temp->op_args1 +
zero_point0 = ( __m128 )_mm_load_ss( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 4 ) );
}
//c[0, 0-3]
@@ -8769,12 +8769,12 @@ POST_OPS_DOWNSCALE_2x1F:
{
if( post_ops_list_temp->scale_factor_len > 1 )
{
selector0 = _mm_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
selector0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 4) );
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm_loadu_ps( (float* )post_ops_list_temp->op_args1 +
zero_point0 = ( __m128 )_mm_load_ss( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 4 ) );
}
//c[0, 0-3]
@@ -9013,12 +9013,12 @@ POST_OPS_DOWNSCALE_1x1F:
{
if( post_ops_list_temp->scale_factor_len > 1 )
{
selector0 = _mm_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
selector0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_j + ( 0 * 4) );
}
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm_loadu_ps( (float* )post_ops_list_temp->op_args1 +
zero_point0 = ( __m128 )_mm_load_ss( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_j + ( 0 * 4 ) );
}
//c[0, 0-3]

View File

@@ -689,18 +689,18 @@ POST_OPS_DOWNSCALE_6x16F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 4 );
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 5 );
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 5 ) );
}
//c[0, 0-7]
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
@@ -1299,18 +1299,18 @@ POST_OPS_DOWNSCALE_6x8F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 4 );
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 5 );
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 5 ) );
}
//c[0, 0-7]
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);

View File

@@ -0,0 +1,367 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binarsy form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <immintrin.h>
#include <string.h>
#include "blis.h"
#ifdef BLIS_ADDON_LPGEMM
void packb_nr16_f32f32f32of32_row_major
(
float* pack_b_buffer,
const float* b,
const dim_t ldb,
const dim_t NC,
const dim_t KC,
dim_t* rs_b,
dim_t* cs_b
);
void packb_nr16_f32f32f32of32_col_major
(
float* pack_b_buffer,
const float* b,
const dim_t ldb,
const dim_t NC,
const dim_t KC,
dim_t* rs_b,
dim_t* cs_b
);
void packb_nr16_f32f32f32of32
(
float* pack_b_buffer,
const float* b,
const dim_t rs_b,
const dim_t cs_b,
const dim_t NC,
const dim_t KC,
dim_t* rs_p,
dim_t* cs_p
)
{
if( cs_b == 1 )
{
packb_nr16_f32f32f32of32_row_major( pack_b_buffer, b,
rs_b, NC, KC, rs_p, cs_p );
}
else
{
packb_nr16_f32f32f32of32_col_major( pack_b_buffer, b,
cs_b, NC, KC, rs_p, cs_p );
}
}
void packb_nr16_f32f32f32of32_row_major
(
float* pack_b_buffer,
const float* b,
const dim_t ldb,
const dim_t NC,
const dim_t KC,
dim_t* rs_b,
dim_t* cs_b
)
{
dim_t NR = 16;
__m256 a0;
__m256 b0;
dim_t n_full_pieces_loop_limit = ( NC / NR ) * NR;
dim_t n_partial_pieces = NC % NR;
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
{
for ( dim_t kr = 0; kr < KC; kr += 1 )
{
a0 = _mm256_loadu_ps( b + ( jc + 0 ) + ( ldb * kr ) );
b0 = _mm256_loadu_ps( b + ( jc + 8 ) + ( ldb * kr ) );
//store to pack_b buffer
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) +
( ( kr * NR ) + 0 ), a0 );
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) +
( ( kr * NR ) + 8 ), b0 );
}
}
if( n_partial_pieces > 0 )
{
for ( dim_t kr = 0; kr < KC; kr += 1 )
{
// No point in vectorizing fringe case since n_fringe is expected
// to be laid out contiguously in pack buffer as nr_fringe*KC
// instead of 8*KC + 4*KC + .., etc.
memcpy( pack_b_buffer + ( n_full_pieces_loop_limit * KC ) +
( ( kr * NR ) + 0 ),
b + ( n_full_pieces_loop_limit + 0 ) + ( ldb * kr ),
n_partial_pieces * ( sizeof( float ) ) );
// Zero out padding data.
memset( pack_b_buffer + ( n_full_pieces_loop_limit * KC ) +
( ( kr * NR ) + n_partial_pieces ),
0, ( NR - n_partial_pieces ) * sizeof( float ) );
}
}
*rs_b = NR;
*cs_b = 1;
}
#define LOAD_PS_8x8() \
a_reg[0] = _mm256_loadu_ps( b + ( ldb * ( jr + 0 ) ) + ( kr ) ); \
a_reg[1] = _mm256_loadu_ps( b + ( ldb * ( jr + 1 ) ) + ( kr ) ); \
a_reg[2] = _mm256_loadu_ps( b + ( ldb * ( jr + 2 ) ) + ( kr ) ); \
a_reg[3] = _mm256_loadu_ps( b + ( ldb * ( jr + 3 ) ) + ( kr ) ); \
a_reg[4] = _mm256_loadu_ps( b + ( ldb * ( jr + 4 ) ) + ( kr ) ); \
a_reg[5] = _mm256_loadu_ps( b + ( ldb * ( jr + 5 ) ) + ( kr ) ); \
a_reg[6] = _mm256_loadu_ps( b + ( ldb * ( jr + 6 ) ) + ( kr ) ); \
a_reg[7] = _mm256_loadu_ps( b + ( ldb * ( jr + 7 ) ) + ( kr ) ); \
#define K_FRINGE_MEMCPY_LOAD_PS_8x8() \
memcpy( buf0, b + ( ldb * ( jr + 0 ) ) + ( kr ), \
k_partial_pieces * sizeof( float ) ); \
a_reg[0] = _mm256_loadu_ps( buf0 ); \
memcpy( buf1, b + ( ldb * ( jr + 1 ) ) + ( kr ), \
k_partial_pieces * sizeof( float ) ); \
a_reg[1] = _mm256_loadu_ps( buf1 ); \
memcpy( buf2, b + ( ldb * ( jr + 2 ) ) + ( kr ), \
k_partial_pieces * sizeof( float ) ); \
a_reg[2] = _mm256_loadu_ps( buf2 ); \
memcpy( buf3, b + ( ldb * ( jr + 3 ) ) + ( kr ), \
k_partial_pieces * sizeof( float ) ); \
a_reg[3] = _mm256_loadu_ps( buf3 ); \
memcpy( buf4, b + ( ldb * ( jr + 4 ) ) + ( kr ), \
k_partial_pieces * sizeof( float ) ); \
a_reg[4] = _mm256_loadu_ps( buf4 ); \
memcpy( buf5, b + ( ldb * ( jr + 5 ) ) + ( kr ), \
k_partial_pieces * sizeof( float ) ); \
a_reg[5] = _mm256_loadu_ps( buf5 ); \
memcpy( buf6, b + ( ldb * ( jr + 6 ) ) + ( kr ), \
k_partial_pieces * sizeof( float ) ); \
a_reg[6] = _mm256_loadu_ps( buf6 ); \
memcpy( buf7, b + ( ldb * ( jr + 7 ) ) + ( kr ), \
k_partial_pieces * sizeof( float ) ); \
a_reg[7] = _mm256_loadu_ps( buf7 ); \
#define N_FRINGE_LOAD_PS_8x8() \
for ( int i = 0; i < jr_elems; ++i ) \
{ \
a_reg[i] = _mm256_loadu_ps( b + ( ldb * ( jr + i ) ) + ( kr ) ); \
} \
for ( int i = jr_elems; i < n_sub_blk_wdth; ++i ) \
{ \
a_reg[i] = _mm256_setzero_ps(); \
} \
#define KN_FRINGE_MEMCPY_LOAD_PS_8x8() \
for ( int i = 0; i < jr_elems; ++i ) \
{ \
memcpy( buf0, b + ( ldb * ( jr + i ) ) + ( kr ), \
k_partial_pieces * sizeof( float ) ); \
a_reg[i] = _mm256_loadu_ps( buf0 ); \
} \
for ( int i = jr_elems; i < n_sub_blk_wdth; ++i ) \
{ \
a_reg[i] = _mm256_setzero_ps(); \
} \
#define UNPACK_PS_8x8() \
/* Even indices contains lo parts, odd indices contains hi parts. */ \
b_reg[0] = _mm256_unpacklo_ps( a_reg[0], a_reg[1] ); \
b_reg[1] = _mm256_unpackhi_ps( a_reg[0], a_reg[1] ); \
b_reg[2] = _mm256_unpacklo_ps( a_reg[2], a_reg[3] ); \
b_reg[3] = _mm256_unpackhi_ps( a_reg[2], a_reg[3] ); \
b_reg[4] = _mm256_unpacklo_ps( a_reg[4], a_reg[5] ); \
b_reg[5] = _mm256_unpackhi_ps( a_reg[4], a_reg[5] ); \
b_reg[6] = _mm256_unpacklo_ps( a_reg[6], a_reg[7] ); \
b_reg[7] = _mm256_unpackhi_ps( a_reg[6], a_reg[7] ); \
#define UNPACK_PD_8x8() \
/* Even indices contains lo parts, odd indices contains hi parts. */ \
a_reg[0] = ( __m256 )_mm256_unpacklo_pd( ( __m256d )b_reg[0], \
( __m256d )b_reg[2] ); \
a_reg[1] = ( __m256 )_mm256_unpackhi_pd( ( __m256d )b_reg[0], \
( __m256d )b_reg[2] ); \
a_reg[2] = ( __m256 )_mm256_unpacklo_pd( ( __m256d )b_reg[4], \
( __m256d )b_reg[6] ); \
a_reg[3] = ( __m256 )_mm256_unpackhi_pd( ( __m256d )b_reg[4], \
( __m256d )b_reg[6] ); \
a_reg[4] = ( __m256 )_mm256_unpacklo_pd( ( __m256d )b_reg[1], \
( __m256d )b_reg[3] ); \
a_reg[5] = ( __m256 )_mm256_unpackhi_pd( ( __m256d )b_reg[1], \
( __m256d )b_reg[3] ); \
a_reg[6] = ( __m256 )_mm256_unpacklo_pd( ( __m256d )b_reg[5], \
( __m256d )b_reg[7] ); \
a_reg[7] = ( __m256 )_mm256_unpackhi_pd( ( __m256d )b_reg[5], \
( __m256d )b_reg[7] ); \
#define PERMUTE_R1_8x8() \
/* Even indices contains lo parts, odd indices contains hi parts. */ \
b_reg[0] = _mm256_permute2f128_ps( a_reg[0], a_reg[2], 0x20 ); /* Row 0 */ \
b_reg[1] = _mm256_permute2f128_ps( a_reg[0], a_reg[2], 0x31 ); /* Row 4 */ \
b_reg[2] = _mm256_permute2f128_ps( a_reg[4], a_reg[6], 0x20 ); /* Row 2 */ \
b_reg[3] = _mm256_permute2f128_ps( a_reg[4], a_reg[6], 0x31 ); /* Row 6 */ \
b_reg[4] = _mm256_permute2f128_ps( a_reg[1], a_reg[3], 0x20 ); /* Row 1 */ \
b_reg[5] = _mm256_permute2f128_ps( a_reg[1], a_reg[3], 0x31 ); /* Row 5 */ \
b_reg[6] = _mm256_permute2f128_ps( a_reg[5], a_reg[7], 0x20 ); /* Row 3 */ \
b_reg[7] = _mm256_permute2f128_ps( a_reg[5], a_reg[7], 0x31 ); /* Row 7 */ \
#define STORE_PS_8x8() \
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 0 ) * NR ) + jr_offset, \
b_reg[0] ); \
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 1 ) * NR ) + jr_offset, \
b_reg[4] ); \
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 2 ) * NR ) + jr_offset, \
b_reg[2] ); \
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 3 ) * NR ) + jr_offset, \
b_reg[6] ); \
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 4 ) * NR ) + jr_offset, \
b_reg[1] ); \
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 5 ) * NR ) + jr_offset, \
b_reg[5] ); \
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 6 ) * NR ) + jr_offset, \
b_reg[3] ); \
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 7 ) * NR ) + jr_offset, \
b_reg[7] ); \
#define K_FRINGE_SAFE_STORE_PS_8x8() \
a_reg[0] = b_reg[0]; \
a_reg[1] = b_reg[4]; \
a_reg[2] = b_reg[2]; \
a_reg[3] = b_reg[6]; \
a_reg[4] = b_reg[1]; \
a_reg[5] = b_reg[5]; \
a_reg[6] = b_reg[3]; \
a_reg[7] = b_reg[7]; \
for (int i = 0; i < k_partial_pieces; ++i ) \
{ \
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + i ) * NR ) + jr_offset, \
a_reg[i] ); \
} \
void packb_nr16_f32f32f32of32_col_major
(
float* pack_b_buffer,
const float* b,
const dim_t ldb,
const dim_t NC,
const dim_t KC,
dim_t* rs_b,
dim_t* cs_b
)
{
const dim_t NR = 16;
const dim_t n_sub_blk_wdth = 8;
const dim_t k_reg_size = 8;
float buf0[8] = { 0 };
float buf1[8] = { 0 };
float buf2[8] = { 0 };
float buf3[8] = { 0 };
float buf4[8] = { 0 };
float buf5[8] = { 0 };
float buf6[8] = { 0 };
float buf7[8] = { 0 };
__m256 a_reg[8];
__m256 b_reg[8];
dim_t n_full_pieces_loop_limit = ( NC / NR ) * NR;
dim_t n_partial_pieces = NC % NR;
dim_t k_full_pieces_loop_limit = ( KC / k_reg_size ) * k_reg_size;
dim_t k_partial_pieces = KC % k_reg_size;
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
{
for ( dim_t jr = jc; jr < jc + NR; jr += n_sub_blk_wdth )
{
dim_t jr_offset = jr % NR;
for ( dim_t kr = 0; kr < k_full_pieces_loop_limit; kr += k_reg_size )
{
LOAD_PS_8x8();
UNPACK_PS_8x8();
UNPACK_PD_8x8();
PERMUTE_R1_8x8();
STORE_PS_8x8();
}
if ( k_partial_pieces > 0 )
{
dim_t kr = k_full_pieces_loop_limit;
K_FRINGE_MEMCPY_LOAD_PS_8x8();
UNPACK_PS_8x8();
UNPACK_PD_8x8();
PERMUTE_R1_8x8();
K_FRINGE_SAFE_STORE_PS_8x8();
}
}
}
if( n_partial_pieces > 0 )
{
dim_t jc = n_full_pieces_loop_limit;
for ( dim_t jr = n_full_pieces_loop_limit; jr < NC; jr += n_sub_blk_wdth )
{
dim_t jr_offset = jr % NR;
dim_t jr_elems = ( ( NC - jr ) >= n_sub_blk_wdth ) ? n_sub_blk_wdth : ( NC - jr );
for ( dim_t kr = 0; kr < k_full_pieces_loop_limit; kr += k_reg_size )
{
N_FRINGE_LOAD_PS_8x8();
UNPACK_PS_8x8();
UNPACK_PD_8x8();
PERMUTE_R1_8x8();
STORE_PS_8x8();
}
if ( k_partial_pieces > 0 )
{
dim_t kr = k_full_pieces_loop_limit;
KN_FRINGE_MEMCPY_LOAD_PS_8x8();
UNPACK_PS_8x8();
UNPACK_PD_8x8();
PERMUTE_R1_8x8();
K_FRINGE_SAFE_STORE_PS_8x8();
}
}
}
*rs_b = NR;
*cs_b = 1;
}
#endif

View File

@@ -322,7 +322,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64)
post_ops_list, post_ops_attr
);
// No leftover fringe after this podint.
// No leftover fringe after this point.
}
return;
}

View File

@@ -125,13 +125,6 @@ void unpackb_nr32_bf16bf16f32of32_row_major
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = KC % 2;
// KC when not multiple of 2 will have padding to make it multiple of 2 in packed buffer.
dim_t KC_updated = KC;
if ( k_partial_pieces > 0 )
{
KC_updated += ( 2 - k_partial_pieces );
}
__m512i a0, c0;
__m512i a01;

View File

@@ -839,14 +839,14 @@ POST_OPS_DOWNSCALE_5x64F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -903,8 +903,8 @@ POST_OPS_DOWNSCALE_5x64F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 4 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
}
//c[4, 0-15]
F32_SCL_MULRND(zmm24, selector1, zero_point0);
@@ -1745,14 +1745,14 @@ POST_OPS_DOWNSCALE_4x64F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -2485,12 +2485,12 @@ POST_OPS_DOWNSCALE_3x64F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -3066,10 +3066,10 @@ POST_OPS_DOWNSCALE_2x64F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -3484,8 +3484,8 @@ POST_OPS_DOWNSCALE_1x64F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -4205,14 +4205,14 @@ POST_OPS_DOWNSCALE_5x48F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -4257,8 +4257,8 @@ POST_OPS_DOWNSCALE_5x48F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 4 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
}
//c[4, 0-15]
F32_SCL_MULRND(zmm24, selector1, zero_point0);
@@ -4955,14 +4955,14 @@ POST_OPS_DOWNSCALE_4x48F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -5569,12 +5569,12 @@ POST_OPS_DOWNSCALE_3x48F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -6058,10 +6058,10 @@ POST_OPS_DOWNSCALE_2x48F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -6422,8 +6422,8 @@ POST_OPS_DOWNSCALE_1x48F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -6993,16 +6993,16 @@ POST_OPS_DOWNSCALE_5x32F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 4 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
zero_point4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -7576,14 +7576,14 @@ POST_OPS_DOWNSCALE_4x32F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -8068,12 +8068,12 @@ POST_OPS_DOWNSCALE_3x32F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -8464,10 +8464,10 @@ POST_OPS_DOWNSCALE_2x32F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -8771,8 +8771,8 @@ POST_OPS_DOWNSCALE_1x32F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);

View File

@@ -986,6 +986,7 @@ POST_OPS_DOWNSCALE_6x64F:
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
}
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
@@ -1108,14 +1109,14 @@ POST_OPS_DOWNSCALE_6x64F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1 );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1 ) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
}
//c[0, 0-15]
@@ -2174,14 +2175,14 @@ POST_OPS_DOWNSCALE_6x48F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);
@@ -2228,10 +2229,10 @@ POST_OPS_DOWNSCALE_6x48F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 4 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 5);
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 5) );
}
//c[4, 0-15]
F32_SCL_MULRND(zmm24, selector1, zero_point0);
@@ -3013,18 +3014,18 @@ POST_OPS_DOWNSCALE_6x32F:
}
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
{
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 0 );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 1);
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 2 );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 3 );
zero_point4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 4 );
zero_point5 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
post_ops_attr.post_op_c_i + 5 );
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 0 ) );
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 1) );
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
zero_point4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
zero_point5 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 5 ) );
}
//c[0, 0-15]
F32_SCL_MULRND(zmm8, selector1, zero_point0);

View File

@@ -0,0 +1,506 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binarsy form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <immintrin.h>
#include "blis.h"
#ifdef BLIS_ADDON_LPGEMM
void packb_nr64_f32f32f32of32_row_major
(
float* pack_b_buffer,
const float* b,
const dim_t ldb,
const dim_t NC,
const dim_t KC,
dim_t* rs_b,
dim_t* cs_b
);
void packb_nr64_f32f32f32of32_col_major
(
float* pack_b_buffer,
const float* b,
const dim_t ldb,
const dim_t NC,
const dim_t KC,
dim_t* rs_b,
dim_t* cs_b
);
void packb_nr64_f32f32f32of32
(
float* pack_b_buffer,
const float* b,
const dim_t rs_b,
const dim_t cs_b,
const dim_t NC,
const dim_t KC,
dim_t* rs_p,
dim_t* cs_p
)
{
if( cs_b == 1 )
{
packb_nr64_f32f32f32of32_row_major( pack_b_buffer, b,
rs_b, NC, KC, rs_p, cs_p );
}
else
{
packb_nr64_f32f32f32of32_col_major( pack_b_buffer, b,
cs_b, NC, KC, rs_p, cs_p );
}
}
void packb_nr64_f32f32f32of32_row_major
(
float* pack_b_buffer,
const float* b,
const dim_t ldb,
const dim_t NC,
const dim_t KC,
dim_t* rs_b,
dim_t* cs_b
)
{
dim_t NR = 64;
__m512 a0;
__m512 b0;
__m512 c0;
__m512 d0;
dim_t n_full_pieces_loop_limit = ( NC / NR ) * NR;
dim_t n_partial_pieces = NC % NR;
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
{
for ( dim_t kr = 0; kr < KC; kr += 1 )
{
a0 = _mm512_loadu_ps( b + ( ldb * kr ) + ( jc + 0 ) );
b0 = _mm512_loadu_ps( b + ( ldb * kr ) + ( jc + 16 ) );
c0 = _mm512_loadu_ps( b + ( ldb * kr ) + ( jc + 32 ) );
d0 = _mm512_loadu_ps( b + ( ldb * kr ) + ( jc + 48 ) );
//store to pack_b buffer
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) +
( ( kr * NR ) + 0 ), a0 );
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) +
( ( kr * NR ) + 16 ), b0 );
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) +
( ( kr * NR ) + 32 ), c0 );
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) +
( ( kr * NR ) + 48 ), d0 );
}
}
if( n_partial_pieces > 0 )
{
dim_t n0_partial_rem = n_partial_pieces % 16;
dim_t n0_48 = n_partial_pieces / 48;
dim_t n0_32 = n_partial_pieces / 32;
dim_t n0_16 = n_partial_pieces / 16;
__mmask16 lmask_0 = _cvtu32_mask16( 0x0 );
__mmask16 lmask_1 = _cvtu32_mask16( 0x0 );
__mmask16 lmask_2 = _cvtu32_mask16( 0x0 );
__mmask16 lmask_3 = _cvtu32_mask16( 0x0 );
if ( n0_48 > 0 )
{
lmask_0 = _cvtu32_mask16( 0xFFFF );
lmask_1 = _cvtu32_mask16( 0xFFFF );
lmask_2 = _cvtu32_mask16( 0xFFFF );
lmask_3 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
}
else if ( n0_32 > 0 )
{
lmask_0 = _cvtu32_mask16( 0xFFFF );
lmask_1 = _cvtu32_mask16( 0xFFFF );
lmask_2 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
}
else if ( n0_16 > 0 )
{
lmask_0 = _cvtu32_mask16( 0xFFFF );
lmask_1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
}
else
{
lmask_0 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
}
for ( dim_t kr = 0; kr < KC; kr += 1 )
{
a0 = _mm512_maskz_loadu_ps( lmask_0, b + ( ldb * kr ) +
( n_full_pieces_loop_limit + 0 ) );
b0 = _mm512_maskz_loadu_ps( lmask_1, b + ( ldb * kr ) +
( n_full_pieces_loop_limit + 16 ) );
c0 = _mm512_maskz_loadu_ps( lmask_2, b + ( ldb * kr ) +
( n_full_pieces_loop_limit + 32 ) );
d0 = _mm512_maskz_loadu_ps( lmask_3, b + ( ldb * kr ) +
( n_full_pieces_loop_limit + 48 ) );
//store to pack_b buffer
_mm512_storeu_ps( pack_b_buffer + ( n_full_pieces_loop_limit * KC ) +
( ( kr * NR ) + 0 ), a0 );
_mm512_storeu_ps( pack_b_buffer + ( n_full_pieces_loop_limit * KC ) +
( ( kr * NR ) + 16 ), b0 );
_mm512_storeu_ps( pack_b_buffer + ( n_full_pieces_loop_limit * KC ) +
( ( kr * NR ) + 32 ), c0 );
_mm512_storeu_ps( pack_b_buffer + ( n_full_pieces_loop_limit * KC ) +
( ( kr * NR ) + 48 ), d0 );
}
}
*rs_b = NR;
*cs_b = 1;
}
#define MASK_LOAD_PS_16x16(msk) \
a_reg[0] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 0 ) ) + ( kr ) ); \
a_reg[1] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 1 ) ) + ( kr ) ); \
a_reg[2] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 2 ) ) + ( kr ) ); \
a_reg[3] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 3 ) ) + ( kr ) ); \
a_reg[4] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 4 ) ) + ( kr ) ); \
a_reg[5] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 5 ) ) + ( kr ) ); \
a_reg[6] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 6 ) ) + ( kr ) ); \
a_reg[7] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 7 ) ) + ( kr ) ); \
a_reg[8] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 8 ) ) + ( kr ) ); \
a_reg[9] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 9 ) ) + ( kr ) ); \
a_reg[10] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 10 ) ) + ( kr ) ); \
a_reg[11] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 11 ) ) + ( kr ) ); \
a_reg[12] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 12 ) ) + ( kr ) ); \
a_reg[13] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 13 ) ) + ( kr ) ); \
a_reg[14] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 14 ) ) + ( kr ) ); \
a_reg[15] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 15 ) ) + ( kr ) ); \
#define MASK_ARR_LOAD_PS_16x16(msk_n_arr) \
a_reg[0] = _mm512_maskz_loadu_ps( msk_n_arr[0], b + ( ldb * ( jr + 0 ) ) + ( kr ) ); \
a_reg[1] = _mm512_maskz_loadu_ps( msk_n_arr[1], b + ( ldb * ( jr + 1 ) ) + ( kr ) ); \
a_reg[2] = _mm512_maskz_loadu_ps( msk_n_arr[2], b + ( ldb * ( jr + 2 ) ) + ( kr ) ); \
a_reg[3] = _mm512_maskz_loadu_ps( msk_n_arr[3], b + ( ldb * ( jr + 3 ) ) + ( kr ) ); \
a_reg[4] = _mm512_maskz_loadu_ps( msk_n_arr[4], b + ( ldb * ( jr + 4 ) ) + ( kr ) ); \
a_reg[5] = _mm512_maskz_loadu_ps( msk_n_arr[5], b + ( ldb * ( jr + 5 ) ) + ( kr ) ); \
a_reg[6] = _mm512_maskz_loadu_ps( msk_n_arr[6], b + ( ldb * ( jr + 6 ) ) + ( kr ) ); \
a_reg[7] = _mm512_maskz_loadu_ps( msk_n_arr[7], b + ( ldb * ( jr + 7 ) ) + ( kr ) ); \
a_reg[8] = _mm512_maskz_loadu_ps( msk_n_arr[8], b + ( ldb * ( jr + 8 ) ) + ( kr ) ); \
a_reg[9] = _mm512_maskz_loadu_ps( msk_n_arr[9], b + ( ldb * ( jr + 9 ) ) + ( kr ) ); \
a_reg[10] = _mm512_maskz_loadu_ps( msk_n_arr[10], b + ( ldb * ( jr + 10 ) ) + ( kr ) ); \
a_reg[11] = _mm512_maskz_loadu_ps( msk_n_arr[11], b + ( ldb * ( jr + 11 ) ) + ( kr ) ); \
a_reg[12] = _mm512_maskz_loadu_ps( msk_n_arr[12], b + ( ldb * ( jr + 12 ) ) + ( kr ) ); \
a_reg[13] = _mm512_maskz_loadu_ps( msk_n_arr[13], b + ( ldb * ( jr + 13 ) ) + ( kr ) ); \
a_reg[14] = _mm512_maskz_loadu_ps( msk_n_arr[14], b + ( ldb * ( jr + 14 ) ) + ( kr ) ); \
a_reg[15] = _mm512_maskz_loadu_ps( msk_n_arr[15], b + ( ldb * ( jr + 15 ) ) + ( kr ) ); \
#define UNPACK_PS_16x16() \
/* Even indices contains lo parts, odd indices contains hi parts. */ \
b_reg[0] = _mm512_unpacklo_ps( a_reg[0], a_reg[1] ); \
b_reg[1] = _mm512_unpackhi_ps( a_reg[0], a_reg[1] ); \
b_reg[2] = _mm512_unpacklo_ps( a_reg[2], a_reg[3] ); \
b_reg[3] = _mm512_unpackhi_ps( a_reg[2], a_reg[3] ); \
b_reg[4] = _mm512_unpacklo_ps( a_reg[4], a_reg[5] ); \
b_reg[5] = _mm512_unpackhi_ps( a_reg[4], a_reg[5] ); \
b_reg[6] = _mm512_unpacklo_ps( a_reg[6], a_reg[7] ); \
b_reg[7] = _mm512_unpackhi_ps( a_reg[6], a_reg[7] ); \
b_reg[8] = _mm512_unpacklo_ps( a_reg[8], a_reg[9] ); \
b_reg[9] = _mm512_unpackhi_ps( a_reg[8], a_reg[9] ); \
b_reg[10] = _mm512_unpacklo_ps( a_reg[10], a_reg[11] ); \
b_reg[11] = _mm512_unpackhi_ps( a_reg[10], a_reg[11] ); \
b_reg[12] = _mm512_unpacklo_ps( a_reg[12], a_reg[13] ); \
b_reg[13] = _mm512_unpackhi_ps( a_reg[12], a_reg[13] ); \
b_reg[14] = _mm512_unpacklo_ps( a_reg[14], a_reg[15] ); \
b_reg[15] = _mm512_unpackhi_ps( a_reg[14], a_reg[15] ); \
#define UNPACK_PD_16x16() \
/* Even indices contains lo parts, odd indices contains hi parts. */ \
a_reg[0] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[0], ( __m512d )b_reg[2] ); \
a_reg[1] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[0], ( __m512d )b_reg[2] ); \
a_reg[2] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[4], ( __m512d )b_reg[6] ); \
a_reg[3] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[4], ( __m512d )b_reg[6] ); \
a_reg[4] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[8], ( __m512d )b_reg[10] ); \
a_reg[5] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[8], ( __m512d )b_reg[10] ); \
a_reg[6] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[12], ( __m512d )b_reg[14] ); \
a_reg[7] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[12], ( __m512d )b_reg[14] ); \
a_reg[8] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[1], ( __m512d )b_reg[3] ); \
a_reg[9] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[1], ( __m512d )b_reg[3] ); \
a_reg[10] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[5], ( __m512d )b_reg[7] ); \
a_reg[11] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[5], ( __m512d )b_reg[7] ); \
a_reg[12] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[9], ( __m512d )b_reg[11] ); \
a_reg[13] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[9], ( __m512d )b_reg[11] ); \
a_reg[14] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[13], ( __m512d )b_reg[15] ); \
a_reg[15] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[13], ( __m512d )b_reg[15] ); \
#define PERMUTE_R1_16x16(selector1, selector2) \
/* Even indices contains lo parts, odd indices contains hi parts. */ \
b_reg[0] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[0], selector1, ( __m512d )a_reg[2] ); \
b_reg[1] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[0], selector2, ( __m512d )a_reg[2] ); \
b_reg[2] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[4], selector1, ( __m512d )a_reg[6] ); \
b_reg[3] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[4], selector2, ( __m512d )a_reg[6] ); \
b_reg[4] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[8], selector1, ( __m512d )a_reg[10] ); \
b_reg[5] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[8], selector2, ( __m512d )a_reg[10] ); \
b_reg[6] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[12], selector1, ( __m512d )a_reg[14] ); \
b_reg[7] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[12], selector2, ( __m512d )a_reg[14] ); \
b_reg[8] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[1], selector1, ( __m512d )a_reg[3] ); \
b_reg[9] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[1], selector2, ( __m512d )a_reg[3] ); \
b_reg[10] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[5], selector1, ( __m512d )a_reg[7] ); \
b_reg[11] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[5], selector2, ( __m512d )a_reg[7] ); \
b_reg[12] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[9], selector1, ( __m512d )a_reg[11] ); \
b_reg[13] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[9], selector2, ( __m512d )a_reg[11] ); \
b_reg[14] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[13], selector1, ( __m512d )a_reg[15] ); \
b_reg[15] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[13], selector2, ( __m512d )a_reg[15] ); \
#define PERMUTE_R2_16x16(selector1_1, selector2_1) \
/* Even indices contains lo parts, odd indices contains hi parts. */ \
a_reg[0] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[0], selector1_1, \
( __m512d )b_reg[2] ); /* Row 0 */ \
a_reg[1] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[0], selector2_1, \
( __m512d )b_reg[2] ); /* Row 4 */ \
a_reg[2] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[4], selector1_1, \
( __m512d )b_reg[6] ); /* Row 2 */ \
a_reg[3] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[4], selector2_1, \
( __m512d )b_reg[6] ); /* Row 6 */ \
a_reg[4] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[8], selector1_1, \
( __m512d )b_reg[10] ); /* Row 1 */ \
a_reg[5] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[8], selector2_1, \
( __m512d )b_reg[10] ); /* Row 5 */ \
a_reg[6] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[12], selector1_1, \
( __m512d )b_reg[14] ); /* Row 3 */ \
a_reg[7] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[12], selector2_1, \
( __m512d )b_reg[14] ); /* Row 7 */ \
a_reg[8] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[1], selector1_1, \
( __m512d )b_reg[3] ); /* Row 8 */ \
a_reg[9] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[1], selector2_1, \
( __m512d )b_reg[3] ); /* Row 12 */ \
a_reg[10] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[5], selector1_1, \
( __m512d )b_reg[7] ); /* Row 10 */ \
a_reg[11] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[5], selector2_1, \
( __m512d )b_reg[7] ); /* Row 14 */ \
a_reg[12] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[9], selector1_1, \
( __m512d )b_reg[11] ); /* Row 9 */ \
a_reg[13] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[9], selector2_1, \
( __m512d )b_reg[11] ); /* Row 13 */ \
a_reg[14] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[13], selector1_1, \
( __m512d )b_reg[15] ); /* Row 11 */ \
a_reg[15] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[13], selector2_1, \
( __m512d )b_reg[15] ); /* Row 15 */ \
#define STORE_PS_16x16() \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 0 ) * NR ) + jr_offset, a_reg[0] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 1 ) * NR ) + jr_offset, a_reg[4] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 2 ) * NR ) + jr_offset, a_reg[2] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 3 ) * NR ) + jr_offset, a_reg[6] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 4 ) * NR ) + jr_offset, a_reg[1] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 5 ) * NR ) + jr_offset, a_reg[5] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 6 ) * NR ) + jr_offset, a_reg[3] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 7 ) * NR ) + jr_offset, a_reg[7] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 8 ) * NR ) + jr_offset, a_reg[8] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 9 ) * NR ) + jr_offset, a_reg[12] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 10 ) * NR ) + jr_offset, a_reg[10] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 11 ) * NR ) + jr_offset, a_reg[14] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 12 ) * NR ) + jr_offset, a_reg[9] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 13 ) * NR ) + jr_offset, a_reg[13] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 14 ) * NR ) + jr_offset, a_reg[11] ); \
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 15 ) * NR ) + jr_offset, a_reg[15] ); \
#define MASK_STORE_PS_16x16(msk_arr) \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 0 ) * NR ) + jr_offset, \
msk_arr[0], a_reg[0] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 1 ) * NR ) + jr_offset, \
msk_arr[1], a_reg[4] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 2 ) * NR ) + jr_offset, \
msk_arr[2], a_reg[2] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 3 ) * NR ) + jr_offset, \
msk_arr[3], a_reg[6] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 4 ) * NR ) + jr_offset, \
msk_arr[4], a_reg[1] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 5 ) * NR ) + jr_offset, \
msk_arr[5], a_reg[5] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 6 ) * NR ) + jr_offset, \
msk_arr[6], a_reg[3] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 7 ) * NR ) + jr_offset, \
msk_arr[7], a_reg[7] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 8 ) * NR ) + jr_offset, \
msk_arr[8], a_reg[8] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 9 ) * NR ) + jr_offset, \
msk_arr[9], a_reg[12] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 10 ) * NR ) + jr_offset, \
msk_arr[10], a_reg[10] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 11 ) * NR ) + jr_offset, \
msk_arr[11], a_reg[14] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 12 ) * NR ) + jr_offset, \
msk_arr[12], a_reg[9] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 13 ) * NR ) + jr_offset, \
msk_arr[13], a_reg[13] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 14 ) * NR ) + jr_offset, \
msk_arr[14], a_reg[11] ); \
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 15 ) * NR ) + jr_offset, \
msk_arr[15], a_reg[15] ); \
void packb_nr64_f32f32f32of32_col_major
(
float* pack_b_buffer,
const float* b,
const dim_t ldb,
const dim_t NC,
const dim_t KC,
dim_t* rs_b,
dim_t* cs_b
)
{
const dim_t NR = 64;
const dim_t n_sub_blk_wdth = 16;
const dim_t k_reg_size = 16;
__m512 a_reg[16];
__m512 b_reg[16];
dim_t n_full_pieces_loop_limit = ( NC / NR ) * NR;
dim_t n_partial_pieces = NC % NR;
dim_t k_full_pieces_loop_limit = ( KC / k_reg_size ) * k_reg_size;
dim_t k_partial_pieces = KC % k_reg_size;
// First permute sequences.
__m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB );
__m512i selector2 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF );
// Second permute sequences.
__m512i selector1_1 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB );
__m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF );
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
{
for ( dim_t jr = jc; jr < jc + NR; jr += n_sub_blk_wdth )
{
dim_t jr_offset = jr % NR;
__mmask16 msk = _cvtu32_mask16( 0xFFFF );
for ( dim_t kr = 0; kr < k_full_pieces_loop_limit; kr += k_reg_size )
{
MASK_LOAD_PS_16x16(msk);
UNPACK_PS_16x16();
UNPACK_PD_16x16();
PERMUTE_R1_16x16(selector1, selector2);
PERMUTE_R2_16x16(selector1_1, selector2_1);
STORE_PS_16x16();
}
if ( k_partial_pieces > 0 )
{
msk = _cvtu32_mask16( 0xFFFF >> ( k_reg_size - k_partial_pieces ) );
dim_t kr = k_full_pieces_loop_limit;
MASK_LOAD_PS_16x16(msk);
UNPACK_PS_16x16();
UNPACK_PD_16x16();
PERMUTE_R1_16x16(selector1, selector2);
PERMUTE_R2_16x16(selector1_1, selector2_1);
__mmask16 msk_arr[16];
for ( int i = 0; i < k_partial_pieces; ++i )
{
msk_arr[i] = _cvtu32_mask16(0xFFFF);
}
for ( int i = k_partial_pieces; i < 16; ++i )
{
msk_arr[i] = _cvtu32_mask16(0x0);
}
MASK_STORE_PS_16x16(msk_arr);
}
}
}
if( n_partial_pieces > 0 )
{
dim_t jc = n_full_pieces_loop_limit;
for ( dim_t jr = n_full_pieces_loop_limit; jr < NC; jr += n_sub_blk_wdth )
{
dim_t jr_offset = jr % NR;
__mmask16 msk_n_arr[16];
dim_t jr_elems = ( ( NC - jr ) >= n_sub_blk_wdth ) ? n_sub_blk_wdth : ( NC - jr );
for ( int i = 0; i < jr_elems; ++i )
{
msk_n_arr[i] = _cvtu32_mask16( 0xFFFF );
}
for ( int i = jr_elems; i < n_sub_blk_wdth; ++i )
{
msk_n_arr[i] = _cvtu32_mask16( 0x0 );
}
for ( dim_t kr = 0; kr < k_full_pieces_loop_limit; kr += k_reg_size )
{
MASK_ARR_LOAD_PS_16x16(msk_n_arr);
UNPACK_PS_16x16();
UNPACK_PD_16x16();
PERMUTE_R1_16x16(selector1, selector2);
PERMUTE_R2_16x16(selector1_1, selector2_1);
STORE_PS_16x16();
}
if ( k_partial_pieces > 0 )
{
for ( int i = 0; i < jr_elems; ++i )
{
msk_n_arr[i] = _cvtu32_mask16( 0xFFFF &
( 0xFFFF >> ( k_reg_size - k_partial_pieces ) ) );
}
for ( int i = jr_elems; i < n_sub_blk_wdth; ++i )
{
msk_n_arr[i] = _cvtu32_mask16( 0x0 );
}
dim_t kr = k_full_pieces_loop_limit;
MASK_ARR_LOAD_PS_16x16(msk_n_arr);
UNPACK_PS_16x16();
UNPACK_PD_16x16();
PERMUTE_R1_16x16(selector1, selector2);
PERMUTE_R2_16x16(selector1_1, selector2_1);
__mmask16 msk_arr[16];
for (int i = 0; i < k_partial_pieces; ++i)
{
msk_arr[i] = _cvtu32_mask16(0xFFFF);
}
for (int i = k_partial_pieces; i < 16; ++i)
{
msk_arr[i] = _cvtu32_mask16(0x0);
}
MASK_STORE_PS_16x16(msk_arr);
}
}
}
*rs_b = NR;
*cs_b = 1;
}
#endif