mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
New A packing kernels for F32 API in LPGEMM.
-New packing kernels for A matrix, both based on AVX512 and AVX2 ISA, for both row and column major storage are added as part of this change. Dependency on haswell A packing kernels are removed by this. -Tiny GEMM thresholds are further tuned for BF16 and F32 APIs. AMD-Internal: [SWLCSG-3380, SWLCSG-3415] Change-Id: I7330defacbacc9d07037ce1baf4a441f941e59be
This commit is contained in:
@@ -56,11 +56,19 @@ static inline bool is_tiny_input_bf16obf16
|
||||
const dim_t NC = lcntx->blksz.NC;
|
||||
const dim_t MC = lcntx->blksz.MC;
|
||||
const dim_t KC = lcntx->blksz.KC;
|
||||
const dim_t MR = lcntx->blksz.MR;
|
||||
const dim_t NR = lcntx->blksz.NR;
|
||||
|
||||
dim_t mnk = m * n * k;
|
||||
const dim_t mnk_magic_num = 36 * 128 * 128;
|
||||
const dim_t m_thresh = 6 * MR;
|
||||
const dim_t n_thresh = 2 * NR;
|
||||
const dim_t k_thresh = 1024;
|
||||
|
||||
// Need to explicitly check for MC, NC boundaries for safety.
|
||||
if ( ( k < 256 ) && ( m <= MC ) && ( n < NC ) && ( k < KC ) &&
|
||||
( ( ( m <= 36 ) && ( n <= 64 ) ) ||
|
||||
( ( m <= 12 ) && ( n <= 128 ) ) ) )
|
||||
if ( ( m <= MC ) && ( n < NC ) && ( k < KC ) &&
|
||||
( ( m <= m_thresh ) && ( n <= n_thresh ) && ( k <= k_thresh ) &&
|
||||
( mnk < mnk_magic_num ) ) )
|
||||
{
|
||||
is_tiny = TRUE;
|
||||
}
|
||||
|
||||
@@ -56,11 +56,19 @@ static inline bool is_tiny_input_bf16of32
|
||||
const dim_t NC = lcntx->blksz.NC;
|
||||
const dim_t MC = lcntx->blksz.MC;
|
||||
const dim_t KC = lcntx->blksz.KC;
|
||||
const dim_t MR = lcntx->blksz.MR;
|
||||
const dim_t NR = lcntx->blksz.NR;
|
||||
|
||||
dim_t mnk = m * n * k;
|
||||
const dim_t mnk_magic_num = 36 * 128 * 128;
|
||||
const dim_t m_thresh = 6 * MR;
|
||||
const dim_t n_thresh = 2 * NR;
|
||||
const dim_t k_thresh = 1024;
|
||||
|
||||
// Need to explicitly check for MC, NC boundaries for safety.
|
||||
if ( ( k < 256 ) && ( m <= MC ) && ( n < NC ) && ( k < KC ) &&
|
||||
( ( ( m <= 36 ) && ( n <= 64 ) ) ||
|
||||
( ( m <= 12 ) && ( n <= 128 ) ) ) )
|
||||
if ( ( m <= MC ) && ( n < NC ) && ( k < KC ) &&
|
||||
( ( m <= m_thresh ) && ( n <= n_thresh ) && ( k <= k_thresh ) &&
|
||||
( mnk < mnk_magic_num ) ) )
|
||||
{
|
||||
is_tiny = TRUE;
|
||||
}
|
||||
|
||||
@@ -55,11 +55,20 @@ static inline bool is_tiny_input_f32
|
||||
|
||||
const dim_t NC = lcntx->blksz.NC;
|
||||
const dim_t MC = lcntx->blksz.MC;
|
||||
const dim_t KC = lcntx->blksz.KC;
|
||||
const dim_t MR = lcntx->blksz.MR;
|
||||
const dim_t NR = lcntx->blksz.NR;
|
||||
|
||||
dim_t mnk = m * n * k;
|
||||
const dim_t mnk_magic_num = 12 * 64 * 496;
|
||||
const dim_t m_thresh = 6 * MR;
|
||||
const dim_t n_thresh = 2 * NR;
|
||||
const dim_t k_thresh = 480;
|
||||
|
||||
// Need to explicitly check for MC, NC boundaries for safety.
|
||||
if ( ( k < 128 ) && ( m <= MC ) && ( n < NC ) &&
|
||||
( ( ( m <= 36 ) && ( n <= 64 ) ) ||
|
||||
( ( m <= 12 ) && ( n <= 128 ) ) ) )
|
||||
if ( ( k < KC ) && ( m <= MC ) && ( n < NC ) &&
|
||||
( ( m <= m_thresh ) && ( n <= n_thresh ) && ( k <= k_thresh ) &&
|
||||
( mnk < mnk_magic_num ) ) )
|
||||
{
|
||||
is_tiny = TRUE;
|
||||
}
|
||||
|
||||
@@ -245,6 +245,7 @@ static void _lpgemm_cntx_init_func_map()
|
||||
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
|
||||
{
|
||||
LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
|
||||
LPGEMM_PACKA_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
|
||||
LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
|
||||
}
|
||||
}
|
||||
@@ -260,6 +261,7 @@ static void _lpgemm_cntx_init_func_map()
|
||||
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
|
||||
{
|
||||
LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2
|
||||
LPGEMM_PACKA_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2;
|
||||
LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,10 +59,14 @@
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
|
||||
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx512) \
|
||||
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2 \
|
||||
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx2) \
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
|
||||
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
|
||||
@@ -110,10 +114,14 @@
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI \
|
||||
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
|
||||
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx512) \
|
||||
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
|
||||
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx2) \
|
||||
|
||||
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI \
|
||||
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
|
||||
|
||||
@@ -147,10 +155,14 @@
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_MAP_AVX512 \
|
||||
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
|
||||
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx512) \
|
||||
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_UPD_MAP_AVX512_TO_AVX2 \
|
||||
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx2) \
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX512 \
|
||||
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
|
||||
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
|
||||
@@ -160,6 +172,9 @@
|
||||
PBMACRO(BF16S4F32OF32, NULL) \
|
||||
PBSMACRO(BF16S4F32OF32, NULL)
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
|
||||
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
|
||||
|
||||
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512 \
|
||||
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
|
||||
|
||||
@@ -180,6 +195,7 @@
|
||||
#define LPGEMM_PACKA_FUNC_MAP_AVX2 \
|
||||
PAMACRO(U8S8S32OS32, NULL) \
|
||||
PAMACRO(BF16BF16F32OF32, NULL) \
|
||||
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx2) \
|
||||
KMACRO(BF16S4F32OF32, NULL) \
|
||||
PAMACRO(S8S8S32OS32, NULL) \
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -62,19 +62,6 @@ typedef void (*lpgemm_rowvar_f32)
|
||||
lpgemm_post_op_attr
|
||||
);
|
||||
|
||||
void lpgemm_pack_a_f32f32f32of32
|
||||
(
|
||||
const float* input_buf_addr_a,
|
||||
float* reorder_buf_addr_a,
|
||||
const dim_t m,
|
||||
const dim_t k,
|
||||
const dim_t rs_a,
|
||||
const dim_t cs_a,
|
||||
const dim_t ps_p,
|
||||
const dim_t MR,
|
||||
cntx_t* cntx
|
||||
);
|
||||
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
LPGEMV(float, float, float, f32f32f32of32)
|
||||
{
|
||||
@@ -352,8 +339,6 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// Query the global cntx.
|
||||
cntx_t* cntx = bli_gks_query_cntx();
|
||||
|
||||
// Query the context for various blocksizes.
|
||||
const dim_t NC = lcntx->blksz.NC;
|
||||
@@ -385,6 +370,7 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
auxinfo_t aux;
|
||||
|
||||
// Check if packing of A is required.
|
||||
// TODO: mtag_a for tranpose needs to be honored.
|
||||
bool should_pack_A = bli_rntm_pack_a( rntm );
|
||||
|
||||
// Pack buffer for A.
|
||||
@@ -594,13 +580,13 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
cs_a_use = MR;
|
||||
ps_a_use = MR * kc0;
|
||||
|
||||
lpgemm_pack_a_f32f32f32of32
|
||||
( ( lpgemm_pack_f32 )lcntx->packa_fun_ptr )
|
||||
(
|
||||
( a + ( rs_a * ic ) + ( pc * cs_a) ),
|
||||
pack_a_buffer_f32f32f32of32,
|
||||
( a + ( rs_a * ic ) + ( pc * cs_a) ),
|
||||
rs_a, cs_a,
|
||||
mc0, kc0,
|
||||
rs_a, cs_a, ps_a_use, MR,
|
||||
cntx
|
||||
&rs_a_use, &cs_a_use
|
||||
);
|
||||
|
||||
a_use = pack_a_buffer_f32f32f32of32;
|
||||
@@ -671,58 +657,3 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void lpgemm_pack_a_f32f32f32of32
|
||||
(
|
||||
const float* input_buf_addr_a,
|
||||
float* reorder_buf_addr_a,
|
||||
const dim_t m,
|
||||
const dim_t k,
|
||||
const dim_t rs_a,
|
||||
const dim_t cs_a,
|
||||
const dim_t ps_p,
|
||||
const dim_t MR,
|
||||
cntx_t* cntx
|
||||
)
|
||||
{
|
||||
float one_local = *PASTEMAC(s,1);
|
||||
float* restrict kappa_cast = &one_local;
|
||||
|
||||
// Set the schema to "column stored row panels" to indicate packing to conventional
|
||||
// column-stored row panels.
|
||||
pack_t schema = BLIS_PACKED_ROW_PANELS;
|
||||
trans_t transc = BLIS_NO_TRANSPOSE;
|
||||
conj_t conjc = bli_extract_conj( transc );
|
||||
// Compute the total number of iterations we'll need.
|
||||
dim_t m_iter = ( m + MR - 1 ) / MR;
|
||||
|
||||
inc_t cs_p = MR;
|
||||
|
||||
float* p_temp = reorder_buf_addr_a;
|
||||
|
||||
dim_t ir, it;
|
||||
// Iterate over every logical micropanel in the source mmatrix.
|
||||
for ( ir = 0, it = 0; it < m_iter; ir += MR, it += 1 )
|
||||
{
|
||||
dim_t panel_dim_i = bli_min( MR, m - ir );
|
||||
|
||||
const float* a_use = input_buf_addr_a + ( ir * rs_a );
|
||||
float* p_use = p_temp;
|
||||
|
||||
PASTEMAC(s,packm_cxk)
|
||||
(
|
||||
conjc,
|
||||
schema,
|
||||
panel_dim_i,
|
||||
MR,
|
||||
k,
|
||||
k,
|
||||
kappa_cast,
|
||||
( float* )a_use, rs_a, cs_a,
|
||||
p_use, cs_p,
|
||||
cntx
|
||||
);
|
||||
|
||||
p_temp += ps_p;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,7 +250,7 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
);
|
||||
get_packa_strides_mfringe_u8s8s32os32
|
||||
(
|
||||
&rs_a_use, &cs_a_use, gemm_MR, 1
|
||||
rs_a, cs_a, &rs_a_use, &cs_a_use, gemm_MR, 1
|
||||
);
|
||||
|
||||
a_use = pack_a_buffer_s8s8s32os32;
|
||||
|
||||
@@ -231,7 +231,7 @@ LPGEMV(uint8_t,int8_t,int32_t,u8s8s32os32)
|
||||
);
|
||||
get_packa_strides_mfringe_u8s8s32os32
|
||||
(
|
||||
&rs_a_use, &cs_a_use, gemm_MR, 1
|
||||
rs_a, cs_a, &rs_a_use, &cs_a_use, gemm_MR, 1
|
||||
);
|
||||
|
||||
a_use = pack_a_buffer;
|
||||
|
||||
@@ -46,6 +46,30 @@ void packa_mr16_f32f32f32of32_col_major
|
||||
dim_t* cs_p
|
||||
);
|
||||
|
||||
void packa_mr6_f32f32f32of32_avx512
|
||||
(
|
||||
float* pack_a_buf,
|
||||
const float* a,
|
||||
const dim_t rs,
|
||||
const dim_t cs,
|
||||
const dim_t MC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_a,
|
||||
dim_t* cs_a
|
||||
);
|
||||
|
||||
void packa_mr6_f32f32f32of32_avx2
|
||||
(
|
||||
float* pack_a_buf,
|
||||
const float* a,
|
||||
const dim_t rs,
|
||||
const dim_t cs,
|
||||
const dim_t MC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_a,
|
||||
dim_t* cs_a
|
||||
);
|
||||
|
||||
typedef void (*lpgemm_pack_f32)
|
||||
(
|
||||
float*,
|
||||
|
||||
@@ -39,14 +39,20 @@
|
||||
// for different schemas used to pack A fringe cases.
|
||||
BLIS_INLINE void get_packa_strides_mfringe_u8s8s32os32
|
||||
(
|
||||
dim_t* rs,
|
||||
dim_t* cs,
|
||||
const dim_t rs,
|
||||
const dim_t cs,
|
||||
dim_t* rs_use,
|
||||
dim_t* cs_use,
|
||||
dim_t MR,
|
||||
dim_t m_fringe
|
||||
)
|
||||
{
|
||||
( *rs ) = 4;
|
||||
( *cs ) = ( ( *cs ) / MR ) * m_fringe;
|
||||
// Only applicable for row major packing.
|
||||
if ( ( rs != 1 ) && ( cs == 1 ) && ( ( *cs_use ) > MR ))
|
||||
{
|
||||
( *rs_use ) = 4;
|
||||
( *cs_use ) = ( ( *cs_use ) / MR ) * m_fringe;
|
||||
}
|
||||
}
|
||||
|
||||
typedef void (*packa_s32)
|
||||
|
||||
602
kernels/zen/lpgemm/f32f32f32/lpgemm_pack_a_f32_avx2.c
Normal file
602
kernels/zen/lpgemm/f32f32f32/lpgemm_pack_a_f32_avx2.c
Normal file
@@ -0,0 +1,602 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <string.h>
|
||||
#include "blis.h"
|
||||
|
||||
#ifdef BLIS_ADDON_LPGEMM
|
||||
|
||||
#define F32_ROW_MAJOR_K_PACK_LOOP_AVX2() \
|
||||
a0 = _mm256_unpacklo_ps( a01, b0 ); \
|
||||
b0 = _mm256_unpackhi_ps( a01, b0 ); \
|
||||
\
|
||||
c0 = _mm256_unpacklo_ps( c01, d0 ); \
|
||||
d0 = _mm256_unpackhi_ps( c01, d0 ); \
|
||||
\
|
||||
e0 = _mm256_unpacklo_ps( e01, f0 ); \
|
||||
f0 = _mm256_unpackhi_ps( e01, f0 ); \
|
||||
\
|
||||
a01 = _mm256_castpd_ps( _mm256_unpacklo_pd( _mm256_castps_pd( a0 ), \
|
||||
_mm256_castps_pd( c0 ) ) ); \
|
||||
a0 = _mm256_castpd_ps( _mm256_unpackhi_pd( _mm256_castps_pd( a0 ), \
|
||||
_mm256_castps_pd( c0 ) ) ); \
|
||||
\
|
||||
c01 = _mm256_castpd_ps( _mm256_unpacklo_pd( _mm256_castps_pd( b0 ), \
|
||||
_mm256_castps_pd( d0 ) ) ); \
|
||||
c0 = _mm256_castpd_ps( _mm256_unpackhi_pd( _mm256_castps_pd( b0 ), \
|
||||
_mm256_castps_pd( d0 ) ) ); \
|
||||
\
|
||||
a0_128 = _mm256_castps256_ps128( a01 ); \
|
||||
b0_128 = _mm256_castps256_ps128( a0 ); \
|
||||
c0_128 = _mm256_castps256_ps128( c01 ); \
|
||||
d0_128 = _mm256_castps256_ps128( c0 ); \
|
||||
e0_128 = _mm256_castps256_ps128( e0 ); \
|
||||
f0_128 = _mm256_castps256_ps128( f0 ); \
|
||||
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 0, a0_128 ); \
|
||||
_mm_storel_pd( ( double*)( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 4 ), \
|
||||
_mm_castps_pd( e0_128 ) ); \
|
||||
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 6, b0_128 ); \
|
||||
_mm_storeh_pd( ( double* )( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 10 ), \
|
||||
_mm_castps_pd( e0_128 ) ); \
|
||||
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 12, c0_128 ); \
|
||||
_mm_storel_pd( ( double* )( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 16 ), \
|
||||
_mm_castps_pd( f0_128 ) ); \
|
||||
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 18, d0_128 ); \
|
||||
_mm_storeh_pd( ( double* )( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 22 ), \
|
||||
_mm_castps_pd( f0_128 ) ); \
|
||||
\
|
||||
a0_128 = _mm256_extractf128_ps( a01, 0x1 ); \
|
||||
b0_128 = _mm256_extractf128_ps( a0, 0x1 ); \
|
||||
c0_128 = _mm256_extractf128_ps( c01, 0x1 ); \
|
||||
d0_128 = _mm256_extractf128_ps( c0, 0x1 ); \
|
||||
e0_128 = _mm256_extractf128_ps( e0, 0x1 ); \
|
||||
f0_128 = _mm256_extractf128_ps( f0, 0x1 ); \
|
||||
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 24, a0_128 ); \
|
||||
_mm_storel_pd( ( double* )( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 28 ), \
|
||||
_mm_castps_pd( e0_128 ) ); \
|
||||
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 30, b0_128 ); \
|
||||
_mm_storeh_pd( ( double*)(pack_a_buf + ( ic * KC ) + ( kr * MR ) + 34 ), \
|
||||
_mm_castps_pd( e0_128 ) ); \
|
||||
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 36, c0_128 ); \
|
||||
_mm_storel_pd( ( double*)(pack_a_buf + ( ic * KC ) + ( kr * MR ) + 40 ), \
|
||||
_mm_castps_pd( f0_128 ) ); \
|
||||
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 42, d0_128 ); \
|
||||
_mm_storeh_pd( ( double*)(pack_a_buf + ( ic * KC ) + ( kr * MR ) + 46 ), \
|
||||
_mm_castps_pd( f0_128 ) ); \
|
||||
|
||||
|
||||
#define F32_ROW_MAJOR_K_PACK_LOOP_SSE() \
|
||||
a0_128 = _mm_unpacklo_ps( a01_128, b0_128 ); \
|
||||
b0_128 = _mm_unpackhi_ps( a01_128, b0_128 ); \
|
||||
\
|
||||
c0_128 = _mm_unpacklo_ps( c01_128, d0_128 ); \
|
||||
d0_128 = _mm_unpackhi_ps( c01_128, d0_128 ); \
|
||||
\
|
||||
e0_128 = _mm_unpacklo_ps( e01_128, f0_128 ); \
|
||||
f0_128 = _mm_unpackhi_ps( e01_128, f0_128 ); \
|
||||
\
|
||||
a01_128 = _mm_castpd_ps( _mm_unpacklo_pd( _mm_castps_pd( a0_128 ), \
|
||||
_mm_castps_pd( c0_128 ) ) ); \
|
||||
a0_128 = _mm_castpd_ps( _mm_unpackhi_pd( _mm_castps_pd( a0_128 ), \
|
||||
_mm_castps_pd( c0_128 ) ) ); \
|
||||
\
|
||||
c01_128 = _mm_castpd_ps( _mm_unpacklo_pd( _mm_castps_pd( b0_128 ), \
|
||||
_mm_castps_pd( d0_128 ) ) ); \
|
||||
c0_128 = _mm_castpd_ps( _mm_unpackhi_pd( _mm_castps_pd( b0_128 ), \
|
||||
_mm_castps_pd( d0_128 ) ) ); \
|
||||
\
|
||||
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 0, a01_128 ); \
|
||||
_mm_storel_pd( ( double*)( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 4 ), \
|
||||
_mm_castps_pd( e0_128 ) ); \
|
||||
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 6, a0_128 ); \
|
||||
_mm_storeh_pd( ( double*)( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 10 ), \
|
||||
_mm_castps_pd( e0_128 ) ); \
|
||||
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 12, c01_128 ); \
|
||||
_mm_storel_pd( ( double*)( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 16 ), \
|
||||
_mm_castps_pd( f0_128 ) ); \
|
||||
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 18, c0_128 ); \
|
||||
_mm_storeh_pd( ( double*)( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 22 ), \
|
||||
_mm_castps_pd( f0_128 ) ); \
|
||||
|
||||
// Row Major Packing in blocks of MRxKC
|
||||
void packa_f32f32f32of32_row_major_avx2
|
||||
(
|
||||
float* pack_a_buf,
|
||||
const float* a,
|
||||
const dim_t lda,
|
||||
const dim_t MC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_a,
|
||||
dim_t* cs_a
|
||||
)
|
||||
{
|
||||
const dim_t MR = 6;
|
||||
const dim_t KR_NDIM = 8;
|
||||
|
||||
dim_t m_full_pieces = MC / MR;
|
||||
dim_t m_full_pieces_loop_limit = m_full_pieces * MR;
|
||||
dim_t m_partial_pieces = MC % MR;
|
||||
|
||||
dim_t kr_full_pieces = KC / KR_NDIM;
|
||||
dim_t kr_full_pieces_loop_limit = kr_full_pieces * KR_NDIM;
|
||||
dim_t kr_partial_pieces = KC % KR_NDIM;
|
||||
|
||||
__m256 a0;
|
||||
__m256 b0;
|
||||
__m256 c0;
|
||||
__m256 d0;
|
||||
__m256 e0;
|
||||
__m256 f0;
|
||||
__m256 a01;
|
||||
__m256 c01;
|
||||
__m256 e01;
|
||||
__m128 a0_128;
|
||||
__m128 b0_128;
|
||||
__m128 c0_128;
|
||||
__m128 d0_128;
|
||||
__m128 e0_128;
|
||||
__m128 f0_128;
|
||||
|
||||
for ( dim_t ic = 0; ic < m_full_pieces_loop_limit; ic += MR )
|
||||
{
|
||||
for ( dim_t kr = 0; kr < kr_full_pieces_loop_limit; kr += KR_NDIM )
|
||||
{
|
||||
a01 = _mm256_loadu_ps( a + ( lda * ( ic + 0 ) ) + kr );
|
||||
b0 = _mm256_loadu_ps( a + ( lda * ( ic + 1 ) ) + kr );
|
||||
c01 = _mm256_loadu_ps( a + ( lda * ( ic + 2 ) ) + kr );
|
||||
d0 = _mm256_loadu_ps( a + ( lda * ( ic + 3 ) ) + kr );
|
||||
e01 = _mm256_loadu_ps( a + ( lda * ( ic + 4 ) ) + kr );
|
||||
f0 = _mm256_loadu_ps( a + ( lda * ( ic + 5 ) ) + kr );
|
||||
|
||||
F32_ROW_MAJOR_K_PACK_LOOP_AVX2();
|
||||
}
|
||||
if ( kr_partial_pieces > 0 )
|
||||
{
|
||||
dim_t kr_partial_pieces_4 = ( kr_partial_pieces / 4 ) * 4;
|
||||
dim_t kr_partial_pieces_rem = kr_partial_pieces - kr_partial_pieces_4;
|
||||
|
||||
dim_t kr = kr_full_pieces_loop_limit;
|
||||
if ( kr_partial_pieces_4 > 0 )
|
||||
{
|
||||
__m128 a01_128 = _mm_loadu_ps( a + ( lda * ( ic + 0 ) ) + kr );
|
||||
b0_128 = _mm_loadu_ps( a + ( lda * ( ic + 1 ) ) + kr );
|
||||
__m128 c01_128 = _mm_loadu_ps( a + ( lda * ( ic + 2 ) ) + kr );
|
||||
d0_128 = _mm_loadu_ps( a + ( lda * ( ic + 3 ) ) + kr );
|
||||
__m128 e01_128 = _mm_loadu_ps( a + ( lda * ( ic + 4 ) ) + kr );
|
||||
f0_128 = _mm_loadu_ps( a + ( lda * ( ic + 5 ) ) + kr );
|
||||
|
||||
F32_ROW_MAJOR_K_PACK_LOOP_SSE();
|
||||
}
|
||||
kr += kr_partial_pieces_4;
|
||||
if ( kr_partial_pieces_rem > 0 )
|
||||
{
|
||||
for ( int ii = 0; ii < kr_partial_pieces_rem; ++ii )
|
||||
{
|
||||
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + 0 ) =
|
||||
*( a + ( lda * ( ic + 0 ) ) + ( kr + ii ) );
|
||||
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + 1 ) =
|
||||
*( a + ( lda * ( ic + 1 ) ) + ( kr + ii ) );
|
||||
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + 2 ) =
|
||||
*( a + ( lda * ( ic + 2 ) ) + ( kr + ii ) );
|
||||
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + 3 ) =
|
||||
*( a + ( lda * ( ic + 3 ) ) + ( kr + ii ) );
|
||||
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + 4 ) =
|
||||
*( a + ( lda * ( ic + 4 ) ) + ( kr + ii ) );
|
||||
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + 5 ) =
|
||||
*( a + ( lda * ( ic + 5 ) ) + ( kr + ii ) );
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if ( m_partial_pieces > 0 )
|
||||
{
|
||||
dim_t ic = m_full_pieces_loop_limit;
|
||||
__m256 temp_a_reg[6];
|
||||
for ( dim_t kr = 0; kr < kr_full_pieces_loop_limit; kr += KR_NDIM )
|
||||
{
|
||||
for ( int ii = 0; ii < m_partial_pieces; ++ii )
|
||||
{
|
||||
temp_a_reg[ii] = _mm256_loadu_ps( a + ( lda * ( ic + ii ) ) + kr );
|
||||
}
|
||||
for ( int ii = m_partial_pieces; ii < MR; ++ii )
|
||||
{
|
||||
temp_a_reg[ii] = _mm256_setzero_ps();
|
||||
}
|
||||
a01 = temp_a_reg[0];
|
||||
b0 = temp_a_reg[1];
|
||||
c01 = temp_a_reg[2];
|
||||
d0 = temp_a_reg[3];
|
||||
e01 = temp_a_reg[4];
|
||||
f0 = temp_a_reg[5];
|
||||
|
||||
F32_ROW_MAJOR_K_PACK_LOOP_AVX2();
|
||||
}
|
||||
if ( kr_partial_pieces > 0 )
|
||||
{
|
||||
dim_t kr_partial_pieces_4 = ( kr_partial_pieces / 4 ) * 4;
|
||||
dim_t kr_partial_pieces_rem = kr_partial_pieces - kr_partial_pieces_4;
|
||||
|
||||
dim_t kr = kr_full_pieces_loop_limit;
|
||||
if ( kr_partial_pieces_4 > 0 )
|
||||
{
|
||||
__m128 temp_a_reg_128[6] = {0};
|
||||
for ( int ii = 0; ii < m_partial_pieces; ++ii )
|
||||
{
|
||||
temp_a_reg_128[ii] = _mm_loadu_ps( a + ( lda * ( ic + ii ) ) + kr );
|
||||
}
|
||||
for ( int ii = m_partial_pieces; ii < MR; ++ii )
|
||||
{
|
||||
temp_a_reg_128[ii] = _mm_setzero_ps();
|
||||
}
|
||||
__m128 a01_128 = temp_a_reg_128[0];
|
||||
b0_128 = temp_a_reg_128[1];
|
||||
__m128 c01_128 = temp_a_reg_128[2];
|
||||
d0_128 = temp_a_reg_128[3];
|
||||
__m128 e01_128 = temp_a_reg_128[4];
|
||||
f0_128 = temp_a_reg_128[5];
|
||||
|
||||
F32_ROW_MAJOR_K_PACK_LOOP_SSE();
|
||||
}
|
||||
kr += kr_partial_pieces_4;
|
||||
if ( kr_partial_pieces_rem > 0 )
|
||||
{
|
||||
for ( int ii = 0; ii < kr_partial_pieces_rem; ++ii )
|
||||
{
|
||||
for ( int jj = 0; jj < m_partial_pieces; ++jj )
|
||||
{
|
||||
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + jj ) =
|
||||
*( a + ( lda * ( ic + jj ) ) + ( kr + ii ) );
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*rs_a = 1;
|
||||
*cs_a = 6;
|
||||
}
|
||||
|
||||
#define F32_COL_MAJOR_K_PACK_STORE_SSE() \
|
||||
_mm_storeu_ps \
|
||||
( \
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 0 ) ), \
|
||||
a0 \
|
||||
); \
|
||||
_mm_store_sd \
|
||||
( \
|
||||
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 0 ) + 4 ) ), \
|
||||
_mm_castps_pd( a0_2e ) \
|
||||
); \
|
||||
_mm_storeu_ps \
|
||||
( \
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 1 ) ), \
|
||||
b0 \
|
||||
); \
|
||||
_mm_store_sd \
|
||||
( \
|
||||
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 1 ) + 4 ) ), \
|
||||
_mm_castps_pd( b0_2e ) \
|
||||
); \
|
||||
_mm_storeu_ps \
|
||||
( \
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 2 ) ), \
|
||||
c0 \
|
||||
); \
|
||||
_mm_store_sd \
|
||||
( \
|
||||
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 2 ) + 4 ) ), \
|
||||
_mm_castps_pd( c0_2e ) \
|
||||
); \
|
||||
_mm_storeu_ps \
|
||||
( \
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 3 ) ), \
|
||||
d0 \
|
||||
); \
|
||||
_mm_store_sd \
|
||||
( \
|
||||
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 3 ) + 4 ) ), \
|
||||
_mm_castps_pd( d0_2e ) \
|
||||
); \
|
||||
_mm_storeu_ps \
|
||||
( \
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 4 ) ), \
|
||||
e0 \
|
||||
); \
|
||||
_mm_store_sd \
|
||||
( \
|
||||
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 4 ) + 4 ) ), \
|
||||
_mm_castps_pd( e0_2e ) \
|
||||
); \
|
||||
_mm_storeu_ps \
|
||||
( \
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 5 ) ), \
|
||||
f0 \
|
||||
); \
|
||||
_mm_store_sd \
|
||||
( \
|
||||
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 5 ) + 4 ) ), \
|
||||
_mm_castps_pd( f0_2e ) \
|
||||
); \
|
||||
_mm_storeu_ps \
|
||||
( \
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 6 ) ), \
|
||||
g0 \
|
||||
); \
|
||||
_mm_store_sd \
|
||||
( \
|
||||
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 6 ) + 4 ) ), \
|
||||
_mm_castps_pd( g0_2e ) \
|
||||
); \
|
||||
_mm_storeu_ps \
|
||||
( \
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 7 ) ), \
|
||||
h0 \
|
||||
); \
|
||||
_mm_store_sd \
|
||||
( \
|
||||
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 7 ) + 4 ) ), \
|
||||
_mm_castps_pd( h0_2e ) \
|
||||
); \
|
||||
|
||||
#define F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, dst, src) \
|
||||
cp_st = 0; \
|
||||
if ( m_partial_4 > 0 ) \
|
||||
{ \
|
||||
( dst )[cp_st + 0] = ( src )[cp_st + 0]; \
|
||||
( dst )[cp_st + 1] = ( src )[cp_st + 1]; \
|
||||
( dst )[cp_st + 2] = ( src )[cp_st + 2]; \
|
||||
( dst )[cp_st + 3] = ( src )[cp_st + 3]; \
|
||||
cp_st += 4; \
|
||||
} \
|
||||
if ( m_partial_2 > 0 ) \
|
||||
{ \
|
||||
( dst )[cp_st + 0] = ( src )[cp_st + 0]; \
|
||||
( dst )[cp_st + 1] = ( src )[cp_st + 1]; \
|
||||
cp_st += 2; \
|
||||
} \
|
||||
if ( m_partial_1 > 0 ) \
|
||||
{ \
|
||||
( dst )[cp_st + 0] = ( src )[cp_st + 0]; \
|
||||
cp_st += 1; \
|
||||
} \
|
||||
|
||||
void packa_f32f32f32of32_col_major_avx2
|
||||
(
|
||||
float* pack_a_buf,
|
||||
const float* a,
|
||||
const dim_t lda,
|
||||
const dim_t MC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_a,
|
||||
dim_t* cs_a
|
||||
)
|
||||
{
|
||||
const dim_t MR = 6;
|
||||
const dim_t KR_NDIM = 8;
|
||||
|
||||
dim_t m_full_pieces = MC / MR;
|
||||
dim_t m_full_pieces_loop_limit = m_full_pieces * MR;
|
||||
dim_t m_partial_pieces = MC % MR;
|
||||
|
||||
dim_t kr_full_pieces = KC / KR_NDIM;
|
||||
dim_t kr_full_pieces_loop_limit = kr_full_pieces * KR_NDIM;
|
||||
dim_t kr_partial_pieces = KC % KR_NDIM;
|
||||
|
||||
__m128 a0;
|
||||
__m128 b0;
|
||||
__m128 c0;
|
||||
__m128 d0;
|
||||
__m128 e0;
|
||||
__m128 f0;
|
||||
__m128 g0;
|
||||
__m128 h0;
|
||||
__m128 a0_2e;
|
||||
__m128 b0_2e;
|
||||
__m128 c0_2e;
|
||||
__m128 d0_2e;
|
||||
__m128 e0_2e;
|
||||
__m128 f0_2e;
|
||||
__m128 g0_2e;
|
||||
__m128 h0_2e;
|
||||
|
||||
for ( dim_t ic = 0; ic < m_full_pieces_loop_limit; ic += MR )
|
||||
{
|
||||
for ( dim_t kr = 0; kr < kr_full_pieces_loop_limit; kr += KR_NDIM )
|
||||
{
|
||||
// First 4 elements.
|
||||
a0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 0 ) ) );
|
||||
b0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 1 ) ) );
|
||||
c0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 2 ) ) );
|
||||
d0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 3 ) ) );
|
||||
e0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 4 ) ) );
|
||||
f0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 5 ) ) );
|
||||
g0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 6 ) ) );
|
||||
h0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 7 ) ) );
|
||||
|
||||
// Last 2 elements.
|
||||
a0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
|
||||
( lda * ( kr + 0 ) ) ) ) );
|
||||
b0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
|
||||
( lda * ( kr + 1 ) ) ) ) );
|
||||
c0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
|
||||
( lda * ( kr + 2 ) ) ) ) );
|
||||
d0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
|
||||
( lda * ( kr + 3 ) ) ) ) );
|
||||
e0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
|
||||
( lda * ( kr + 4 ) ) ) ) );
|
||||
f0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
|
||||
( lda * ( kr + 5 ) ) ) ) );
|
||||
g0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
|
||||
( lda * ( kr + 6 ) ) ) ) );
|
||||
h0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
|
||||
( lda * ( kr + 7 ) ) ) ) );
|
||||
|
||||
F32_COL_MAJOR_K_PACK_STORE_SSE();
|
||||
}
|
||||
if ( kr_partial_pieces )
|
||||
{
|
||||
for ( dim_t kr = kr_full_pieces_loop_limit; kr < KC; ++kr )
|
||||
{
|
||||
a0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 0 ) ) );
|
||||
a0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
|
||||
( lda * ( kr + 0 ) ) ) ) );
|
||||
|
||||
_mm_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 0 ) ),
|
||||
a0
|
||||
);
|
||||
_mm_store_sd
|
||||
(
|
||||
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + \
|
||||
( MR * 0 ) + 4 ) ),
|
||||
_mm_castps_pd( a0_2e )
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
if ( m_partial_pieces > 0 )
|
||||
{
|
||||
dim_t ic = m_full_pieces_loop_limit;
|
||||
dim_t m_partial_4 = ( m_partial_pieces / 4 ) * 4;
|
||||
dim_t m_partial_2 = ( ( m_partial_pieces - m_partial_4 ) / 2 ) * 2;
|
||||
dim_t m_partial_1 = m_partial_pieces - ( m_partial_4 + m_partial_2 );
|
||||
dim_t cp_st = 0;
|
||||
|
||||
float temp_pack_a_buf[6] = { 0 };
|
||||
|
||||
for ( dim_t kr = 0; kr < kr_full_pieces_loop_limit; kr += KR_NDIM )
|
||||
{
|
||||
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
|
||||
a + ic + ( lda * ( kr + 0 ) ) );
|
||||
a0 = _mm_loadu_ps( temp_pack_a_buf );
|
||||
a0_2e = _mm_castpd_ps( _mm_load_sd(
|
||||
( double* )( temp_pack_a_buf + 4 ) ) );
|
||||
|
||||
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
|
||||
a + ic + ( lda * ( kr + 1 ) ) );
|
||||
b0 = _mm_loadu_ps( temp_pack_a_buf );
|
||||
b0_2e = _mm_castpd_ps( _mm_load_sd(
|
||||
( double* )( temp_pack_a_buf + 4 ) ) );
|
||||
|
||||
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
|
||||
a + ic + ( lda * ( kr + 2 ) ) );
|
||||
c0 = _mm_loadu_ps( temp_pack_a_buf );
|
||||
c0_2e = _mm_castpd_ps( _mm_load_sd(
|
||||
( double* )( temp_pack_a_buf + 4 ) ) );
|
||||
|
||||
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
|
||||
a + ic + ( lda * ( kr + 3 ) ) );
|
||||
d0 = _mm_loadu_ps( temp_pack_a_buf );
|
||||
d0_2e = _mm_castpd_ps( _mm_load_sd(
|
||||
( double* )( temp_pack_a_buf + 4 ) ) );
|
||||
|
||||
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
|
||||
a + ic + ( lda * ( kr + 4 ) ) );
|
||||
e0 = _mm_loadu_ps( temp_pack_a_buf );
|
||||
e0_2e = _mm_castpd_ps( _mm_load_sd(
|
||||
( double* )( temp_pack_a_buf + 4 ) ) );
|
||||
|
||||
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
|
||||
a + ic + ( lda * ( kr + 5 ) ) );
|
||||
f0 = _mm_loadu_ps( temp_pack_a_buf );
|
||||
f0_2e = _mm_castpd_ps( _mm_load_sd(
|
||||
( double* )( temp_pack_a_buf + 4 ) ) );
|
||||
|
||||
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
|
||||
a + ic + ( lda * ( kr + 6 ) ) );
|
||||
g0 = _mm_loadu_ps( temp_pack_a_buf );
|
||||
g0_2e = _mm_castpd_ps( _mm_load_sd(
|
||||
( double* )( temp_pack_a_buf + 4 ) ) );
|
||||
|
||||
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
|
||||
a + ic + ( lda * ( kr + 7 ) ) );
|
||||
h0 = _mm_loadu_ps( temp_pack_a_buf );
|
||||
h0_2e = _mm_castpd_ps( _mm_load_sd(
|
||||
( double* )( temp_pack_a_buf + 4 ) ) );
|
||||
|
||||
F32_COL_MAJOR_K_PACK_STORE_SSE();
|
||||
}
|
||||
if ( kr_partial_pieces )
|
||||
{
|
||||
for ( dim_t kr = kr_full_pieces_loop_limit; kr < KC; ++kr )
|
||||
{
|
||||
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
|
||||
a + ic + ( lda * ( kr + 0 ) ) );
|
||||
a0 = _mm_loadu_ps( temp_pack_a_buf );
|
||||
a0_2e = _mm_castpd_ps( _mm_load_sd(
|
||||
( double* )( temp_pack_a_buf + 4 ) ) );
|
||||
|
||||
_mm_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 0 ) ),
|
||||
a0
|
||||
);
|
||||
_mm_store_sd
|
||||
(
|
||||
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + \
|
||||
( MR * 0 ) + 4 ) ),
|
||||
_mm_castps_pd( a0_2e )
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void packa_mr6_f32f32f32of32_avx2
|
||||
(
|
||||
float* pack_a_buf,
|
||||
const float* a,
|
||||
const dim_t rs,
|
||||
const dim_t cs,
|
||||
const dim_t MC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_a,
|
||||
dim_t* cs_a
|
||||
)
|
||||
{
|
||||
if( cs == 1 )
|
||||
{
|
||||
packa_f32f32f32of32_row_major_avx2
|
||||
( pack_a_buf, a, rs, MC, KC, rs_a, cs_a );
|
||||
}
|
||||
else
|
||||
{
|
||||
packa_f32f32f32of32_col_major_avx2
|
||||
( pack_a_buf, a, cs, MC, KC, rs_a, cs_a );
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -759,4 +759,408 @@ void packa_mr16_f32f32f32of32_col_major
|
||||
*rs_p = KC;
|
||||
*cs_p = 1;
|
||||
}
|
||||
|
||||
#define F32_ROW_MAJOR_K_PACK_LOOP(pack_a_buf, KC, kr) \
|
||||
a01 = _mm512_unpacklo_ps( a0, b0 ); \
|
||||
a0 = _mm512_unpackhi_ps( a0, b0 ); \
|
||||
\
|
||||
c01 = _mm512_unpacklo_ps( c0, d0 ); \
|
||||
c0 = _mm512_unpackhi_ps( c0, d0 ); \
|
||||
\
|
||||
e01 = _mm512_unpacklo_ps( e0, f0 ); /* Elem 4 */ \
|
||||
e0 = _mm512_unpackhi_ps( e0, f0 ); /* Elem 5 */ \
|
||||
\
|
||||
b0 = _mm512_castpd_ps( _mm512_unpacklo_pd( _mm512_castps_pd( a01 ), \
|
||||
_mm512_castps_pd( c01 ) ) ); \
|
||||
a01 = _mm512_castpd_ps( _mm512_unpackhi_pd( _mm512_castps_pd( a01 ), \
|
||||
_mm512_castps_pd( c01 ) ) ); \
|
||||
\
|
||||
d0 = _mm512_castpd_ps( _mm512_unpacklo_pd( _mm512_castps_pd( a0 ), \
|
||||
_mm512_castps_pd( c0 ) ) ); \
|
||||
c01 = _mm512_castpd_ps( _mm512_unpackhi_pd( _mm512_castps_pd( a0 ), \
|
||||
_mm512_castps_pd( c0 ) ) ); \
|
||||
\
|
||||
a0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( b0 ), \
|
||||
selector1, _mm512_castps_pd( a01 ) ) ); \
|
||||
c0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( d0 ), \
|
||||
selector1, _mm512_castps_pd( c01 ) ) ); \
|
||||
b0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( b0 ), \
|
||||
selector1_1, _mm512_castps_pd( a01 ) ) ); \
|
||||
d0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( d0 ), \
|
||||
selector1_1, _mm512_castps_pd( c01 ) ) ); \
|
||||
\
|
||||
a01 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( a0 ), \
|
||||
selector2, _mm512_castps_pd( c0 ) ) ); /* a[0] */ \
|
||||
c01 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( b0 ), \
|
||||
selector2, _mm512_castps_pd( d0 ) ) ); /* a[2] */ \
|
||||
a0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( a0 ), \
|
||||
selector2_1, _mm512_castps_pd( c0 ) ) ); /* a[1] */ \
|
||||
c0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( b0 ), \
|
||||
selector2_1, _mm512_castps_pd( d0 ) ) ); /* a[3] */ \
|
||||
\
|
||||
/* First half */ \
|
||||
b0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( a01 ), \
|
||||
selector3, _mm512_castps_pd( e01 ) ) ); /* 1st 16 */ \
|
||||
a01 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( a01 ), \
|
||||
selector4, _mm512_castps_pd( e0 ) ) ); /* 1st 8 */ \
|
||||
d0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( a0 ), \
|
||||
selector5, _mm512_castps_pd( e01 ) ) ); /* 2nd 16 */ \
|
||||
a0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( a0 ), \
|
||||
selector6, _mm512_castps_pd( e0 ) ) ); /* 2nd 4 */ \
|
||||
\
|
||||
_mm512_storeu_ps( pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 0 ) ) ), b0 ); \
|
||||
_mm512_storeu_ps( pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 16 ) ) ) , a01 ); \
|
||||
_mm512_storeu_ps( pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 24 ) ) ), d0 ); \
|
||||
/* Last piece */ \
|
||||
last_piece = _mm512_castps512_ps256( a0 ); \
|
||||
_mm256_mask_storeu_ps \
|
||||
( \
|
||||
pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 40 ) ) ), \
|
||||
_cvtu32_mask16( 0xFFFF), \
|
||||
last_piece \
|
||||
); \
|
||||
\
|
||||
/* Second half */ \
|
||||
b0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( c01 ), \
|
||||
selector7, _mm512_castps_pd( e01 ) ) ); /* 3rd 16 */ \
|
||||
c01 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( c01 ), \
|
||||
selector8, _mm512_castps_pd( e0 ) ) ); /* 3rd 8 */ \
|
||||
d0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( c0 ), \
|
||||
selector9, _mm512_castps_pd( e01 ) ) ); /* 4th 16 */ \
|
||||
c0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( c0 ), \
|
||||
selector10, _mm512_castps_pd( e0 ) ) ); /* 4th 8 */ \
|
||||
\
|
||||
_mm512_storeu_ps( pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 48 ) ) ), b0 ); \
|
||||
_mm512_storeu_ps( pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 64 ) ) ) , c01 ); \
|
||||
_mm512_storeu_ps( pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 72 ) ) ), d0 ); \
|
||||
/* Last piece */ \
|
||||
last_piece = _mm512_castps512_ps256( c0 ); \
|
||||
_mm256_mask_storeu_ps \
|
||||
( \
|
||||
pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 88 ) ) ), \
|
||||
_cvtu32_mask16( 0xFFFF), \
|
||||
last_piece \
|
||||
); \
|
||||
|
||||
// Row Major Packing in blocks of MRxKC
|
||||
void packa_f32f32f32of32_row_major_avx512
|
||||
(
|
||||
float* pack_a_buf,
|
||||
const float* a,
|
||||
const dim_t lda,
|
||||
const dim_t MC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_a,
|
||||
dim_t* cs_a
|
||||
)
|
||||
{
|
||||
const dim_t MR = 6;
|
||||
const dim_t KR_NDIM = 16;
|
||||
|
||||
// Used for permuting the mm512i elements for use in vpdpbusd instruction.
|
||||
// These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3.
|
||||
// Adding 4 int32 wise gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7.
|
||||
__m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB );
|
||||
__m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF );
|
||||
__m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB );
|
||||
__m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF );
|
||||
|
||||
// First half.
|
||||
__m512i selector3 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x2, 0x3, 0x9, 0x4, 0x5 ); // 64 elems
|
||||
__m512i selector4 = _mm512_setr_epi64( 0x8, 0x6, 0x7, 0x9, 0x0, 0x0, 0x0, 0x0 ); // 32 elems
|
||||
__m512i selector5 = _mm512_setr_epi64( 0x0, 0x1, 0xA, 0x2, 0x3, 0xB, 0x4, 0x5 ); // 64 elems
|
||||
__m512i selector6 = _mm512_setr_epi64( 0xA, 0x6, 0x7, 0xB, 0x0, 0x0, 0x0, 0x0 ); // 32 elems
|
||||
|
||||
// Second half.
|
||||
__m512i selector7 = _mm512_setr_epi64( 0x0, 0x1, 0xC, 0x2, 0x3, 0xD, 0x4, 0x5 ); // 64 elems
|
||||
__m512i selector8 = _mm512_setr_epi64( 0xC, 0x6, 0x7, 0xD, 0x0, 0x0, 0x0, 0x0 ); // 32 elems
|
||||
__m512i selector9 = _mm512_setr_epi64( 0x0, 0x1, 0xE, 0x2, 0x3, 0xF, 0x4, 0x5 ); // 64 elems
|
||||
__m512i selector10 = _mm512_setr_epi64( 0xE, 0x6, 0x7, 0xF, 0x0, 0x0, 0x0, 0x0 ); // 32 elems
|
||||
|
||||
dim_t m_full_pieces = MC / MR;
|
||||
dim_t m_full_pieces_loop_limit = m_full_pieces * MR;
|
||||
dim_t m_partial_pieces = MC % MR;
|
||||
|
||||
dim_t kr_full_pieces = KC / KR_NDIM;
|
||||
dim_t kr_full_pieces_loop_limit = kr_full_pieces * KR_NDIM;
|
||||
dim_t kr_partial_pieces = KC % KR_NDIM;
|
||||
|
||||
__m512 a0;
|
||||
__m512 b0;
|
||||
__m512 c0;
|
||||
__m512 d0;
|
||||
__m512 e0;
|
||||
__m512 f0;
|
||||
__m512 a01;
|
||||
__m512 c01;
|
||||
__m512 e01;
|
||||
__m256 last_piece;
|
||||
|
||||
__mmask16 mmask[6];
|
||||
for ( dim_t ic = 0; ic < MC; ic += MR )
|
||||
{
|
||||
if ( ic == m_full_pieces_loop_limit )
|
||||
{
|
||||
for ( int ii = 0; ii < m_partial_pieces; ++ii )
|
||||
{
|
||||
mmask[ii] = _cvtu32_mask16( 0xFFFF );
|
||||
}
|
||||
for ( int ii = m_partial_pieces; ii < MR; ++ii )
|
||||
{
|
||||
mmask[ii] = _cvtu32_mask16( 0x0 );
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for ( int ii = 0; ii < MR; ++ii )
|
||||
{
|
||||
mmask[ii] = _cvtu32_mask16( 0xFFFF );
|
||||
}
|
||||
}
|
||||
for ( dim_t kr = 0; kr < kr_full_pieces_loop_limit; kr += KR_NDIM )
|
||||
{
|
||||
a0 = _mm512_maskz_loadu_ps( mmask[0], a + ( lda * ( ic + 0 ) ) + kr );
|
||||
b0 = _mm512_maskz_loadu_ps( mmask[1], a + ( lda * ( ic + 1 ) ) + kr );
|
||||
c0 = _mm512_maskz_loadu_ps( mmask[2], a + ( lda * ( ic + 2 ) ) + kr );
|
||||
d0 = _mm512_maskz_loadu_ps( mmask[3], a + ( lda * ( ic + 3 ) ) + kr );
|
||||
e0 = _mm512_maskz_loadu_ps( mmask[4], a + ( lda * ( ic + 4 ) ) + kr );
|
||||
f0 = _mm512_maskz_loadu_ps( mmask[5], a + ( lda * ( ic + 5 ) ) + kr );
|
||||
|
||||
F32_ROW_MAJOR_K_PACK_LOOP(pack_a_buf, KC, kr);
|
||||
}
|
||||
if ( kr_partial_pieces > 0 )
|
||||
{
|
||||
err_t r_val;
|
||||
size_t temp_size = MR * KR_NDIM * sizeof( float );
|
||||
float* temp_pack_a_buf = bli_malloc_user( temp_size, &r_val );
|
||||
|
||||
__mmask16 lmask = _cvtu32_mask16( 0xFFFF >> ( 16 - kr_partial_pieces ) );
|
||||
for ( int ii = 0; ii < MR; ++ii )
|
||||
{
|
||||
mmask[ii] = _mm512_kand( mmask[ii], lmask );
|
||||
}
|
||||
|
||||
a0 = _mm512_maskz_loadu_ps( mmask[0], a + ( lda * ( ic + 0 ) ) +
|
||||
kr_full_pieces_loop_limit );
|
||||
b0 = _mm512_maskz_loadu_ps( mmask[1], a + ( lda * ( ic + 1 ) ) +
|
||||
kr_full_pieces_loop_limit );
|
||||
c0 = _mm512_maskz_loadu_ps( mmask[2], a + ( lda * ( ic + 2 ) ) +
|
||||
kr_full_pieces_loop_limit );
|
||||
d0 = _mm512_maskz_loadu_ps( mmask[3], a + ( lda * ( ic + 3 ) ) +
|
||||
kr_full_pieces_loop_limit );
|
||||
e0 = _mm512_maskz_loadu_ps( mmask[4], a + ( lda * ( ic + 4 ) ) +
|
||||
kr_full_pieces_loop_limit );
|
||||
f0 = _mm512_maskz_loadu_ps( mmask[5], a + ( lda * ( ic + 5 ) ) +
|
||||
kr_full_pieces_loop_limit );
|
||||
|
||||
F32_ROW_MAJOR_K_PACK_LOOP(temp_pack_a_buf, 0, 0);
|
||||
|
||||
memcpy
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( kr_full_pieces_loop_limit * MR ),
|
||||
temp_pack_a_buf,
|
||||
kr_partial_pieces * MR * sizeof( float )
|
||||
);
|
||||
|
||||
bli_free_user( temp_pack_a_buf );
|
||||
}
|
||||
}
|
||||
|
||||
*rs_a = 1;
|
||||
*cs_a = 6;
|
||||
}
|
||||
|
||||
void packa_f32f32f32of32_col_major_avx512
|
||||
(
|
||||
float* pack_a_buf,
|
||||
const float* a,
|
||||
const dim_t lda,
|
||||
const dim_t MC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_a,
|
||||
dim_t* cs_a
|
||||
)
|
||||
{
|
||||
const dim_t MR = 6;
|
||||
const dim_t KR_NDIM = 16;
|
||||
|
||||
dim_t m_full_pieces = MC / MR;
|
||||
dim_t m_full_pieces_loop_limit = m_full_pieces * MR;
|
||||
dim_t m_partial_pieces = MC % MR;
|
||||
|
||||
dim_t kr_full_pieces = KC / KR_NDIM;
|
||||
dim_t kr_full_pieces_loop_limit = kr_full_pieces * KR_NDIM;
|
||||
dim_t kr_partial_pieces = KC % KR_NDIM;
|
||||
|
||||
__m256 a0;
|
||||
__m256 b0;
|
||||
__m256 c0;
|
||||
__m256 d0;
|
||||
__m256 e0;
|
||||
__m256 f0;
|
||||
__m256 g0;
|
||||
__m256 h0;
|
||||
__m256 i0;
|
||||
__m256 j0;
|
||||
__m256 k0;
|
||||
__m256 l0;
|
||||
__m256 m0;
|
||||
__m256 n0;
|
||||
__m256 o0;
|
||||
__m256 p0;
|
||||
|
||||
__mmask16 mmask[16];
|
||||
|
||||
for ( dim_t ic = 0; ic < MC; ic += MR )
|
||||
{
|
||||
if ( ic == m_full_pieces_loop_limit )
|
||||
{
|
||||
for ( int ii = 0; ii < 16; ++ii )
|
||||
{
|
||||
mmask[ii] = _cvtu32_mask16( 0x3F >> ( MR - m_partial_pieces ) );
|
||||
}
|
||||
}
|
||||
/* Inside the kr loop, the mmask is modified. Need to reset it
|
||||
* at beginning of each ic loop iteration. */
|
||||
else
|
||||
{
|
||||
for ( int ii = 0; ii < 16; ++ii )
|
||||
{
|
||||
mmask[ii] = _cvtu32_mask16( 0x3F );
|
||||
}
|
||||
}
|
||||
for ( dim_t kr = 0; kr < KC; kr += KR_NDIM )
|
||||
{
|
||||
if ( kr == kr_full_pieces_loop_limit )
|
||||
{
|
||||
for ( int ii = kr_partial_pieces; ii < 16; ++ii )
|
||||
{
|
||||
mmask[ii] = _cvtu32_mask16( 0x0 );
|
||||
}
|
||||
}
|
||||
a0 = _mm256_maskz_loadu_ps( mmask[0], a + ic + ( lda * ( kr + 0 ) ) );
|
||||
b0 = _mm256_maskz_loadu_ps( mmask[1], a + ic + ( lda * ( kr + 1 ) ) );
|
||||
c0 = _mm256_maskz_loadu_ps( mmask[2], a + ic + ( lda * ( kr + 2 ) ) );
|
||||
d0 = _mm256_maskz_loadu_ps( mmask[3], a + ic + ( lda * ( kr + 3 ) ) );
|
||||
e0 = _mm256_maskz_loadu_ps( mmask[4], a + ic + ( lda * ( kr + 4 ) ) );
|
||||
f0 = _mm256_maskz_loadu_ps( mmask[5], a + ic + ( lda * ( kr + 5 ) ) );
|
||||
g0 = _mm256_maskz_loadu_ps( mmask[6], a + ic + ( lda * ( kr + 6 ) ) );
|
||||
h0 = _mm256_maskz_loadu_ps( mmask[7], a + ic + ( lda * ( kr + 7 ) ) );
|
||||
i0 = _mm256_maskz_loadu_ps( mmask[8], a + ic + ( lda * ( kr + 8 ) ) );
|
||||
j0 = _mm256_maskz_loadu_ps( mmask[9], a + ic + ( lda * ( kr + 9 ) ) );
|
||||
k0 = _mm256_maskz_loadu_ps( mmask[10], a + ic + ( lda * ( kr + 10 ) ) );
|
||||
l0 = _mm256_maskz_loadu_ps( mmask[11], a + ic + ( lda * ( kr + 11 ) ) );
|
||||
m0 = _mm256_maskz_loadu_ps( mmask[12], a + ic + ( lda * ( kr + 12 ) ) );
|
||||
n0 = _mm256_maskz_loadu_ps( mmask[13], a + ic + ( lda * ( kr + 13 ) ) );
|
||||
o0 = _mm256_maskz_loadu_ps( mmask[14], a + ic + ( lda * ( kr + 14 ) ) );
|
||||
p0 = _mm256_maskz_loadu_ps( mmask[15], a + ic + ( lda * ( kr + 15 ) ) );
|
||||
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 0 ) ),
|
||||
mmask[0], a0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 1 ) ),
|
||||
mmask[1], b0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 2 ) ),
|
||||
mmask[2], c0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 3 ) ),
|
||||
mmask[3], d0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 4 ) ),
|
||||
mmask[4], e0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 5 ) ),
|
||||
mmask[5], f0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 6 ) ),
|
||||
mmask[6], g0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 7 ) ),
|
||||
mmask[7], h0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 8 ) ),
|
||||
mmask[8], i0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 9 ) ),
|
||||
mmask[9], j0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 10 ) ),
|
||||
mmask[10], k0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 11 ) ),
|
||||
mmask[11], l0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 12 ) ),
|
||||
mmask[12], m0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 13 ) ),
|
||||
mmask[13], n0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 14 ) ),
|
||||
mmask[14], o0
|
||||
);
|
||||
_mm256_mask_storeu_ps
|
||||
(
|
||||
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 15 ) ),
|
||||
mmask[15], p0
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void packa_mr6_f32f32f32of32_avx512
|
||||
(
|
||||
float* pack_a_buf,
|
||||
const float* a,
|
||||
const dim_t rs,
|
||||
const dim_t cs,
|
||||
const dim_t MC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_a,
|
||||
dim_t* cs_a
|
||||
)
|
||||
{
|
||||
if( cs == 1 )
|
||||
{
|
||||
packa_f32f32f32of32_row_major_avx512
|
||||
( pack_a_buf, a, rs, MC, KC, rs_a, cs_a );
|
||||
}
|
||||
else
|
||||
{
|
||||
packa_f32f32f32of32_col_major_avx512
|
||||
( pack_a_buf, a, cs, MC, KC, rs_a, cs_a );
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -115,16 +115,16 @@ void packa_u8s8s32os32
|
||||
dim_t* cs_a
|
||||
)
|
||||
{
|
||||
if( cs == 1 )
|
||||
{
|
||||
packa_k64_u8s8s32o32
|
||||
( pack_a_buffer_u8s8s32o32, a, rs, MC, KC, rs_a, cs_a );
|
||||
}
|
||||
else
|
||||
if ( rs == 1 )
|
||||
{
|
||||
packa_mr16_u8s8s32o32_col_major
|
||||
( pack_a_buffer_u8s8s32o32, a, rs, cs, MC, KC, rs_a, cs_a );
|
||||
}
|
||||
else
|
||||
{
|
||||
packa_k64_u8s8s32o32
|
||||
( pack_a_buffer_u8s8s32o32, a, rs, MC, KC, rs_a, cs_a );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user