mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Adding support for AOCL_ENABLE_INSTRUCTIONS for f32 LPGEMM API.
-Currently lpgemm sets the context (block sizes and micro-kernels) based on the ISA of the machine it is being executed on. However this approach does not give the flexibility to select a different context at runtime. In order to enable runtime selection of context, the context initialization is modified to read the AOCL_ENABLE_INSTRUCTIONS env variable and set the context based on the same. As part of this commit, only f32 context selection is enabled. -Bug fixes in scale ops in f32 micro-kernels and GEMV path selection. -Added vectorized f32 packing kernels for NR=16(AVX2) and NR=64(AVX512). This is only for B matrix and helps remove dependency of f32 lpgemm api on the BLIS packing framework. AMD Internal: [CPUPL-5959] Change-Id: I4b459aaf33c54423952f89905ba43cf119ce20f6
This commit is contained in:
committed by
sireesha.sanga
parent
0d5c09d042
commit
880a971dc5
@@ -58,11 +58,6 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32)
|
||||
// Initialize lpgemm context.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
// Query the global cntx.
|
||||
cntx_t* cntx = bli_gks_query_cntx();
|
||||
|
||||
num_t dt = BLIS_FLOAT;
|
||||
|
||||
AOCL_MATRIX_TYPE input_mat_type;
|
||||
bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type );
|
||||
|
||||
@@ -71,15 +66,18 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32)
|
||||
return 0; // A reorder not supported.
|
||||
}
|
||||
|
||||
const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx );
|
||||
const dim_t NR = lpgemm_get_block_size_NR_global_cntx( F32F32F32OF32 );
|
||||
|
||||
// Extra space since packing does width in multiples of NR.
|
||||
dim_t n_reorder;
|
||||
if(n == 1)
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
if( ( n == 1 ) && ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
|
||||
{
|
||||
//When n == 1, LPGEMV doesn't expect B to be reordered.
|
||||
n_reorder = 1;
|
||||
}else
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
n_reorder = ( ( n + NR - 1 ) / NR ) * NR;
|
||||
}
|
||||
@@ -134,11 +132,6 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
|
||||
// Initialize lpgemm context.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
// Query the global cntx.
|
||||
cntx_t* cntx = bli_gks_query_cntx();
|
||||
|
||||
num_t dt = BLIS_FLOAT;
|
||||
|
||||
AOCL_MATRIX_TYPE input_mat_type;
|
||||
bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type );
|
||||
|
||||
@@ -147,20 +140,14 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
|
||||
return; // A reorder not supported.
|
||||
}
|
||||
|
||||
const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx );
|
||||
const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx );
|
||||
const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx );
|
||||
// Query the context for various blocksizes.
|
||||
lpgemm_cntx_t* lcntx = lpgemm_get_global_cntx_obj( F32F32F32OF32 );
|
||||
dim_t NC = lcntx->blksz.NC;
|
||||
dim_t KC = lcntx->blksz.KC;
|
||||
dim_t NR = lcntx->blksz.NR;
|
||||
|
||||
inc_t rs_p = NR;
|
||||
|
||||
float one_local = *PASTEMAC(s,1);
|
||||
float* restrict kappa_cast = &one_local;
|
||||
|
||||
// Set the schema to "row stored column panels" to indicate packing to
|
||||
// conventional row-stored column panels.
|
||||
pack_t schema = BLIS_PACKED_COL_PANELS;
|
||||
trans_t transc = BLIS_NO_TRANSPOSE;
|
||||
conj_t conjc = bli_extract_conj( transc );
|
||||
dim_t rs_b_reorder = 0;
|
||||
dim_t cs_b_reorder = 0;
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
@@ -170,9 +157,10 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
|
||||
dim_t n_threads = bli_rntm_num_threads( &rntm_g );
|
||||
n_threads = ( n_threads > 0 ) ? n_threads : 1;
|
||||
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
//When n == 1, B marix becomes a vector.
|
||||
//Reordering is avoided so that LPGEMV can process it efficiently.
|
||||
if(n == 1)
|
||||
if( ( n == 1 ) && ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
|
||||
{
|
||||
if(rs_b == 1)
|
||||
{
|
||||
@@ -186,6 +174,7 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
@@ -220,15 +209,9 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
|
||||
&nc0, &n_sub_updated
|
||||
);
|
||||
|
||||
// Compute the total number of iterations we'll need.
|
||||
dim_t n_iter = ( nc0 + NR - 1 ) / NR;
|
||||
|
||||
for ( dim_t pc = 0; pc < k; pc += KC )
|
||||
{
|
||||
dim_t kc0 = bli_min( ( k - pc ), KC );
|
||||
inc_t ps_p = kc0 * NR;
|
||||
|
||||
const float* b_temp = input_buf_addr + ( jc * cs_b ) + ( pc * rs_b );
|
||||
|
||||
// The offsets are calculated in such a way that it resembles
|
||||
// the reorder buffer traversal in single threaded reordering.
|
||||
@@ -265,34 +248,13 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
|
||||
// st = ( jc_cur_loop * k ) <traverse blocks 1,2,3,4>
|
||||
// + ( n_sub_updated * pc ) <traverse block 5>
|
||||
// + ( NC' * kc0_updated) <traverse block 6>
|
||||
float* p_temp = reorder_buf_addr + ( jc_cur_loop * k ) +
|
||||
( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 );
|
||||
|
||||
dim_t jr, it;
|
||||
// Iterate over every logical micropanel in the source matrix.
|
||||
for ( jr = 0, it = 0; it < n_iter; jr += NR, it += 1 )
|
||||
{
|
||||
dim_t panel_dim_i = bli_min( NR, nc0 - jr );
|
||||
|
||||
const float* b_use = b_temp + ( jr * cs_b );
|
||||
float* p_use = p_temp;
|
||||
|
||||
PASTEMAC(s,packm_cxk)
|
||||
(
|
||||
conjc,
|
||||
schema,
|
||||
panel_dim_i,
|
||||
NR,
|
||||
kc0,
|
||||
kc0,
|
||||
kappa_cast,
|
||||
( float* )b_use, cs_b, rs_b,
|
||||
p_use, rs_p,
|
||||
cntx
|
||||
);
|
||||
|
||||
p_temp += ps_p;
|
||||
}
|
||||
( ( lpgemm_pack_f32 )lcntx->packb_fun_ptr )
|
||||
(
|
||||
reorder_buf_addr + ( jc_cur_loop * k ) +
|
||||
( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 ),
|
||||
input_buf_addr + ( rs_b * pc ) + ( cs_b * jc ),
|
||||
rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder
|
||||
);
|
||||
}
|
||||
|
||||
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
|
||||
|
||||
@@ -41,6 +41,7 @@
|
||||
#define LPGEMM_BLKSZ_MAP_ZEN4 \
|
||||
XMACRO(U8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
|
||||
XMACRO(U8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(F32F32F32OF32, 192, 8064, 512, 6, 64, 1, 6, 64, 1) \
|
||||
XMACRO(BF16BF16F32OF32, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
XMACRO(BF16S4F32OF32, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
@@ -50,12 +51,27 @@
|
||||
#define LPGEMM_BLKSZ_MAP_ZEN \
|
||||
XMACRO(U8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
|
||||
XMACRO(U8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(F32F32F32OF32, 144, 8160, 512, 6, 16, 1, 6, 16, 1) \
|
||||
XMACRO(BF16BF16F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(S8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
|
||||
XMACRO(U8S4S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(BF16S4F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
|
||||
#define LPGEMM_BLKSZ_UPD_MAP_ZEN4_TO_ZEN \
|
||||
XMACRO(F32F32F32OF32, 144, 8160, 512, 6, 16, 1, 6, 16, 1) \
|
||||
|
||||
// The STMACRO follows the format MT, NT, KT which are SUP switch thresholds.
|
||||
// ID = One of the AOCL_OPERATION_TYPE enum.
|
||||
#define LPGEMM_SUP_THRES_MAP_ZEN4 \
|
||||
STMACRO(F32F32F32OF32, 682, 512, 240) \
|
||||
|
||||
#define LPGEMM_SUP_THRES_MAP_ZEN \
|
||||
STMACRO(F32F32F32OF32, 512, 200, 240) \
|
||||
|
||||
#define LPGEMM_SUP_THRES_UPD_MAP_ZEN4_TO_ZEN \
|
||||
STMACRO(F32F32F32OF32, 512, 200, 240) \
|
||||
|
||||
#define LPGEMM_ELTWISE_OPS_BLKSZ_MAP_ZEN4 \
|
||||
XMACRO(BF16OF32, 144, 1024, 2048, 6, 64) \
|
||||
|
||||
|
||||
@@ -55,6 +55,8 @@ static lpgemm_eltwise_ops_cntx_t
|
||||
global_eltwise_ops_cntx_t_list[AOCL_ELTWISE_OPS_OPERATION_TYPE_LEN] \
|
||||
__attribute__((aligned(64))); //Post-ops only utils without gemm.
|
||||
|
||||
static arch_t global_lpgemm_enable_arch = BLIS_ARCH_ERROR;
|
||||
|
||||
// This array is to store function pointers to jit generated kernels.
|
||||
static void* global_jit_kernels[ LPGEMM_BF16_MR ]
|
||||
[ ( LPGEMM_BF16_NR / NUM_F32_ELEMS_PER_ZMM ) + 1 ]
|
||||
@@ -67,6 +69,25 @@ static void* global_jit_kernels[ LPGEMM_BF16_MR ]
|
||||
|
||||
static bli_pthread_once_t once_check_lpgemm_func_map_init = BLIS_PTHREAD_ONCE_INIT;
|
||||
|
||||
static void _lpgemm_init_enable_arch()
|
||||
{
|
||||
arch_t arch_id = bli_arch_query_id();
|
||||
bool enbl_instr = bli_aocl_enable_instruction_query();
|
||||
|
||||
if ( ( enbl_instr == TRUE ) &&
|
||||
( ( arch_id == BLIS_ARCH_ZEN3 ) ||
|
||||
( arch_id == BLIS_ARCH_ZEN2 ) ||
|
||||
( arch_id == BLIS_ARCH_ZEN ) ) )
|
||||
{
|
||||
global_lpgemm_enable_arch = BLIS_ARCH_ZEN3;
|
||||
}
|
||||
}
|
||||
|
||||
arch_t lpgemm_get_enabled_arch()
|
||||
{
|
||||
return global_lpgemm_enable_arch;
|
||||
}
|
||||
|
||||
static void _lpgemm_util_cntx_init_func_map()
|
||||
{
|
||||
#define UMACRO(ID,FUNC_PTR) global_util_cntx_t_list[ID].kern_fun_ptr = FUNC_PTR;
|
||||
@@ -175,6 +196,13 @@ static void _lpgemm_cntx_init_func_map()
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// If arch is updated at runtime, it is expeceted to be honoured.
|
||||
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
|
||||
{
|
||||
LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
|
||||
LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
|
||||
}
|
||||
}
|
||||
else if ( bli_cpuid_is_avx512vnni_supported() == TRUE )
|
||||
{
|
||||
@@ -183,6 +211,12 @@ static void _lpgemm_cntx_init_func_map()
|
||||
LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI
|
||||
LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI
|
||||
#endif
|
||||
|
||||
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
|
||||
{
|
||||
LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2
|
||||
LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2;
|
||||
}
|
||||
}
|
||||
else if ( bli_cpuid_is_avx2fma3_supported() == TRUE )
|
||||
{
|
||||
@@ -263,6 +297,11 @@ static void _lpgemm_cntx_init_blksz_map()
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == TRUE )
|
||||
{
|
||||
LPGEMM_BLKSZ_MAP_ZEN4
|
||||
|
||||
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
|
||||
{
|
||||
LPGEMM_BLKSZ_UPD_MAP_ZEN4_TO_ZEN
|
||||
}
|
||||
}
|
||||
else if ( bli_cpuid_is_avx2fma3_supported() == TRUE )
|
||||
{
|
||||
@@ -276,6 +315,45 @@ static void _lpgemm_cntx_init_blksz_map()
|
||||
#undef XMACRO
|
||||
}
|
||||
|
||||
BLIS_INLINE void lpgemm_set_sup_thres_global_cntx
|
||||
(
|
||||
AOCL_OPERATION_TYPE op_type,
|
||||
dim_t MT,
|
||||
dim_t NT,
|
||||
dim_t KT
|
||||
)
|
||||
{
|
||||
global_cntx_t_list[op_type].sup_thres.MT = MT;
|
||||
global_cntx_t_list[op_type].sup_thres.NT = NT;
|
||||
global_cntx_t_list[op_type].sup_thres.KT = KT;
|
||||
}
|
||||
|
||||
static void _lpgemm_cntx_init_sup_thres_map()
|
||||
{
|
||||
#define STMACRO(ID,MT,NT,KT) \
|
||||
lpgemm_set_sup_thres_global_cntx(ID, MT, NT, KT); \
|
||||
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == TRUE )
|
||||
{
|
||||
LPGEMM_SUP_THRES_MAP_ZEN4
|
||||
|
||||
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
|
||||
{
|
||||
LPGEMM_SUP_THRES_UPD_MAP_ZEN4_TO_ZEN
|
||||
}
|
||||
}
|
||||
else if ( bli_cpuid_is_avx2fma3_supported() == TRUE )
|
||||
{
|
||||
LPGEMM_SUP_THRES_MAP_ZEN
|
||||
}
|
||||
else
|
||||
{
|
||||
LPGEMM_SUP_THRES_MAP_ZEN
|
||||
}
|
||||
|
||||
#undef STMACRO
|
||||
}
|
||||
|
||||
BLIS_INLINE void lpgemm_set_block_sizes_global_eltwise_ops_cntx
|
||||
(
|
||||
AOCL_ELTWISE_OPS_OPERATION_TYPE op_type,
|
||||
@@ -317,8 +395,10 @@ static void _lpgemm_eltwise_ops_cntx_init_blksz_map()
|
||||
|
||||
static void lpgemm_cntx_init_map()
|
||||
{
|
||||
_lpgemm_init_enable_arch();
|
||||
_lpgemm_cntx_init_func_map();
|
||||
_lpgemm_cntx_init_blksz_map();
|
||||
_lpgemm_cntx_init_sup_thres_map();
|
||||
_lpgemm_eltwise_ops_cntx_init_blksz_map();
|
||||
_lpgemm_eltwise_ops_cntx_init_func_map();
|
||||
_lpgemm_util_cntx_init_func_map();
|
||||
@@ -375,6 +455,21 @@ dim_t lpgemm_get_block_size_MR_global_cntx( AOCL_OPERATION_TYPE op_type )
|
||||
return global_cntx_t_list[op_type].blksz.MR;
|
||||
}
|
||||
|
||||
dim_t lpgemm_get_sup_thres_MT_global_cntx( AOCL_OPERATION_TYPE op_type )
|
||||
{
|
||||
return global_cntx_t_list[op_type].sup_thres.MT;
|
||||
}
|
||||
|
||||
dim_t lpgemm_get_sup_thres_NT_global_cntx( AOCL_OPERATION_TYPE op_type )
|
||||
{
|
||||
return global_cntx_t_list[op_type].sup_thres.NT;
|
||||
}
|
||||
|
||||
dim_t lpgemm_get_sup_thres_KT_global_cntx( AOCL_OPERATION_TYPE op_type )
|
||||
{
|
||||
return global_cntx_t_list[op_type].sup_thres.KT;
|
||||
}
|
||||
|
||||
void lpgemm_get_packa_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs )
|
||||
{
|
||||
*rs = lcntx->pack_s.packa_rs;
|
||||
|
||||
@@ -61,6 +61,14 @@ dim_t lpgemm_get_block_size_NR_global_cntx( AOCL_OPERATION_TYPE op_type );
|
||||
|
||||
dim_t lpgemm_get_block_size_MR_global_cntx( AOCL_OPERATION_TYPE op_type );
|
||||
|
||||
dim_t lpgemm_get_sup_thres_MT_global_cntx( AOCL_OPERATION_TYPE op_type );
|
||||
|
||||
dim_t lpgemm_get_sup_thres_NT_global_cntx( AOCL_OPERATION_TYPE op_type );
|
||||
|
||||
dim_t lpgemm_get_sup_thres_KT_global_cntx( AOCL_OPERATION_TYPE op_type );
|
||||
|
||||
arch_t lpgemm_get_enabled_arch();
|
||||
|
||||
void lpgemm_get_packa_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs );
|
||||
|
||||
void lpgemm_get_packb_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs );
|
||||
|
||||
@@ -44,7 +44,7 @@
|
||||
// TODO: Add reference kernels for BF16/VNNI kernels for ISA combinations
|
||||
// that is not supported.
|
||||
|
||||
// Genoa
|
||||
// AVX512 + VNNI + BF16
|
||||
#define LPGEMM_KERN_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
|
||||
KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \
|
||||
@@ -54,6 +54,9 @@
|
||||
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
|
||||
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
|
||||
|
||||
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2 \
|
||||
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
|
||||
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
|
||||
@@ -65,12 +68,16 @@
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
|
||||
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
|
||||
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
|
||||
PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \
|
||||
PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \
|
||||
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
|
||||
PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \
|
||||
PBMACRO(BF16S4F32OF32, packb_nr64_bf16s4f32of32)
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2 \
|
||||
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
|
||||
|
||||
#define LPGEMM_UNPACKB_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
UBMACRO(BF16BF16F32OF32, unpackb_nr64_bf16bf16f32of32)
|
||||
|
||||
@@ -93,6 +100,7 @@
|
||||
UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx512_kernel) \
|
||||
|
||||
|
||||
// AVX512 + VNNI
|
||||
#define LPGEMM_KERN_FUNC_MAP_AVX512_VNNI \
|
||||
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
|
||||
KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \
|
||||
@@ -102,6 +110,9 @@
|
||||
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
|
||||
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
|
||||
|
||||
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
|
||||
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI \
|
||||
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
|
||||
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
|
||||
@@ -113,18 +124,23 @@
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI \
|
||||
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
|
||||
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
|
||||
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
|
||||
PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \
|
||||
PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \
|
||||
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
|
||||
PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \
|
||||
PBSMACRO(BF16S4F32OF32, packb_nr64_bf16s4f32of32)
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
|
||||
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
|
||||
|
||||
#define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512_VNNI \
|
||||
UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \
|
||||
UMACRO(F32_GELU_ERF, lpgemm_util_f32_gelu_erf_avx512_kernel) \
|
||||
UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx512_kernel) \
|
||||
|
||||
|
||||
// AVX512
|
||||
#define LPGEMM_KERN_FUNC_MAP_AVX512 \
|
||||
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
|
||||
KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \
|
||||
@@ -134,6 +150,9 @@
|
||||
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
|
||||
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
|
||||
|
||||
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_TO_AVX2 \
|
||||
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_MAP_AVX512 \
|
||||
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
|
||||
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
|
||||
@@ -145,6 +164,7 @@
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX512 \
|
||||
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
|
||||
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
|
||||
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
|
||||
PBMACRO(BF16BF16F32OF32, NULL) \
|
||||
PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \
|
||||
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
|
||||
@@ -152,13 +172,16 @@
|
||||
PBMACRO(BF16S4F32OF32, NULL) \
|
||||
PBSMACRO(BF16S4F32OF32, NULL) \
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_TO_AVX2 \
|
||||
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
|
||||
|
||||
#define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512 \
|
||||
UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \
|
||||
UMACRO(F32_GELU_ERF, lpgemm_util_f32_gelu_erf_avx512_kernel) \
|
||||
UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx512_kernel) \
|
||||
|
||||
// Milan
|
||||
|
||||
// AVX2
|
||||
#define LPGEMM_KERN_FUNC_MAP_AVX2 \
|
||||
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
|
||||
KMACRO(U8S8S32OS32, NULL) \
|
||||
@@ -179,6 +202,7 @@
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX2 \
|
||||
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
|
||||
PBMACRO(U8S8S32OS32, NULL) \
|
||||
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
|
||||
PBMACRO(BF16BF16F32OF32, NULL) \
|
||||
KMACRO(BF16S4F32OF32, NULL) \
|
||||
PBMACRO(S8S8S32OS32, NULL) \
|
||||
|
||||
@@ -75,25 +75,9 @@ void lpgemm_pack_a_f32f32f32of32
|
||||
cntx_t* cntx
|
||||
);
|
||||
|
||||
void lpgemm_pack_b_f32f32f32of32
|
||||
(
|
||||
const float* input_buf_addr_b,
|
||||
float* reorder_buf_addr_b,
|
||||
const dim_t n,
|
||||
const dim_t k,
|
||||
const dim_t rs_b,
|
||||
const dim_t cs_b,
|
||||
const dim_t ps_p,
|
||||
const dim_t NR,
|
||||
cntx_t* cntx
|
||||
);
|
||||
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
LPGEMV(float, float, float, f32f32f32of32)
|
||||
{
|
||||
cntx_t *cntx = bli_gks_query_cntx();
|
||||
num_t dt = BLIS_FLOAT;
|
||||
|
||||
const float* a_use = (float*)a;
|
||||
inc_t rs_a_use = rs_a;
|
||||
inc_t cs_a_use = cs_a;
|
||||
@@ -101,21 +85,20 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
float* b_use = (float*)b;
|
||||
inc_t rs_b_use = rs_b;
|
||||
inc_t cs_b_use = cs_b;
|
||||
inc_t ps_b_use;
|
||||
|
||||
siz_t mem_a_size_req = 0;
|
||||
mem_t mem_a = BLIS_MEM_INITIALIZER;
|
||||
siz_t mem_b_size_req = 0;
|
||||
mem_t mem_b = BLIS_MEM_INITIALIZER;
|
||||
mem_t mem_a = BLIS_MEM_INITIALIZER;
|
||||
siz_t mem_b_size_req = 0;
|
||||
mem_t mem_b = BLIS_MEM_INITIALIZER;
|
||||
|
||||
float* pack_a_buffer_f32f32f32of32;
|
||||
float* pack_b_buffer_f32f32f32of32;
|
||||
|
||||
// Query the context for various blocksizes.
|
||||
const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_NR, cntx);
|
||||
const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_NC, cntx);
|
||||
const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_KC, cntx);
|
||||
const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt(dt, BLIS_KC, cntx);
|
||||
const dim_t NC = lcntx->blksz.NC;
|
||||
const dim_t KC = lcntx->blksz.KC;
|
||||
const dim_t MC = lcntx->blksz.MC;
|
||||
const dim_t NR = lcntx->blksz.NR;
|
||||
|
||||
// Strides are updated based on matrix packing/reordering.
|
||||
float *c_use = NULL;
|
||||
@@ -161,9 +144,9 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
|
||||
// Compute the IC loop thread range for the current thread.
|
||||
dim_t ic_start, ic_end;
|
||||
thread_ic.n_way = ( thread_ic.n_way == 1 ) ?
|
||||
( thread->n_threads ) : ( thread_ic.n_way );
|
||||
thread_ic.work_id = thread->tid;
|
||||
thread_ic.n_way = ( thread_ic.n_way == 1 ) ?
|
||||
( thread->n_threads ) : ( thread_ic.n_way );
|
||||
thread_ic.work_id = thread->tid;
|
||||
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
|
||||
|
||||
for (dim_t ic = ic_start; ic < ic_end; ic += MC)
|
||||
@@ -219,9 +202,9 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
{
|
||||
// Compute the JC loop thread range for the current thread.
|
||||
dim_t jc_start, jc_end;
|
||||
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
|
||||
( thread->n_threads ) : ( thread_jc.n_way );
|
||||
thread_jc.work_id = thread->tid;
|
||||
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
|
||||
( thread->n_threads ) : ( thread_jc.n_way );
|
||||
thread_jc.work_id = thread->tid;
|
||||
bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end);
|
||||
|
||||
if ( mtag_a == PACK )
|
||||
@@ -296,16 +279,13 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
// Set the strides for pack buffer.
|
||||
rs_b_use = NR;
|
||||
cs_b_use = 1;
|
||||
ps_b_use = kc0;
|
||||
|
||||
lpgemm_pack_b_f32f32f32of32
|
||||
(
|
||||
( b + ( rs_b * pc ) + ( cs_b * jc ) ),
|
||||
pack_b_buffer_f32f32f32of32 + ( n_sub_updated * pc ),
|
||||
nc0 , kc0,
|
||||
rs_b, cs_b, ( NR * ps_b_use ), NR,
|
||||
cntx
|
||||
);
|
||||
( ( lpgemm_pack_f32 )lcntx->packb_fun_ptr )
|
||||
(
|
||||
pack_b_buffer_f32f32f32of32 + ( n_sub_updated * pc ),
|
||||
b + ( rs_b * pc ) + ( cs_b * jc ),
|
||||
rs_b, cs_b, nc0, kc0, &rs_b_use, &cs_b_use
|
||||
);
|
||||
}
|
||||
b_use = pack_b_buffer_f32f32f32of32;
|
||||
}
|
||||
@@ -339,10 +319,10 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
} // jc loop
|
||||
|
||||
// Release pack buffers.
|
||||
if ( ( mtag_b == PACK ) && ( bli_mem_is_alloc( &mem_b ) ) )
|
||||
{
|
||||
bli_pba_release( rntm, &mem_b );
|
||||
}
|
||||
if ( ( mtag_b == PACK ) && ( bli_mem_is_alloc( &mem_b ) ) )
|
||||
{
|
||||
bli_pba_release( rntm, &mem_b );
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -352,7 +332,9 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
// Handle using LPGEMV when m or/and n equal to 1
|
||||
// The avx512 check will be removed when avx2 kernels added in future
|
||||
if ( ( ( m == 1 ) || ( n == 1 ) ) && (bli_cpuid_is_avx512_supported() == TRUE) )
|
||||
if ( ( ( m == 1 ) || ( n == 1 ) ) &&
|
||||
( bli_cpuid_is_avx512_supported() == TRUE ) &&
|
||||
( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
|
||||
{
|
||||
lpgemv_rowvar_f32f32f32of32(m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
@@ -371,14 +353,12 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
// Query the global cntx.
|
||||
cntx_t* cntx = bli_gks_query_cntx();
|
||||
|
||||
num_t dt = BLIS_FLOAT;
|
||||
|
||||
// Query the context for various blocksizes.
|
||||
const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx );
|
||||
const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx );
|
||||
const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx );
|
||||
const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx );
|
||||
const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx );
|
||||
const dim_t NC = lcntx->blksz.NC;
|
||||
const dim_t KC = lcntx->blksz.KC;
|
||||
const dim_t MC = lcntx->blksz.MC;
|
||||
const dim_t NR = lcntx->blksz.NR;
|
||||
const dim_t MR = lcntx->blksz.MR;
|
||||
|
||||
// Strides are updated based on matrix packing/reordering.
|
||||
const float* a_use = NULL;
|
||||
@@ -535,13 +515,12 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
if ( ( jc_packb_end > jc_packb_start ) &&
|
||||
( jc_packb_start < ( jc + nc0 ) ) )
|
||||
{
|
||||
lpgemm_pack_b_f32f32f32of32
|
||||
( ( lpgemm_pack_f32 )lcntx->packb_fun_ptr )
|
||||
(
|
||||
( b + ( rs_b * pc ) + ( cs_b * jc ) + ( cs_b * jc_packb_start ) ),
|
||||
pack_b_buffer_f32f32f32of32 + ( jc_packb_start * kc0 ),
|
||||
( jc_packb_end - jc_packb_start ), kc0,
|
||||
rs_b, cs_b, ( NR * ps_b_use ), NR,
|
||||
cntx
|
||||
b + ( rs_b * pc ) + ( cs_b * jc ) +
|
||||
( cs_b * jc_packb_start ),
|
||||
rs_b, cs_b, nc0, kc0, &rs_b_use, &cs_b_use
|
||||
);
|
||||
}
|
||||
|
||||
@@ -740,58 +719,3 @@ void lpgemm_pack_a_f32f32f32of32
|
||||
p_temp += ps_p;
|
||||
}
|
||||
}
|
||||
|
||||
void lpgemm_pack_b_f32f32f32of32
|
||||
(
|
||||
const float* input_buf_addr_b,
|
||||
float* reorder_buf_addr_b,
|
||||
const dim_t n,
|
||||
const dim_t k,
|
||||
const dim_t rs_b,
|
||||
const dim_t cs_b,
|
||||
const dim_t ps_p,
|
||||
const dim_t NR,
|
||||
cntx_t* cntx
|
||||
)
|
||||
{
|
||||
float one_local = *PASTEMAC(s,1);
|
||||
float* restrict kappa_cast = &one_local;
|
||||
|
||||
// Set the schema to "row stored column panels" to indicate packing to
|
||||
// conventional row-stored column panels.
|
||||
pack_t schema = BLIS_PACKED_COL_PANELS;
|
||||
trans_t transc = BLIS_NO_TRANSPOSE;
|
||||
conj_t conjc = bli_extract_conj( transc );
|
||||
// Compute the total number of iterations we'll need.
|
||||
dim_t n_iter = ( n + NR - 1 ) / NR;
|
||||
|
||||
inc_t rs_p = NR;
|
||||
|
||||
float* p_temp = reorder_buf_addr_b;
|
||||
|
||||
dim_t jr, it;
|
||||
// Iterate over every logical micropanel in the source matrix.
|
||||
for ( jr = 0, it = 0; it < n_iter; jr += NR, it += 1 )
|
||||
{
|
||||
dim_t panel_dim_i = bli_min( NR, n - jr );
|
||||
|
||||
const float* b_use = input_buf_addr_b + ( jr * cs_b );
|
||||
float* p_use = p_temp;
|
||||
|
||||
PASTEMAC(s,packm_cxk)
|
||||
(
|
||||
conjc,
|
||||
schema,
|
||||
panel_dim_i,
|
||||
NR,
|
||||
k,
|
||||
k,
|
||||
kappa_cast,
|
||||
( float* )b_use, cs_b, rs_b,
|
||||
p_use, rs_p,
|
||||
cntx
|
||||
);
|
||||
|
||||
p_temp += ps_p;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ typedef enum
|
||||
BF16BF16F32OF32 = 3, // bf16 - A, bf16 - B, float - C
|
||||
S8S8S32OS32 = 4, // int8_t - A, int8_t - B, int32_t - C
|
||||
S8S8S16OS16 = 5, // int8_t - A, int8_t - B, int16_t - C
|
||||
U8S4S32OS32 = 6, // Only used for reordering int4_t B matrix.
|
||||
U8S4S32OS32 = 6, // Only used for reordering int4_t B matrix.
|
||||
BF16S4F32OF32 = 7 // Only used for reordering int4_t B matrix.
|
||||
} AOCL_OPERATION_TYPE;
|
||||
#define AOCL_OPERATION_TYPE_LEN 8
|
||||
@@ -148,6 +148,13 @@ typedef struct
|
||||
dim_t packb_cs;
|
||||
} lpgemm_pack_strides_t;
|
||||
|
||||
typedef struct
|
||||
{
|
||||
dim_t MT;
|
||||
dim_t NT;
|
||||
dim_t KT;
|
||||
} lpgemm_sup_thres_t;
|
||||
|
||||
typedef struct
|
||||
{
|
||||
lpgemm_block_size_t blksz;
|
||||
@@ -157,6 +164,7 @@ typedef struct
|
||||
void_fp unpackb_fun_ptr;
|
||||
void_fp packsclb_fun_ptr;
|
||||
lpgemm_pack_strides_t pack_s;
|
||||
lpgemm_sup_thres_t sup_thres;
|
||||
} lpgemm_cntx_t;
|
||||
|
||||
typedef struct
|
||||
|
||||
@@ -503,7 +503,6 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
|
||||
}
|
||||
else if ( ( *n_threads ) > 1 )
|
||||
{
|
||||
|
||||
dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 );
|
||||
dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 );
|
||||
dim_t mr_blks = ( m + MR - 1 ) / MR;
|
||||
@@ -558,22 +557,17 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
|
||||
rntm_t* rntm_g
|
||||
)
|
||||
{
|
||||
// Query the global cntx.
|
||||
cntx_t* cntx = bli_gks_query_cntx();
|
||||
|
||||
num_t dt = BLIS_FLOAT;
|
||||
|
||||
// Query the context for SUP limits.
|
||||
const dim_t MT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx );
|
||||
const dim_t NT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx );
|
||||
const dim_t KT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx );
|
||||
const dim_t MT = lpgemm_get_sup_thres_MT_global_cntx( F32F32F32OF32 );
|
||||
const dim_t NT = lpgemm_get_sup_thres_NT_global_cntx( F32F32F32OF32 );
|
||||
const dim_t KT = lpgemm_get_sup_thres_KT_global_cntx( F32F32F32OF32 );
|
||||
|
||||
// Query the context for various blocksizes.
|
||||
const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx );
|
||||
const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx );
|
||||
const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx );
|
||||
const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx );
|
||||
const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx );
|
||||
dim_t NR = lpgemm_get_block_size_NR_global_cntx( F32F32F32OF32 );
|
||||
dim_t MR = lpgemm_get_block_size_MR_global_cntx( F32F32F32OF32 );
|
||||
dim_t MC = lpgemm_get_block_size_MC_global_cntx( F32F32F32OF32 );
|
||||
dim_t NC = lpgemm_get_block_size_NC_global_cntx( F32F32F32OF32 );
|
||||
dim_t KC = lpgemm_get_block_size_KC_global_cntx( F32F32F32OF32 );
|
||||
|
||||
const dim_t MT_2 = MT / 2;
|
||||
|
||||
@@ -640,7 +634,7 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
|
||||
|
||||
if ( ( m >= MT ) && ( n >= NT ) && ( k >= KT ) )
|
||||
{
|
||||
if (((k <= page_size_b_floatx2) && (m_ic > MT_2) && (n_jc >= NT)) ||
|
||||
if (((k >= page_size_b_floatx2) && (m_ic > MT_2) && (n_jc >= NT)) ||
|
||||
((bli_cpuid_is_avx512_supported() == FALSE) && (k > page_size_b_floatx2)))
|
||||
{
|
||||
bli_rntm_set_pack_b( 1, rntm_g );
|
||||
|
||||
@@ -31,8 +31,8 @@
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
#ifndef BLIS_GEMM_F32_PACKA
|
||||
#define BLIS_GEMM_F32_PACKA
|
||||
#ifndef BLIS_GEMM_F32_PACKAB
|
||||
#define BLIS_GEMM_F32_PACKAB
|
||||
|
||||
void packa_mr16_f32f32f32of32_col_major
|
||||
(
|
||||
@@ -45,6 +45,43 @@ void packa_mr16_f32f32f32of32_col_major
|
||||
dim_t* rs_p,
|
||||
dim_t* cs_p
|
||||
);
|
||||
#endif
|
||||
|
||||
typedef void (*lpgemm_pack_f32)
|
||||
(
|
||||
float*,
|
||||
const float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
dim_t*,
|
||||
dim_t*
|
||||
);
|
||||
|
||||
void packb_nr64_f32f32f32of32
|
||||
(
|
||||
float* pack_b_buffer,
|
||||
const float* b,
|
||||
const dim_t rs_b,
|
||||
const dim_t cs_b,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_p,
|
||||
dim_t* cs_p
|
||||
);
|
||||
|
||||
void packb_nr16_f32f32f32of32
|
||||
(
|
||||
float* pack_b_buffer,
|
||||
const float* b,
|
||||
const dim_t rs_b,
|
||||
const dim_t cs_b,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_p,
|
||||
dim_t* cs_p
|
||||
);
|
||||
|
||||
#endif //BLIS_GEMM_F32_PACKAB
|
||||
|
||||
|
||||
|
||||
@@ -1313,7 +1313,7 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
if ( global_dscale_out == 'y' || global_can_dscale == 'y') \
|
||||
if ( ( global_dscale_out == 'y' ) || ( global_can_dscale == 'y' ) ) \
|
||||
{ \
|
||||
post_ops->seq_vector[cur_op_index] = SCALE; \
|
||||
cur_op_index++; \
|
||||
|
||||
@@ -530,16 +530,16 @@ POST_OPS_DOWNSCALE_5x16F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 4 );
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
}
|
||||
//c[0, 0-7]
|
||||
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
|
||||
@@ -1096,14 +1096,14 @@ POST_OPS_DOWNSCALE_4x16F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
//c[0, 0-7]
|
||||
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
|
||||
@@ -1574,12 +1574,12 @@ POST_OPS_DOWNSCALE_3x16F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
}
|
||||
//c[0, 0-7]
|
||||
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
|
||||
@@ -1959,10 +1959,10 @@ POST_OPS_DOWNSCALE_2x16F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
}
|
||||
//c[0, 0-7]
|
||||
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
|
||||
@@ -2262,8 +2262,8 @@ POST_OPS_DOWNSCALE_1x16F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
//c[0, 0-7]
|
||||
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
|
||||
|
||||
@@ -2666,16 +2666,16 @@ POST_OPS_DOWNSCALE_5x8F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 4 );
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
}
|
||||
//c[0, 0-7]
|
||||
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
|
||||
@@ -3083,14 +3083,14 @@ POST_OPS_DOWNSCALE_4x8F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
//c[0, 0-7]
|
||||
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
|
||||
@@ -3444,12 +3444,12 @@ POST_OPS_DOWNSCALE_3x8F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
}
|
||||
//c[0, 0-7]
|
||||
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
|
||||
@@ -3750,10 +3750,10 @@ POST_OPS_DOWNSCALE_2x8F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
}
|
||||
//c[0, 0-7]
|
||||
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
|
||||
@@ -3996,8 +3996,8 @@ POST_OPS_DOWNSCALE_1x8F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
}
|
||||
//c[0, 0-7]
|
||||
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
|
||||
@@ -6017,13 +6017,13 @@ POST_OPS_DOWNSCALE_5x2F:
|
||||
{
|
||||
if( post_ops_list_temp->scale_factor_len > 1 )
|
||||
{
|
||||
selector0 = ( __m128 )_mm_load_sd( (const double*)( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
|
||||
selector0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = ( __m128 )_mm_load_sd( (const double*)(float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
|
||||
zero_point0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
|
||||
}
|
||||
//c[0, 0-3]
|
||||
F32_SCL_MULRND_SSE(xmm4, selector0, zero_point0);
|
||||
@@ -6429,13 +6429,13 @@ POST_OPS_DOWNSCALE_4x2F:
|
||||
{
|
||||
if( post_ops_list_temp->scale_factor_len > 1 )
|
||||
{
|
||||
selector0 = ( __m128 )_mm_load_sd( (const double*)( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
|
||||
selector0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = ( __m128 )_mm_load_sd( (const double*)(float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
|
||||
zero_point0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
|
||||
}
|
||||
//c[0, 0-3]
|
||||
F32_SCL_MULRND_SSE(xmm4, selector0, zero_point0);
|
||||
@@ -6785,13 +6785,13 @@ POST_OPS_DOWNSCALE_3x2F:
|
||||
{
|
||||
if( post_ops_list_temp->scale_factor_len > 1 )
|
||||
{
|
||||
selector0 = ( __m128 )_mm_load_sd( (const double*)( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
|
||||
selector0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = ( __m128 )_mm_load_sd( (const double*)(float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
|
||||
zero_point0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
|
||||
}
|
||||
//c[0, 0-3]
|
||||
F32_SCL_MULRND_SSE(xmm4, selector0, zero_point0);
|
||||
@@ -7090,13 +7090,13 @@ POST_OPS_DOWNSCALE_2x2F:
|
||||
{
|
||||
if( post_ops_list_temp->scale_factor_len > 1 )
|
||||
{
|
||||
selector0 = _mm_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 4) );
|
||||
selector0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 4 ) );
|
||||
zero_point0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
|
||||
}
|
||||
//c[0, 0-3]
|
||||
F32_SCL_MULRND_SSE(xmm4, selector0, zero_point0);
|
||||
@@ -7335,13 +7335,13 @@ POST_OPS_DOWNSCALE_1x2F:
|
||||
{
|
||||
if( post_ops_list_temp->scale_factor_len > 1 )
|
||||
{
|
||||
selector0 = ( __m128 )_mm_load_sd( (const double*)( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
|
||||
selector0 = ( __m128 )_mm_load_sd( (const double*)( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = ( __m128 )_mm_load_sd( (const double*)(float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) );
|
||||
zero_point0 = ( __m128 )_mm_load_sd( (const double*)( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 8 ) ) );
|
||||
}
|
||||
//c[0, 0-3]
|
||||
F32_SCL_MULRND_SSE(xmm4, selector0, zero_point0);
|
||||
@@ -7698,12 +7698,12 @@ POST_OPS_DOWNSCALE_5x1F:
|
||||
{
|
||||
if( post_ops_list_temp->scale_factor_len > 1 )
|
||||
{
|
||||
selector0 = _mm_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
|
||||
selector0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 4) );
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
zero_point0 = ( __m128 )_mm_load_ss( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 4 ) );
|
||||
}
|
||||
//c[0, 0-3]
|
||||
@@ -8109,12 +8109,12 @@ POST_OPS_DOWNSCALE_4x1F:
|
||||
{
|
||||
if( post_ops_list_temp->scale_factor_len > 1 )
|
||||
{
|
||||
selector0 = _mm_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
|
||||
selector0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 4) );
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
zero_point0 = ( __m128 )_mm_load_ss( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 4 ) );
|
||||
}
|
||||
//c[0, 0-3]
|
||||
@@ -8464,12 +8464,12 @@ POST_OPS_DOWNSCALE_3x1F:
|
||||
{
|
||||
if( post_ops_list_temp->scale_factor_len > 1 )
|
||||
{
|
||||
selector0 = _mm_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
|
||||
selector0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 4) );
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
zero_point0 = ( __m128 )_mm_load_ss( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 4 ) );
|
||||
}
|
||||
//c[0, 0-3]
|
||||
@@ -8769,12 +8769,12 @@ POST_OPS_DOWNSCALE_2x1F:
|
||||
{
|
||||
if( post_ops_list_temp->scale_factor_len > 1 )
|
||||
{
|
||||
selector0 = _mm_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
|
||||
selector0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 4) );
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
zero_point0 = ( __m128 )_mm_load_ss( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 4 ) );
|
||||
}
|
||||
//c[0, 0-3]
|
||||
@@ -9013,12 +9013,12 @@ POST_OPS_DOWNSCALE_1x1F:
|
||||
{
|
||||
if( post_ops_list_temp->scale_factor_len > 1 )
|
||||
{
|
||||
selector0 = _mm_loadu_ps( ( float* )post_ops_list_temp->scale_factor +
|
||||
selector0 = ( __m128 )_mm_load_ss( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 4) );
|
||||
}
|
||||
if( *( (dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm_loadu_ps( (float* )post_ops_list_temp->op_args1 +
|
||||
zero_point0 = ( __m128 )_mm_load_ss( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 4 ) );
|
||||
}
|
||||
//c[0, 0-3]
|
||||
|
||||
@@ -689,18 +689,18 @@ POST_OPS_DOWNSCALE_6x16F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 4 );
|
||||
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 5 );
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 5 ) );
|
||||
}
|
||||
//c[0, 0-7]
|
||||
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
|
||||
@@ -1299,18 +1299,18 @@ POST_OPS_DOWNSCALE_6x8F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 4 );
|
||||
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 5 );
|
||||
zero_point0 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
zero_point4 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
zero_point5 = _mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 5 ) );
|
||||
}
|
||||
//c[0, 0-7]
|
||||
F32_SCL_MULRND_AVX2(ymm4, selector1, zero_point0);
|
||||
|
||||
367
kernels/zen/lpgemm/f32f32f32/lpgemm_pack_b_f32_avx2.c
Normal file
367
kernels/zen/lpgemm/f32f32f32/lpgemm_pack_b_f32_avx2.c
Normal file
@@ -0,0 +1,367 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binarsy form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <string.h>
|
||||
#include "blis.h"
|
||||
|
||||
#ifdef BLIS_ADDON_LPGEMM
|
||||
|
||||
void packb_nr16_f32f32f32of32_row_major
|
||||
(
|
||||
float* pack_b_buffer,
|
||||
const float* b,
|
||||
const dim_t ldb,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_b,
|
||||
dim_t* cs_b
|
||||
);
|
||||
|
||||
void packb_nr16_f32f32f32of32_col_major
|
||||
(
|
||||
float* pack_b_buffer,
|
||||
const float* b,
|
||||
const dim_t ldb,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_b,
|
||||
dim_t* cs_b
|
||||
);
|
||||
|
||||
void packb_nr16_f32f32f32of32
|
||||
(
|
||||
float* pack_b_buffer,
|
||||
const float* b,
|
||||
const dim_t rs_b,
|
||||
const dim_t cs_b,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_p,
|
||||
dim_t* cs_p
|
||||
)
|
||||
{
|
||||
if( cs_b == 1 )
|
||||
{
|
||||
packb_nr16_f32f32f32of32_row_major( pack_b_buffer, b,
|
||||
rs_b, NC, KC, rs_p, cs_p );
|
||||
}
|
||||
else
|
||||
{
|
||||
packb_nr16_f32f32f32of32_col_major( pack_b_buffer, b,
|
||||
cs_b, NC, KC, rs_p, cs_p );
|
||||
}
|
||||
}
|
||||
|
||||
void packb_nr16_f32f32f32of32_row_major
|
||||
(
|
||||
float* pack_b_buffer,
|
||||
const float* b,
|
||||
const dim_t ldb,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_b,
|
||||
dim_t* cs_b
|
||||
)
|
||||
{
|
||||
dim_t NR = 16;
|
||||
|
||||
__m256 a0;
|
||||
__m256 b0;
|
||||
|
||||
dim_t n_full_pieces_loop_limit = ( NC / NR ) * NR;
|
||||
dim_t n_partial_pieces = NC % NR;
|
||||
|
||||
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
|
||||
{
|
||||
for ( dim_t kr = 0; kr < KC; kr += 1 )
|
||||
{
|
||||
a0 = _mm256_loadu_ps( b + ( jc + 0 ) + ( ldb * kr ) );
|
||||
b0 = _mm256_loadu_ps( b + ( jc + 8 ) + ( ldb * kr ) );
|
||||
|
||||
//store to pack_b buffer
|
||||
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) +
|
||||
( ( kr * NR ) + 0 ), a0 );
|
||||
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) +
|
||||
( ( kr * NR ) + 8 ), b0 );
|
||||
}
|
||||
}
|
||||
if( n_partial_pieces > 0 )
|
||||
{
|
||||
for ( dim_t kr = 0; kr < KC; kr += 1 )
|
||||
{
|
||||
// No point in vectorizing fringe case since n_fringe is expected
|
||||
// to be laid out contiguously in pack buffer as nr_fringe*KC
|
||||
// instead of 8*KC + 4*KC + .., etc.
|
||||
memcpy( pack_b_buffer + ( n_full_pieces_loop_limit * KC ) +
|
||||
( ( kr * NR ) + 0 ),
|
||||
b + ( n_full_pieces_loop_limit + 0 ) + ( ldb * kr ),
|
||||
n_partial_pieces * ( sizeof( float ) ) );
|
||||
|
||||
// Zero out padding data.
|
||||
memset( pack_b_buffer + ( n_full_pieces_loop_limit * KC ) +
|
||||
( ( kr * NR ) + n_partial_pieces ),
|
||||
0, ( NR - n_partial_pieces ) * sizeof( float ) );
|
||||
}
|
||||
}
|
||||
|
||||
*rs_b = NR;
|
||||
*cs_b = 1;
|
||||
}
|
||||
|
||||
#define LOAD_PS_8x8() \
|
||||
a_reg[0] = _mm256_loadu_ps( b + ( ldb * ( jr + 0 ) ) + ( kr ) ); \
|
||||
a_reg[1] = _mm256_loadu_ps( b + ( ldb * ( jr + 1 ) ) + ( kr ) ); \
|
||||
a_reg[2] = _mm256_loadu_ps( b + ( ldb * ( jr + 2 ) ) + ( kr ) ); \
|
||||
a_reg[3] = _mm256_loadu_ps( b + ( ldb * ( jr + 3 ) ) + ( kr ) ); \
|
||||
a_reg[4] = _mm256_loadu_ps( b + ( ldb * ( jr + 4 ) ) + ( kr ) ); \
|
||||
a_reg[5] = _mm256_loadu_ps( b + ( ldb * ( jr + 5 ) ) + ( kr ) ); \
|
||||
a_reg[6] = _mm256_loadu_ps( b + ( ldb * ( jr + 6 ) ) + ( kr ) ); \
|
||||
a_reg[7] = _mm256_loadu_ps( b + ( ldb * ( jr + 7 ) ) + ( kr ) ); \
|
||||
|
||||
#define K_FRINGE_MEMCPY_LOAD_PS_8x8() \
|
||||
memcpy( buf0, b + ( ldb * ( jr + 0 ) ) + ( kr ), \
|
||||
k_partial_pieces * sizeof( float ) ); \
|
||||
a_reg[0] = _mm256_loadu_ps( buf0 ); \
|
||||
memcpy( buf1, b + ( ldb * ( jr + 1 ) ) + ( kr ), \
|
||||
k_partial_pieces * sizeof( float ) ); \
|
||||
a_reg[1] = _mm256_loadu_ps( buf1 ); \
|
||||
memcpy( buf2, b + ( ldb * ( jr + 2 ) ) + ( kr ), \
|
||||
k_partial_pieces * sizeof( float ) ); \
|
||||
a_reg[2] = _mm256_loadu_ps( buf2 ); \
|
||||
memcpy( buf3, b + ( ldb * ( jr + 3 ) ) + ( kr ), \
|
||||
k_partial_pieces * sizeof( float ) ); \
|
||||
a_reg[3] = _mm256_loadu_ps( buf3 ); \
|
||||
memcpy( buf4, b + ( ldb * ( jr + 4 ) ) + ( kr ), \
|
||||
k_partial_pieces * sizeof( float ) ); \
|
||||
a_reg[4] = _mm256_loadu_ps( buf4 ); \
|
||||
memcpy( buf5, b + ( ldb * ( jr + 5 ) ) + ( kr ), \
|
||||
k_partial_pieces * sizeof( float ) ); \
|
||||
a_reg[5] = _mm256_loadu_ps( buf5 ); \
|
||||
memcpy( buf6, b + ( ldb * ( jr + 6 ) ) + ( kr ), \
|
||||
k_partial_pieces * sizeof( float ) ); \
|
||||
a_reg[6] = _mm256_loadu_ps( buf6 ); \
|
||||
memcpy( buf7, b + ( ldb * ( jr + 7 ) ) + ( kr ), \
|
||||
k_partial_pieces * sizeof( float ) ); \
|
||||
a_reg[7] = _mm256_loadu_ps( buf7 ); \
|
||||
|
||||
#define N_FRINGE_LOAD_PS_8x8() \
|
||||
for ( int i = 0; i < jr_elems; ++i ) \
|
||||
{ \
|
||||
a_reg[i] = _mm256_loadu_ps( b + ( ldb * ( jr + i ) ) + ( kr ) ); \
|
||||
} \
|
||||
for ( int i = jr_elems; i < n_sub_blk_wdth; ++i ) \
|
||||
{ \
|
||||
a_reg[i] = _mm256_setzero_ps(); \
|
||||
} \
|
||||
|
||||
#define KN_FRINGE_MEMCPY_LOAD_PS_8x8() \
|
||||
for ( int i = 0; i < jr_elems; ++i ) \
|
||||
{ \
|
||||
memcpy( buf0, b + ( ldb * ( jr + i ) ) + ( kr ), \
|
||||
k_partial_pieces * sizeof( float ) ); \
|
||||
a_reg[i] = _mm256_loadu_ps( buf0 ); \
|
||||
} \
|
||||
for ( int i = jr_elems; i < n_sub_blk_wdth; ++i ) \
|
||||
{ \
|
||||
a_reg[i] = _mm256_setzero_ps(); \
|
||||
} \
|
||||
|
||||
#define UNPACK_PS_8x8() \
|
||||
/* Even indices contains lo parts, odd indices contains hi parts. */ \
|
||||
b_reg[0] = _mm256_unpacklo_ps( a_reg[0], a_reg[1] ); \
|
||||
b_reg[1] = _mm256_unpackhi_ps( a_reg[0], a_reg[1] ); \
|
||||
b_reg[2] = _mm256_unpacklo_ps( a_reg[2], a_reg[3] ); \
|
||||
b_reg[3] = _mm256_unpackhi_ps( a_reg[2], a_reg[3] ); \
|
||||
b_reg[4] = _mm256_unpacklo_ps( a_reg[4], a_reg[5] ); \
|
||||
b_reg[5] = _mm256_unpackhi_ps( a_reg[4], a_reg[5] ); \
|
||||
b_reg[6] = _mm256_unpacklo_ps( a_reg[6], a_reg[7] ); \
|
||||
b_reg[7] = _mm256_unpackhi_ps( a_reg[6], a_reg[7] ); \
|
||||
|
||||
#define UNPACK_PD_8x8() \
|
||||
/* Even indices contains lo parts, odd indices contains hi parts. */ \
|
||||
a_reg[0] = ( __m256 )_mm256_unpacklo_pd( ( __m256d )b_reg[0], \
|
||||
( __m256d )b_reg[2] ); \
|
||||
a_reg[1] = ( __m256 )_mm256_unpackhi_pd( ( __m256d )b_reg[0], \
|
||||
( __m256d )b_reg[2] ); \
|
||||
a_reg[2] = ( __m256 )_mm256_unpacklo_pd( ( __m256d )b_reg[4], \
|
||||
( __m256d )b_reg[6] ); \
|
||||
a_reg[3] = ( __m256 )_mm256_unpackhi_pd( ( __m256d )b_reg[4], \
|
||||
( __m256d )b_reg[6] ); \
|
||||
a_reg[4] = ( __m256 )_mm256_unpacklo_pd( ( __m256d )b_reg[1], \
|
||||
( __m256d )b_reg[3] ); \
|
||||
a_reg[5] = ( __m256 )_mm256_unpackhi_pd( ( __m256d )b_reg[1], \
|
||||
( __m256d )b_reg[3] ); \
|
||||
a_reg[6] = ( __m256 )_mm256_unpacklo_pd( ( __m256d )b_reg[5], \
|
||||
( __m256d )b_reg[7] ); \
|
||||
a_reg[7] = ( __m256 )_mm256_unpackhi_pd( ( __m256d )b_reg[5], \
|
||||
( __m256d )b_reg[7] ); \
|
||||
|
||||
#define PERMUTE_R1_8x8() \
|
||||
/* Even indices contains lo parts, odd indices contains hi parts. */ \
|
||||
b_reg[0] = _mm256_permute2f128_ps( a_reg[0], a_reg[2], 0x20 ); /* Row 0 */ \
|
||||
b_reg[1] = _mm256_permute2f128_ps( a_reg[0], a_reg[2], 0x31 ); /* Row 4 */ \
|
||||
b_reg[2] = _mm256_permute2f128_ps( a_reg[4], a_reg[6], 0x20 ); /* Row 2 */ \
|
||||
b_reg[3] = _mm256_permute2f128_ps( a_reg[4], a_reg[6], 0x31 ); /* Row 6 */ \
|
||||
b_reg[4] = _mm256_permute2f128_ps( a_reg[1], a_reg[3], 0x20 ); /* Row 1 */ \
|
||||
b_reg[5] = _mm256_permute2f128_ps( a_reg[1], a_reg[3], 0x31 ); /* Row 5 */ \
|
||||
b_reg[6] = _mm256_permute2f128_ps( a_reg[5], a_reg[7], 0x20 ); /* Row 3 */ \
|
||||
b_reg[7] = _mm256_permute2f128_ps( a_reg[5], a_reg[7], 0x31 ); /* Row 7 */ \
|
||||
|
||||
#define STORE_PS_8x8() \
|
||||
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 0 ) * NR ) + jr_offset, \
|
||||
b_reg[0] ); \
|
||||
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 1 ) * NR ) + jr_offset, \
|
||||
b_reg[4] ); \
|
||||
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 2 ) * NR ) + jr_offset, \
|
||||
b_reg[2] ); \
|
||||
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 3 ) * NR ) + jr_offset, \
|
||||
b_reg[6] ); \
|
||||
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 4 ) * NR ) + jr_offset, \
|
||||
b_reg[1] ); \
|
||||
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 5 ) * NR ) + jr_offset, \
|
||||
b_reg[5] ); \
|
||||
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 6 ) * NR ) + jr_offset, \
|
||||
b_reg[3] ); \
|
||||
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 7 ) * NR ) + jr_offset, \
|
||||
b_reg[7] ); \
|
||||
|
||||
#define K_FRINGE_SAFE_STORE_PS_8x8() \
|
||||
a_reg[0] = b_reg[0]; \
|
||||
a_reg[1] = b_reg[4]; \
|
||||
a_reg[2] = b_reg[2]; \
|
||||
a_reg[3] = b_reg[6]; \
|
||||
a_reg[4] = b_reg[1]; \
|
||||
a_reg[5] = b_reg[5]; \
|
||||
a_reg[6] = b_reg[3]; \
|
||||
a_reg[7] = b_reg[7]; \
|
||||
for (int i = 0; i < k_partial_pieces; ++i ) \
|
||||
{ \
|
||||
_mm256_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + i ) * NR ) + jr_offset, \
|
||||
a_reg[i] ); \
|
||||
} \
|
||||
|
||||
void packb_nr16_f32f32f32of32_col_major
|
||||
(
|
||||
float* pack_b_buffer,
|
||||
const float* b,
|
||||
const dim_t ldb,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_b,
|
||||
dim_t* cs_b
|
||||
)
|
||||
{
|
||||
const dim_t NR = 16;
|
||||
const dim_t n_sub_blk_wdth = 8;
|
||||
const dim_t k_reg_size = 8;
|
||||
|
||||
float buf0[8] = { 0 };
|
||||
float buf1[8] = { 0 };
|
||||
float buf2[8] = { 0 };
|
||||
float buf3[8] = { 0 };
|
||||
float buf4[8] = { 0 };
|
||||
float buf5[8] = { 0 };
|
||||
float buf6[8] = { 0 };
|
||||
float buf7[8] = { 0 };
|
||||
|
||||
__m256 a_reg[8];
|
||||
__m256 b_reg[8];
|
||||
|
||||
dim_t n_full_pieces_loop_limit = ( NC / NR ) * NR;
|
||||
dim_t n_partial_pieces = NC % NR;
|
||||
|
||||
dim_t k_full_pieces_loop_limit = ( KC / k_reg_size ) * k_reg_size;
|
||||
dim_t k_partial_pieces = KC % k_reg_size;
|
||||
|
||||
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
|
||||
{
|
||||
for ( dim_t jr = jc; jr < jc + NR; jr += n_sub_blk_wdth )
|
||||
{
|
||||
dim_t jr_offset = jr % NR;
|
||||
for ( dim_t kr = 0; kr < k_full_pieces_loop_limit; kr += k_reg_size )
|
||||
{
|
||||
LOAD_PS_8x8();
|
||||
UNPACK_PS_8x8();
|
||||
UNPACK_PD_8x8();
|
||||
PERMUTE_R1_8x8();
|
||||
STORE_PS_8x8();
|
||||
}
|
||||
if ( k_partial_pieces > 0 )
|
||||
{
|
||||
dim_t kr = k_full_pieces_loop_limit;
|
||||
|
||||
K_FRINGE_MEMCPY_LOAD_PS_8x8();
|
||||
UNPACK_PS_8x8();
|
||||
UNPACK_PD_8x8();
|
||||
PERMUTE_R1_8x8();
|
||||
K_FRINGE_SAFE_STORE_PS_8x8();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if( n_partial_pieces > 0 )
|
||||
{
|
||||
dim_t jc = n_full_pieces_loop_limit;
|
||||
for ( dim_t jr = n_full_pieces_loop_limit; jr < NC; jr += n_sub_blk_wdth )
|
||||
{
|
||||
dim_t jr_offset = jr % NR;
|
||||
dim_t jr_elems = ( ( NC - jr ) >= n_sub_blk_wdth ) ? n_sub_blk_wdth : ( NC - jr );
|
||||
|
||||
for ( dim_t kr = 0; kr < k_full_pieces_loop_limit; kr += k_reg_size )
|
||||
{
|
||||
N_FRINGE_LOAD_PS_8x8();
|
||||
UNPACK_PS_8x8();
|
||||
UNPACK_PD_8x8();
|
||||
PERMUTE_R1_8x8();
|
||||
STORE_PS_8x8();
|
||||
}
|
||||
if ( k_partial_pieces > 0 )
|
||||
{
|
||||
dim_t kr = k_full_pieces_loop_limit;
|
||||
|
||||
KN_FRINGE_MEMCPY_LOAD_PS_8x8();
|
||||
UNPACK_PS_8x8();
|
||||
UNPACK_PD_8x8();
|
||||
PERMUTE_R1_8x8();
|
||||
K_FRINGE_SAFE_STORE_PS_8x8();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*rs_b = NR;
|
||||
*cs_b = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -322,7 +322,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64)
|
||||
post_ops_list, post_ops_attr
|
||||
);
|
||||
|
||||
// No leftover fringe after this podint.
|
||||
// No leftover fringe after this point.
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -125,13 +125,6 @@ void unpackb_nr32_bf16bf16f32of32_row_major
|
||||
dim_t k_full_pieces = k_full_pieces_blks * 2;
|
||||
dim_t k_partial_pieces = KC % 2;
|
||||
|
||||
// KC when not multiple of 2 will have padding to make it multiple of 2 in packed buffer.
|
||||
dim_t KC_updated = KC;
|
||||
if ( k_partial_pieces > 0 )
|
||||
{
|
||||
KC_updated += ( 2 - k_partial_pieces );
|
||||
}
|
||||
|
||||
__m512i a0, c0;
|
||||
__m512i a01;
|
||||
|
||||
|
||||
@@ -839,14 +839,14 @@ POST_OPS_DOWNSCALE_5x64F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -903,8 +903,8 @@ POST_OPS_DOWNSCALE_5x64F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 4 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
}
|
||||
//c[4, 0-15]
|
||||
F32_SCL_MULRND(zmm24, selector1, zero_point0);
|
||||
@@ -1745,14 +1745,14 @@ POST_OPS_DOWNSCALE_4x64F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -2485,12 +2485,12 @@ POST_OPS_DOWNSCALE_3x64F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -3066,10 +3066,10 @@ POST_OPS_DOWNSCALE_2x64F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -3484,8 +3484,8 @@ POST_OPS_DOWNSCALE_1x64F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -4205,14 +4205,14 @@ POST_OPS_DOWNSCALE_5x48F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -4257,8 +4257,8 @@ POST_OPS_DOWNSCALE_5x48F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 4 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
}
|
||||
//c[4, 0-15]
|
||||
F32_SCL_MULRND(zmm24, selector1, zero_point0);
|
||||
@@ -4955,14 +4955,14 @@ POST_OPS_DOWNSCALE_4x48F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -5569,12 +5569,12 @@ POST_OPS_DOWNSCALE_3x48F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -6058,10 +6058,10 @@ POST_OPS_DOWNSCALE_2x48F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -6422,8 +6422,8 @@ POST_OPS_DOWNSCALE_1x48F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -6993,16 +6993,16 @@ POST_OPS_DOWNSCALE_5x32F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 4 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
zero_point4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -7576,14 +7576,14 @@ POST_OPS_DOWNSCALE_4x32F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -8068,12 +8068,12 @@ POST_OPS_DOWNSCALE_3x32F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -8464,10 +8464,10 @@ POST_OPS_DOWNSCALE_2x32F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -8771,8 +8771,8 @@ POST_OPS_DOWNSCALE_1x32F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
|
||||
@@ -986,6 +986,7 @@ POST_OPS_DOWNSCALE_6x64F:
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) );
|
||||
}
|
||||
|
||||
if( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
|
||||
{
|
||||
@@ -1108,14 +1109,14 @@ POST_OPS_DOWNSCALE_6x64F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1 );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1 ) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
|
||||
//c[0, 0-15]
|
||||
@@ -2174,14 +2175,14 @@ POST_OPS_DOWNSCALE_6x48F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
@@ -2228,10 +2229,10 @@ POST_OPS_DOWNSCALE_6x48F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 4 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 5);
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 5) );
|
||||
}
|
||||
//c[4, 0-15]
|
||||
F32_SCL_MULRND(zmm24, selector1, zero_point0);
|
||||
@@ -3013,18 +3014,18 @@ POST_OPS_DOWNSCALE_6x32F:
|
||||
}
|
||||
if( *( ( dim_t* )post_ops_list_temp->op_args3 ) > 1 )
|
||||
{
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 0 );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 1);
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 2 );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 3 );
|
||||
zero_point4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 4 );
|
||||
zero_point5 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 ) +
|
||||
post_ops_attr.post_op_c_i + 5 );
|
||||
zero_point0 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 0 ) );
|
||||
zero_point1 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 1) );
|
||||
zero_point2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
zero_point4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
zero_point5 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 5 ) );
|
||||
}
|
||||
//c[0, 0-15]
|
||||
F32_SCL_MULRND(zmm8, selector1, zero_point0);
|
||||
|
||||
506
kernels/zen4/lpgemm/f32f32f32/lpgemm_pack_b_f32_amd512vnni.c
Normal file
506
kernels/zen4/lpgemm/f32f32f32/lpgemm_pack_b_f32_amd512vnni.c
Normal file
@@ -0,0 +1,506 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binarsy form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include "blis.h"
|
||||
|
||||
#ifdef BLIS_ADDON_LPGEMM
|
||||
|
||||
void packb_nr64_f32f32f32of32_row_major
|
||||
(
|
||||
float* pack_b_buffer,
|
||||
const float* b,
|
||||
const dim_t ldb,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_b,
|
||||
dim_t* cs_b
|
||||
);
|
||||
|
||||
void packb_nr64_f32f32f32of32_col_major
|
||||
(
|
||||
float* pack_b_buffer,
|
||||
const float* b,
|
||||
const dim_t ldb,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_b,
|
||||
dim_t* cs_b
|
||||
);
|
||||
|
||||
void packb_nr64_f32f32f32of32
|
||||
(
|
||||
float* pack_b_buffer,
|
||||
const float* b,
|
||||
const dim_t rs_b,
|
||||
const dim_t cs_b,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_p,
|
||||
dim_t* cs_p
|
||||
)
|
||||
{
|
||||
if( cs_b == 1 )
|
||||
{
|
||||
packb_nr64_f32f32f32of32_row_major( pack_b_buffer, b,
|
||||
rs_b, NC, KC, rs_p, cs_p );
|
||||
}
|
||||
else
|
||||
{
|
||||
packb_nr64_f32f32f32of32_col_major( pack_b_buffer, b,
|
||||
cs_b, NC, KC, rs_p, cs_p );
|
||||
}
|
||||
}
|
||||
|
||||
void packb_nr64_f32f32f32of32_row_major
|
||||
(
|
||||
float* pack_b_buffer,
|
||||
const float* b,
|
||||
const dim_t ldb,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_b,
|
||||
dim_t* cs_b
|
||||
)
|
||||
{
|
||||
dim_t NR = 64;
|
||||
|
||||
__m512 a0;
|
||||
__m512 b0;
|
||||
__m512 c0;
|
||||
__m512 d0;
|
||||
|
||||
dim_t n_full_pieces_loop_limit = ( NC / NR ) * NR;
|
||||
dim_t n_partial_pieces = NC % NR;
|
||||
|
||||
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
|
||||
{
|
||||
for ( dim_t kr = 0; kr < KC; kr += 1 )
|
||||
{
|
||||
a0 = _mm512_loadu_ps( b + ( ldb * kr ) + ( jc + 0 ) );
|
||||
b0 = _mm512_loadu_ps( b + ( ldb * kr ) + ( jc + 16 ) );
|
||||
c0 = _mm512_loadu_ps( b + ( ldb * kr ) + ( jc + 32 ) );
|
||||
d0 = _mm512_loadu_ps( b + ( ldb * kr ) + ( jc + 48 ) );
|
||||
|
||||
//store to pack_b buffer
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) +
|
||||
( ( kr * NR ) + 0 ), a0 );
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) +
|
||||
( ( kr * NR ) + 16 ), b0 );
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) +
|
||||
( ( kr * NR ) + 32 ), c0 );
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) +
|
||||
( ( kr * NR ) + 48 ), d0 );
|
||||
}
|
||||
}
|
||||
if( n_partial_pieces > 0 )
|
||||
{
|
||||
dim_t n0_partial_rem = n_partial_pieces % 16;
|
||||
|
||||
dim_t n0_48 = n_partial_pieces / 48;
|
||||
dim_t n0_32 = n_partial_pieces / 32;
|
||||
dim_t n0_16 = n_partial_pieces / 16;
|
||||
|
||||
__mmask16 lmask_0 = _cvtu32_mask16( 0x0 );
|
||||
__mmask16 lmask_1 = _cvtu32_mask16( 0x0 );
|
||||
__mmask16 lmask_2 = _cvtu32_mask16( 0x0 );
|
||||
__mmask16 lmask_3 = _cvtu32_mask16( 0x0 );
|
||||
|
||||
if ( n0_48 > 0 )
|
||||
{
|
||||
lmask_0 = _cvtu32_mask16( 0xFFFF );
|
||||
lmask_1 = _cvtu32_mask16( 0xFFFF );
|
||||
lmask_2 = _cvtu32_mask16( 0xFFFF );
|
||||
lmask_3 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
|
||||
}
|
||||
else if ( n0_32 > 0 )
|
||||
{
|
||||
lmask_0 = _cvtu32_mask16( 0xFFFF );
|
||||
lmask_1 = _cvtu32_mask16( 0xFFFF );
|
||||
lmask_2 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
|
||||
}
|
||||
else if ( n0_16 > 0 )
|
||||
{
|
||||
lmask_0 = _cvtu32_mask16( 0xFFFF );
|
||||
lmask_1 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
lmask_0 = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
|
||||
}
|
||||
|
||||
for ( dim_t kr = 0; kr < KC; kr += 1 )
|
||||
{
|
||||
a0 = _mm512_maskz_loadu_ps( lmask_0, b + ( ldb * kr ) +
|
||||
( n_full_pieces_loop_limit + 0 ) );
|
||||
b0 = _mm512_maskz_loadu_ps( lmask_1, b + ( ldb * kr ) +
|
||||
( n_full_pieces_loop_limit + 16 ) );
|
||||
c0 = _mm512_maskz_loadu_ps( lmask_2, b + ( ldb * kr ) +
|
||||
( n_full_pieces_loop_limit + 32 ) );
|
||||
d0 = _mm512_maskz_loadu_ps( lmask_3, b + ( ldb * kr ) +
|
||||
( n_full_pieces_loop_limit + 48 ) );
|
||||
|
||||
//store to pack_b buffer
|
||||
_mm512_storeu_ps( pack_b_buffer + ( n_full_pieces_loop_limit * KC ) +
|
||||
( ( kr * NR ) + 0 ), a0 );
|
||||
_mm512_storeu_ps( pack_b_buffer + ( n_full_pieces_loop_limit * KC ) +
|
||||
( ( kr * NR ) + 16 ), b0 );
|
||||
_mm512_storeu_ps( pack_b_buffer + ( n_full_pieces_loop_limit * KC ) +
|
||||
( ( kr * NR ) + 32 ), c0 );
|
||||
_mm512_storeu_ps( pack_b_buffer + ( n_full_pieces_loop_limit * KC ) +
|
||||
( ( kr * NR ) + 48 ), d0 );
|
||||
}
|
||||
}
|
||||
|
||||
*rs_b = NR;
|
||||
*cs_b = 1;
|
||||
}
|
||||
|
||||
#define MASK_LOAD_PS_16x16(msk) \
|
||||
a_reg[0] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 0 ) ) + ( kr ) ); \
|
||||
a_reg[1] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 1 ) ) + ( kr ) ); \
|
||||
a_reg[2] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 2 ) ) + ( kr ) ); \
|
||||
a_reg[3] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 3 ) ) + ( kr ) ); \
|
||||
a_reg[4] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 4 ) ) + ( kr ) ); \
|
||||
a_reg[5] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 5 ) ) + ( kr ) ); \
|
||||
a_reg[6] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 6 ) ) + ( kr ) ); \
|
||||
a_reg[7] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 7 ) ) + ( kr ) ); \
|
||||
a_reg[8] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 8 ) ) + ( kr ) ); \
|
||||
a_reg[9] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 9 ) ) + ( kr ) ); \
|
||||
a_reg[10] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 10 ) ) + ( kr ) ); \
|
||||
a_reg[11] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 11 ) ) + ( kr ) ); \
|
||||
a_reg[12] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 12 ) ) + ( kr ) ); \
|
||||
a_reg[13] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 13 ) ) + ( kr ) ); \
|
||||
a_reg[14] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 14 ) ) + ( kr ) ); \
|
||||
a_reg[15] = _mm512_maskz_loadu_ps( ( msk ), b + ( ldb * ( jr + 15 ) ) + ( kr ) ); \
|
||||
|
||||
#define MASK_ARR_LOAD_PS_16x16(msk_n_arr) \
|
||||
a_reg[0] = _mm512_maskz_loadu_ps( msk_n_arr[0], b + ( ldb * ( jr + 0 ) ) + ( kr ) ); \
|
||||
a_reg[1] = _mm512_maskz_loadu_ps( msk_n_arr[1], b + ( ldb * ( jr + 1 ) ) + ( kr ) ); \
|
||||
a_reg[2] = _mm512_maskz_loadu_ps( msk_n_arr[2], b + ( ldb * ( jr + 2 ) ) + ( kr ) ); \
|
||||
a_reg[3] = _mm512_maskz_loadu_ps( msk_n_arr[3], b + ( ldb * ( jr + 3 ) ) + ( kr ) ); \
|
||||
a_reg[4] = _mm512_maskz_loadu_ps( msk_n_arr[4], b + ( ldb * ( jr + 4 ) ) + ( kr ) ); \
|
||||
a_reg[5] = _mm512_maskz_loadu_ps( msk_n_arr[5], b + ( ldb * ( jr + 5 ) ) + ( kr ) ); \
|
||||
a_reg[6] = _mm512_maskz_loadu_ps( msk_n_arr[6], b + ( ldb * ( jr + 6 ) ) + ( kr ) ); \
|
||||
a_reg[7] = _mm512_maskz_loadu_ps( msk_n_arr[7], b + ( ldb * ( jr + 7 ) ) + ( kr ) ); \
|
||||
a_reg[8] = _mm512_maskz_loadu_ps( msk_n_arr[8], b + ( ldb * ( jr + 8 ) ) + ( kr ) ); \
|
||||
a_reg[9] = _mm512_maskz_loadu_ps( msk_n_arr[9], b + ( ldb * ( jr + 9 ) ) + ( kr ) ); \
|
||||
a_reg[10] = _mm512_maskz_loadu_ps( msk_n_arr[10], b + ( ldb * ( jr + 10 ) ) + ( kr ) ); \
|
||||
a_reg[11] = _mm512_maskz_loadu_ps( msk_n_arr[11], b + ( ldb * ( jr + 11 ) ) + ( kr ) ); \
|
||||
a_reg[12] = _mm512_maskz_loadu_ps( msk_n_arr[12], b + ( ldb * ( jr + 12 ) ) + ( kr ) ); \
|
||||
a_reg[13] = _mm512_maskz_loadu_ps( msk_n_arr[13], b + ( ldb * ( jr + 13 ) ) + ( kr ) ); \
|
||||
a_reg[14] = _mm512_maskz_loadu_ps( msk_n_arr[14], b + ( ldb * ( jr + 14 ) ) + ( kr ) ); \
|
||||
a_reg[15] = _mm512_maskz_loadu_ps( msk_n_arr[15], b + ( ldb * ( jr + 15 ) ) + ( kr ) ); \
|
||||
|
||||
#define UNPACK_PS_16x16() \
|
||||
/* Even indices contains lo parts, odd indices contains hi parts. */ \
|
||||
b_reg[0] = _mm512_unpacklo_ps( a_reg[0], a_reg[1] ); \
|
||||
b_reg[1] = _mm512_unpackhi_ps( a_reg[0], a_reg[1] ); \
|
||||
b_reg[2] = _mm512_unpacklo_ps( a_reg[2], a_reg[3] ); \
|
||||
b_reg[3] = _mm512_unpackhi_ps( a_reg[2], a_reg[3] ); \
|
||||
b_reg[4] = _mm512_unpacklo_ps( a_reg[4], a_reg[5] ); \
|
||||
b_reg[5] = _mm512_unpackhi_ps( a_reg[4], a_reg[5] ); \
|
||||
b_reg[6] = _mm512_unpacklo_ps( a_reg[6], a_reg[7] ); \
|
||||
b_reg[7] = _mm512_unpackhi_ps( a_reg[6], a_reg[7] ); \
|
||||
b_reg[8] = _mm512_unpacklo_ps( a_reg[8], a_reg[9] ); \
|
||||
b_reg[9] = _mm512_unpackhi_ps( a_reg[8], a_reg[9] ); \
|
||||
b_reg[10] = _mm512_unpacklo_ps( a_reg[10], a_reg[11] ); \
|
||||
b_reg[11] = _mm512_unpackhi_ps( a_reg[10], a_reg[11] ); \
|
||||
b_reg[12] = _mm512_unpacklo_ps( a_reg[12], a_reg[13] ); \
|
||||
b_reg[13] = _mm512_unpackhi_ps( a_reg[12], a_reg[13] ); \
|
||||
b_reg[14] = _mm512_unpacklo_ps( a_reg[14], a_reg[15] ); \
|
||||
b_reg[15] = _mm512_unpackhi_ps( a_reg[14], a_reg[15] ); \
|
||||
|
||||
#define UNPACK_PD_16x16() \
|
||||
/* Even indices contains lo parts, odd indices contains hi parts. */ \
|
||||
a_reg[0] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[0], ( __m512d )b_reg[2] ); \
|
||||
a_reg[1] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[0], ( __m512d )b_reg[2] ); \
|
||||
a_reg[2] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[4], ( __m512d )b_reg[6] ); \
|
||||
a_reg[3] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[4], ( __m512d )b_reg[6] ); \
|
||||
a_reg[4] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[8], ( __m512d )b_reg[10] ); \
|
||||
a_reg[5] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[8], ( __m512d )b_reg[10] ); \
|
||||
a_reg[6] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[12], ( __m512d )b_reg[14] ); \
|
||||
a_reg[7] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[12], ( __m512d )b_reg[14] ); \
|
||||
a_reg[8] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[1], ( __m512d )b_reg[3] ); \
|
||||
a_reg[9] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[1], ( __m512d )b_reg[3] ); \
|
||||
a_reg[10] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[5], ( __m512d )b_reg[7] ); \
|
||||
a_reg[11] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[5], ( __m512d )b_reg[7] ); \
|
||||
a_reg[12] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[9], ( __m512d )b_reg[11] ); \
|
||||
a_reg[13] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[9], ( __m512d )b_reg[11] ); \
|
||||
a_reg[14] = ( __m512 )_mm512_unpacklo_pd( ( __m512d )b_reg[13], ( __m512d )b_reg[15] ); \
|
||||
a_reg[15] = ( __m512 )_mm512_unpackhi_pd( ( __m512d )b_reg[13], ( __m512d )b_reg[15] ); \
|
||||
|
||||
#define PERMUTE_R1_16x16(selector1, selector2) \
|
||||
/* Even indices contains lo parts, odd indices contains hi parts. */ \
|
||||
b_reg[0] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[0], selector1, ( __m512d )a_reg[2] ); \
|
||||
b_reg[1] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[0], selector2, ( __m512d )a_reg[2] ); \
|
||||
b_reg[2] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[4], selector1, ( __m512d )a_reg[6] ); \
|
||||
b_reg[3] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[4], selector2, ( __m512d )a_reg[6] ); \
|
||||
b_reg[4] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[8], selector1, ( __m512d )a_reg[10] ); \
|
||||
b_reg[5] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[8], selector2, ( __m512d )a_reg[10] ); \
|
||||
b_reg[6] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[12], selector1, ( __m512d )a_reg[14] ); \
|
||||
b_reg[7] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[12], selector2, ( __m512d )a_reg[14] ); \
|
||||
b_reg[8] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[1], selector1, ( __m512d )a_reg[3] ); \
|
||||
b_reg[9] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[1], selector2, ( __m512d )a_reg[3] ); \
|
||||
b_reg[10] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[5], selector1, ( __m512d )a_reg[7] ); \
|
||||
b_reg[11] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[5], selector2, ( __m512d )a_reg[7] ); \
|
||||
b_reg[12] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[9], selector1, ( __m512d )a_reg[11] ); \
|
||||
b_reg[13] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[9], selector2, ( __m512d )a_reg[11] ); \
|
||||
b_reg[14] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[13], selector1, ( __m512d )a_reg[15] ); \
|
||||
b_reg[15] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )a_reg[13], selector2, ( __m512d )a_reg[15] ); \
|
||||
|
||||
#define PERMUTE_R2_16x16(selector1_1, selector2_1) \
|
||||
/* Even indices contains lo parts, odd indices contains hi parts. */ \
|
||||
a_reg[0] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[0], selector1_1, \
|
||||
( __m512d )b_reg[2] ); /* Row 0 */ \
|
||||
a_reg[1] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[0], selector2_1, \
|
||||
( __m512d )b_reg[2] ); /* Row 4 */ \
|
||||
a_reg[2] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[4], selector1_1, \
|
||||
( __m512d )b_reg[6] ); /* Row 2 */ \
|
||||
a_reg[3] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[4], selector2_1, \
|
||||
( __m512d )b_reg[6] ); /* Row 6 */ \
|
||||
a_reg[4] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[8], selector1_1, \
|
||||
( __m512d )b_reg[10] ); /* Row 1 */ \
|
||||
a_reg[5] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[8], selector2_1, \
|
||||
( __m512d )b_reg[10] ); /* Row 5 */ \
|
||||
a_reg[6] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[12], selector1_1, \
|
||||
( __m512d )b_reg[14] ); /* Row 3 */ \
|
||||
a_reg[7] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[12], selector2_1, \
|
||||
( __m512d )b_reg[14] ); /* Row 7 */ \
|
||||
a_reg[8] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[1], selector1_1, \
|
||||
( __m512d )b_reg[3] ); /* Row 8 */ \
|
||||
a_reg[9] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[1], selector2_1, \
|
||||
( __m512d )b_reg[3] ); /* Row 12 */ \
|
||||
a_reg[10] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[5], selector1_1, \
|
||||
( __m512d )b_reg[7] ); /* Row 10 */ \
|
||||
a_reg[11] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[5], selector2_1, \
|
||||
( __m512d )b_reg[7] ); /* Row 14 */ \
|
||||
a_reg[12] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[9], selector1_1, \
|
||||
( __m512d )b_reg[11] ); /* Row 9 */ \
|
||||
a_reg[13] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[9], selector2_1, \
|
||||
( __m512d )b_reg[11] ); /* Row 13 */ \
|
||||
a_reg[14] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[13], selector1_1, \
|
||||
( __m512d )b_reg[15] ); /* Row 11 */ \
|
||||
a_reg[15] = ( __m512 )_mm512_permutex2var_pd( ( __m512d )b_reg[13], selector2_1, \
|
||||
( __m512d )b_reg[15] ); /* Row 15 */ \
|
||||
|
||||
#define STORE_PS_16x16() \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 0 ) * NR ) + jr_offset, a_reg[0] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 1 ) * NR ) + jr_offset, a_reg[4] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 2 ) * NR ) + jr_offset, a_reg[2] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 3 ) * NR ) + jr_offset, a_reg[6] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 4 ) * NR ) + jr_offset, a_reg[1] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 5 ) * NR ) + jr_offset, a_reg[5] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 6 ) * NR ) + jr_offset, a_reg[3] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 7 ) * NR ) + jr_offset, a_reg[7] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 8 ) * NR ) + jr_offset, a_reg[8] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 9 ) * NR ) + jr_offset, a_reg[12] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 10 ) * NR ) + jr_offset, a_reg[10] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 11 ) * NR ) + jr_offset, a_reg[14] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 12 ) * NR ) + jr_offset, a_reg[9] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 13 ) * NR ) + jr_offset, a_reg[13] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 14 ) * NR ) + jr_offset, a_reg[11] ); \
|
||||
_mm512_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 15 ) * NR ) + jr_offset, a_reg[15] ); \
|
||||
|
||||
#define MASK_STORE_PS_16x16(msk_arr) \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 0 ) * NR ) + jr_offset, \
|
||||
msk_arr[0], a_reg[0] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 1 ) * NR ) + jr_offset, \
|
||||
msk_arr[1], a_reg[4] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 2 ) * NR ) + jr_offset, \
|
||||
msk_arr[2], a_reg[2] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 3 ) * NR ) + jr_offset, \
|
||||
msk_arr[3], a_reg[6] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 4 ) * NR ) + jr_offset, \
|
||||
msk_arr[4], a_reg[1] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 5 ) * NR ) + jr_offset, \
|
||||
msk_arr[5], a_reg[5] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 6 ) * NR ) + jr_offset, \
|
||||
msk_arr[6], a_reg[3] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 7 ) * NR ) + jr_offset, \
|
||||
msk_arr[7], a_reg[7] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 8 ) * NR ) + jr_offset, \
|
||||
msk_arr[8], a_reg[8] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 9 ) * NR ) + jr_offset, \
|
||||
msk_arr[9], a_reg[12] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 10 ) * NR ) + jr_offset, \
|
||||
msk_arr[10], a_reg[10] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 11 ) * NR ) + jr_offset, \
|
||||
msk_arr[11], a_reg[14] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 12 ) * NR ) + jr_offset, \
|
||||
msk_arr[12], a_reg[9] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 13 ) * NR ) + jr_offset, \
|
||||
msk_arr[13], a_reg[13] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 14 ) * NR ) + jr_offset, \
|
||||
msk_arr[14], a_reg[11] ); \
|
||||
_mm512_mask_storeu_ps( pack_b_buffer + ( jc * KC ) + ( ( kr + 15 ) * NR ) + jr_offset, \
|
||||
msk_arr[15], a_reg[15] ); \
|
||||
|
||||
void packb_nr64_f32f32f32of32_col_major
|
||||
(
|
||||
float* pack_b_buffer,
|
||||
const float* b,
|
||||
const dim_t ldb,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_b,
|
||||
dim_t* cs_b
|
||||
)
|
||||
{
|
||||
const dim_t NR = 64;
|
||||
const dim_t n_sub_blk_wdth = 16;
|
||||
const dim_t k_reg_size = 16;
|
||||
|
||||
__m512 a_reg[16];
|
||||
__m512 b_reg[16];
|
||||
|
||||
dim_t n_full_pieces_loop_limit = ( NC / NR ) * NR;
|
||||
dim_t n_partial_pieces = NC % NR;
|
||||
|
||||
dim_t k_full_pieces_loop_limit = ( KC / k_reg_size ) * k_reg_size;
|
||||
dim_t k_partial_pieces = KC % k_reg_size;
|
||||
|
||||
// First permute sequences.
|
||||
__m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB );
|
||||
__m512i selector2 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF );
|
||||
|
||||
// Second permute sequences.
|
||||
__m512i selector1_1 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB );
|
||||
__m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF );
|
||||
|
||||
for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR )
|
||||
{
|
||||
for ( dim_t jr = jc; jr < jc + NR; jr += n_sub_blk_wdth )
|
||||
{
|
||||
dim_t jr_offset = jr % NR;
|
||||
__mmask16 msk = _cvtu32_mask16( 0xFFFF );
|
||||
for ( dim_t kr = 0; kr < k_full_pieces_loop_limit; kr += k_reg_size )
|
||||
{
|
||||
MASK_LOAD_PS_16x16(msk);
|
||||
UNPACK_PS_16x16();
|
||||
UNPACK_PD_16x16();
|
||||
PERMUTE_R1_16x16(selector1, selector2);
|
||||
PERMUTE_R2_16x16(selector1_1, selector2_1);
|
||||
STORE_PS_16x16();
|
||||
}
|
||||
if ( k_partial_pieces > 0 )
|
||||
{
|
||||
msk = _cvtu32_mask16( 0xFFFF >> ( k_reg_size - k_partial_pieces ) );
|
||||
dim_t kr = k_full_pieces_loop_limit;
|
||||
|
||||
MASK_LOAD_PS_16x16(msk);
|
||||
UNPACK_PS_16x16();
|
||||
UNPACK_PD_16x16();
|
||||
PERMUTE_R1_16x16(selector1, selector2);
|
||||
PERMUTE_R2_16x16(selector1_1, selector2_1);
|
||||
|
||||
__mmask16 msk_arr[16];
|
||||
for ( int i = 0; i < k_partial_pieces; ++i )
|
||||
{
|
||||
msk_arr[i] = _cvtu32_mask16(0xFFFF);
|
||||
}
|
||||
for ( int i = k_partial_pieces; i < 16; ++i )
|
||||
{
|
||||
msk_arr[i] = _cvtu32_mask16(0x0);
|
||||
}
|
||||
MASK_STORE_PS_16x16(msk_arr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if( n_partial_pieces > 0 )
|
||||
{
|
||||
dim_t jc = n_full_pieces_loop_limit;
|
||||
for ( dim_t jr = n_full_pieces_loop_limit; jr < NC; jr += n_sub_blk_wdth )
|
||||
{
|
||||
dim_t jr_offset = jr % NR;
|
||||
__mmask16 msk_n_arr[16];
|
||||
|
||||
dim_t jr_elems = ( ( NC - jr ) >= n_sub_blk_wdth ) ? n_sub_blk_wdth : ( NC - jr );
|
||||
for ( int i = 0; i < jr_elems; ++i )
|
||||
{
|
||||
msk_n_arr[i] = _cvtu32_mask16( 0xFFFF );
|
||||
}
|
||||
for ( int i = jr_elems; i < n_sub_blk_wdth; ++i )
|
||||
{
|
||||
msk_n_arr[i] = _cvtu32_mask16( 0x0 );
|
||||
}
|
||||
|
||||
for ( dim_t kr = 0; kr < k_full_pieces_loop_limit; kr += k_reg_size )
|
||||
{
|
||||
MASK_ARR_LOAD_PS_16x16(msk_n_arr);
|
||||
UNPACK_PS_16x16();
|
||||
UNPACK_PD_16x16();
|
||||
PERMUTE_R1_16x16(selector1, selector2);
|
||||
PERMUTE_R2_16x16(selector1_1, selector2_1);
|
||||
STORE_PS_16x16();
|
||||
}
|
||||
if ( k_partial_pieces > 0 )
|
||||
{
|
||||
for ( int i = 0; i < jr_elems; ++i )
|
||||
{
|
||||
msk_n_arr[i] = _cvtu32_mask16( 0xFFFF &
|
||||
( 0xFFFF >> ( k_reg_size - k_partial_pieces ) ) );
|
||||
}
|
||||
for ( int i = jr_elems; i < n_sub_blk_wdth; ++i )
|
||||
{
|
||||
msk_n_arr[i] = _cvtu32_mask16( 0x0 );
|
||||
}
|
||||
dim_t kr = k_full_pieces_loop_limit;
|
||||
|
||||
MASK_ARR_LOAD_PS_16x16(msk_n_arr);
|
||||
UNPACK_PS_16x16();
|
||||
UNPACK_PD_16x16();
|
||||
PERMUTE_R1_16x16(selector1, selector2);
|
||||
PERMUTE_R2_16x16(selector1_1, selector2_1);
|
||||
|
||||
__mmask16 msk_arr[16];
|
||||
for (int i = 0; i < k_partial_pieces; ++i)
|
||||
{
|
||||
msk_arr[i] = _cvtu32_mask16(0xFFFF);
|
||||
}
|
||||
for (int i = k_partial_pieces; i < 16; ++i)
|
||||
{
|
||||
msk_arr[i] = _cvtu32_mask16(0x0);
|
||||
}
|
||||
MASK_STORE_PS_16x16(msk_arr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*rs_b = NR;
|
||||
*cs_b = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user