New A packing kernels for F32 API in LPGEMM.

-New packing kernels for A matrix, both based on AVX512 and AVX2 ISA,
for both row and column major storage are added as part of this change.
Dependency on haswell A packing kernels are removed by this.
-Tiny GEMM thresholds are further tuned for BF16 and F32 APIs.

AMD-Internal: [SWLCSG-3380, SWLCSG-3415]

Change-Id: I7330defacbacc9d07037ce1baf4a441f941e59be
This commit is contained in:
Mithun Mohan
2025-02-21 14:19:37 +00:00
parent 8a69141294
commit 7394aafd1e
13 changed files with 1106 additions and 96 deletions

View File

@@ -56,11 +56,19 @@ static inline bool is_tiny_input_bf16obf16
const dim_t NC = lcntx->blksz.NC;
const dim_t MC = lcntx->blksz.MC;
const dim_t KC = lcntx->blksz.KC;
const dim_t MR = lcntx->blksz.MR;
const dim_t NR = lcntx->blksz.NR;
dim_t mnk = m * n * k;
const dim_t mnk_magic_num = 36 * 128 * 128;
const dim_t m_thresh = 6 * MR;
const dim_t n_thresh = 2 * NR;
const dim_t k_thresh = 1024;
// Need to explicitly check for MC, NC boundaries for safety.
if ( ( k < 256 ) && ( m <= MC ) && ( n < NC ) && ( k < KC ) &&
( ( ( m <= 36 ) && ( n <= 64 ) ) ||
( ( m <= 12 ) && ( n <= 128 ) ) ) )
if ( ( m <= MC ) && ( n < NC ) && ( k < KC ) &&
( ( m <= m_thresh ) && ( n <= n_thresh ) && ( k <= k_thresh ) &&
( mnk < mnk_magic_num ) ) )
{
is_tiny = TRUE;
}

View File

@@ -56,11 +56,19 @@ static inline bool is_tiny_input_bf16of32
const dim_t NC = lcntx->blksz.NC;
const dim_t MC = lcntx->blksz.MC;
const dim_t KC = lcntx->blksz.KC;
const dim_t MR = lcntx->blksz.MR;
const dim_t NR = lcntx->blksz.NR;
dim_t mnk = m * n * k;
const dim_t mnk_magic_num = 36 * 128 * 128;
const dim_t m_thresh = 6 * MR;
const dim_t n_thresh = 2 * NR;
const dim_t k_thresh = 1024;
// Need to explicitly check for MC, NC boundaries for safety.
if ( ( k < 256 ) && ( m <= MC ) && ( n < NC ) && ( k < KC ) &&
( ( ( m <= 36 ) && ( n <= 64 ) ) ||
( ( m <= 12 ) && ( n <= 128 ) ) ) )
if ( ( m <= MC ) && ( n < NC ) && ( k < KC ) &&
( ( m <= m_thresh ) && ( n <= n_thresh ) && ( k <= k_thresh ) &&
( mnk < mnk_magic_num ) ) )
{
is_tiny = TRUE;
}

View File

@@ -55,11 +55,20 @@ static inline bool is_tiny_input_f32
const dim_t NC = lcntx->blksz.NC;
const dim_t MC = lcntx->blksz.MC;
const dim_t KC = lcntx->blksz.KC;
const dim_t MR = lcntx->blksz.MR;
const dim_t NR = lcntx->blksz.NR;
dim_t mnk = m * n * k;
const dim_t mnk_magic_num = 12 * 64 * 496;
const dim_t m_thresh = 6 * MR;
const dim_t n_thresh = 2 * NR;
const dim_t k_thresh = 480;
// Need to explicitly check for MC, NC boundaries for safety.
if ( ( k < 128 ) && ( m <= MC ) && ( n < NC ) &&
( ( ( m <= 36 ) && ( n <= 64 ) ) ||
( ( m <= 12 ) && ( n <= 128 ) ) ) )
if ( ( k < KC ) && ( m <= MC ) && ( n < NC ) &&
( ( m <= m_thresh ) && ( n <= n_thresh ) && ( k <= k_thresh ) &&
( mnk < mnk_magic_num ) ) )
{
is_tiny = TRUE;
}

View File

@@ -245,6 +245,7 @@ static void _lpgemm_cntx_init_func_map()
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
{
LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
LPGEMM_PACKA_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2;
}
}
@@ -260,6 +261,7 @@ static void _lpgemm_cntx_init_func_map()
if ( global_lpgemm_enable_arch == BLIS_ARCH_ZEN3 )
{
LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2
LPGEMM_PACKA_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2;
LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2;
}
}

View File

@@ -59,10 +59,14 @@
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16 \
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx512) \
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
#define LPGEMM_PACKA_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2 \
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx2) \
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
@@ -110,10 +114,14 @@
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI \
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx512) \
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
#define LPGEMM_PACKA_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx2) \
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI \
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
@@ -147,10 +155,14 @@
#define LPGEMM_PACKA_FUNC_MAP_AVX512 \
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx512) \
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
#define LPGEMM_PACKA_FUNC_UPD_MAP_AVX512_TO_AVX2 \
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx2) \
#define LPGEMM_PACKB_FUNC_MAP_AVX512 \
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
@@ -160,6 +172,9 @@
PBMACRO(BF16S4F32OF32, NULL) \
PBSMACRO(BF16S4F32OF32, NULL)
#define LPGEMM_PACKB_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512 \
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
@@ -180,6 +195,7 @@
#define LPGEMM_PACKA_FUNC_MAP_AVX2 \
PAMACRO(U8S8S32OS32, NULL) \
PAMACRO(BF16BF16F32OF32, NULL) \
PAMACRO(F32F32F32OF32, packa_mr6_f32f32f32of32_avx2) \
KMACRO(BF16S4F32OF32, NULL) \
PAMACRO(S8S8S32OS32, NULL) \

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -62,19 +62,6 @@ typedef void (*lpgemm_rowvar_f32)
lpgemm_post_op_attr
);
void lpgemm_pack_a_f32f32f32of32
(
const float* input_buf_addr_a,
float* reorder_buf_addr_a,
const dim_t m,
const dim_t k,
const dim_t rs_a,
const dim_t cs_a,
const dim_t ps_p,
const dim_t MR,
cntx_t* cntx
);
#ifdef BLIS_KERNELS_ZEN4
LPGEMV(float, float, float, f32f32f32of32)
{
@@ -352,8 +339,6 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
return;
}
#endif
// Query the global cntx.
cntx_t* cntx = bli_gks_query_cntx();
// Query the context for various blocksizes.
const dim_t NC = lcntx->blksz.NC;
@@ -385,6 +370,7 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
auxinfo_t aux;
// Check if packing of A is required.
// TODO: mtag_a for tranpose needs to be honored.
bool should_pack_A = bli_rntm_pack_a( rntm );
// Pack buffer for A.
@@ -594,13 +580,13 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
cs_a_use = MR;
ps_a_use = MR * kc0;
lpgemm_pack_a_f32f32f32of32
( ( lpgemm_pack_f32 )lcntx->packa_fun_ptr )
(
( a + ( rs_a * ic ) + ( pc * cs_a) ),
pack_a_buffer_f32f32f32of32,
( a + ( rs_a * ic ) + ( pc * cs_a) ),
rs_a, cs_a,
mc0, kc0,
rs_a, cs_a, ps_a_use, MR,
cntx
&rs_a_use, &cs_a_use
);
a_use = pack_a_buffer_f32f32f32of32;
@@ -671,58 +657,3 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
}
}
}
void lpgemm_pack_a_f32f32f32of32
(
const float* input_buf_addr_a,
float* reorder_buf_addr_a,
const dim_t m,
const dim_t k,
const dim_t rs_a,
const dim_t cs_a,
const dim_t ps_p,
const dim_t MR,
cntx_t* cntx
)
{
float one_local = *PASTEMAC(s,1);
float* restrict kappa_cast = &one_local;
// Set the schema to "column stored row panels" to indicate packing to conventional
// column-stored row panels.
pack_t schema = BLIS_PACKED_ROW_PANELS;
trans_t transc = BLIS_NO_TRANSPOSE;
conj_t conjc = bli_extract_conj( transc );
// Compute the total number of iterations we'll need.
dim_t m_iter = ( m + MR - 1 ) / MR;
inc_t cs_p = MR;
float* p_temp = reorder_buf_addr_a;
dim_t ir, it;
// Iterate over every logical micropanel in the source mmatrix.
for ( ir = 0, it = 0; it < m_iter; ir += MR, it += 1 )
{
dim_t panel_dim_i = bli_min( MR, m - ir );
const float* a_use = input_buf_addr_a + ( ir * rs_a );
float* p_use = p_temp;
PASTEMAC(s,packm_cxk)
(
conjc,
schema,
panel_dim_i,
MR,
k,
k,
kappa_cast,
( float* )a_use, rs_a, cs_a,
p_use, cs_p,
cntx
);
p_temp += ps_p;
}
}

View File

@@ -250,7 +250,7 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
);
get_packa_strides_mfringe_u8s8s32os32
(
&rs_a_use, &cs_a_use, gemm_MR, 1
rs_a, cs_a, &rs_a_use, &cs_a_use, gemm_MR, 1
);
a_use = pack_a_buffer_s8s8s32os32;

View File

@@ -231,7 +231,7 @@ LPGEMV(uint8_t,int8_t,int32_t,u8s8s32os32)
);
get_packa_strides_mfringe_u8s8s32os32
(
&rs_a_use, &cs_a_use, gemm_MR, 1
rs_a, cs_a, &rs_a_use, &cs_a_use, gemm_MR, 1
);
a_use = pack_a_buffer;

View File

@@ -46,6 +46,30 @@ void packa_mr16_f32f32f32of32_col_major
dim_t* cs_p
);
void packa_mr6_f32f32f32of32_avx512
(
float* pack_a_buf,
const float* a,
const dim_t rs,
const dim_t cs,
const dim_t MC,
const dim_t KC,
dim_t* rs_a,
dim_t* cs_a
);
void packa_mr6_f32f32f32of32_avx2
(
float* pack_a_buf,
const float* a,
const dim_t rs,
const dim_t cs,
const dim_t MC,
const dim_t KC,
dim_t* rs_a,
dim_t* cs_a
);
typedef void (*lpgemm_pack_f32)
(
float*,

View File

@@ -39,14 +39,20 @@
// for different schemas used to pack A fringe cases.
BLIS_INLINE void get_packa_strides_mfringe_u8s8s32os32
(
dim_t* rs,
dim_t* cs,
const dim_t rs,
const dim_t cs,
dim_t* rs_use,
dim_t* cs_use,
dim_t MR,
dim_t m_fringe
)
{
( *rs ) = 4;
( *cs ) = ( ( *cs ) / MR ) * m_fringe;
// Only applicable for row major packing.
if ( ( rs != 1 ) && ( cs == 1 ) && ( ( *cs_use ) > MR ))
{
( *rs_use ) = 4;
( *cs_use ) = ( ( *cs_use ) / MR ) * m_fringe;
}
}
typedef void (*packa_s32)

View File

@@ -0,0 +1,602 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <immintrin.h>
#include <string.h>
#include "blis.h"
#ifdef BLIS_ADDON_LPGEMM
#define F32_ROW_MAJOR_K_PACK_LOOP_AVX2() \
a0 = _mm256_unpacklo_ps( a01, b0 ); \
b0 = _mm256_unpackhi_ps( a01, b0 ); \
\
c0 = _mm256_unpacklo_ps( c01, d0 ); \
d0 = _mm256_unpackhi_ps( c01, d0 ); \
\
e0 = _mm256_unpacklo_ps( e01, f0 ); \
f0 = _mm256_unpackhi_ps( e01, f0 ); \
\
a01 = _mm256_castpd_ps( _mm256_unpacklo_pd( _mm256_castps_pd( a0 ), \
_mm256_castps_pd( c0 ) ) ); \
a0 = _mm256_castpd_ps( _mm256_unpackhi_pd( _mm256_castps_pd( a0 ), \
_mm256_castps_pd( c0 ) ) ); \
\
c01 = _mm256_castpd_ps( _mm256_unpacklo_pd( _mm256_castps_pd( b0 ), \
_mm256_castps_pd( d0 ) ) ); \
c0 = _mm256_castpd_ps( _mm256_unpackhi_pd( _mm256_castps_pd( b0 ), \
_mm256_castps_pd( d0 ) ) ); \
\
a0_128 = _mm256_castps256_ps128( a01 ); \
b0_128 = _mm256_castps256_ps128( a0 ); \
c0_128 = _mm256_castps256_ps128( c01 ); \
d0_128 = _mm256_castps256_ps128( c0 ); \
e0_128 = _mm256_castps256_ps128( e0 ); \
f0_128 = _mm256_castps256_ps128( f0 ); \
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 0, a0_128 ); \
_mm_storel_pd( ( double*)( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 4 ), \
_mm_castps_pd( e0_128 ) ); \
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 6, b0_128 ); \
_mm_storeh_pd( ( double* )( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 10 ), \
_mm_castps_pd( e0_128 ) ); \
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 12, c0_128 ); \
_mm_storel_pd( ( double* )( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 16 ), \
_mm_castps_pd( f0_128 ) ); \
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 18, d0_128 ); \
_mm_storeh_pd( ( double* )( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 22 ), \
_mm_castps_pd( f0_128 ) ); \
\
a0_128 = _mm256_extractf128_ps( a01, 0x1 ); \
b0_128 = _mm256_extractf128_ps( a0, 0x1 ); \
c0_128 = _mm256_extractf128_ps( c01, 0x1 ); \
d0_128 = _mm256_extractf128_ps( c0, 0x1 ); \
e0_128 = _mm256_extractf128_ps( e0, 0x1 ); \
f0_128 = _mm256_extractf128_ps( f0, 0x1 ); \
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 24, a0_128 ); \
_mm_storel_pd( ( double* )( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 28 ), \
_mm_castps_pd( e0_128 ) ); \
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 30, b0_128 ); \
_mm_storeh_pd( ( double*)(pack_a_buf + ( ic * KC ) + ( kr * MR ) + 34 ), \
_mm_castps_pd( e0_128 ) ); \
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 36, c0_128 ); \
_mm_storel_pd( ( double*)(pack_a_buf + ( ic * KC ) + ( kr * MR ) + 40 ), \
_mm_castps_pd( f0_128 ) ); \
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 42, d0_128 ); \
_mm_storeh_pd( ( double*)(pack_a_buf + ( ic * KC ) + ( kr * MR ) + 46 ), \
_mm_castps_pd( f0_128 ) ); \
#define F32_ROW_MAJOR_K_PACK_LOOP_SSE() \
a0_128 = _mm_unpacklo_ps( a01_128, b0_128 ); \
b0_128 = _mm_unpackhi_ps( a01_128, b0_128 ); \
\
c0_128 = _mm_unpacklo_ps( c01_128, d0_128 ); \
d0_128 = _mm_unpackhi_ps( c01_128, d0_128 ); \
\
e0_128 = _mm_unpacklo_ps( e01_128, f0_128 ); \
f0_128 = _mm_unpackhi_ps( e01_128, f0_128 ); \
\
a01_128 = _mm_castpd_ps( _mm_unpacklo_pd( _mm_castps_pd( a0_128 ), \
_mm_castps_pd( c0_128 ) ) ); \
a0_128 = _mm_castpd_ps( _mm_unpackhi_pd( _mm_castps_pd( a0_128 ), \
_mm_castps_pd( c0_128 ) ) ); \
\
c01_128 = _mm_castpd_ps( _mm_unpacklo_pd( _mm_castps_pd( b0_128 ), \
_mm_castps_pd( d0_128 ) ) ); \
c0_128 = _mm_castpd_ps( _mm_unpackhi_pd( _mm_castps_pd( b0_128 ), \
_mm_castps_pd( d0_128 ) ) ); \
\
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 0, a01_128 ); \
_mm_storel_pd( ( double*)( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 4 ), \
_mm_castps_pd( e0_128 ) ); \
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 6, a0_128 ); \
_mm_storeh_pd( ( double*)( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 10 ), \
_mm_castps_pd( e0_128 ) ); \
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 12, c01_128 ); \
_mm_storel_pd( ( double*)( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 16 ), \
_mm_castps_pd( f0_128 ) ); \
_mm_storeu_ps( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 18, c0_128 ); \
_mm_storeh_pd( ( double*)( pack_a_buf + ( ic * KC ) + ( kr * MR ) + 22 ), \
_mm_castps_pd( f0_128 ) ); \
// Row Major Packing in blocks of MRxKC
void packa_f32f32f32of32_row_major_avx2
(
float* pack_a_buf,
const float* a,
const dim_t lda,
const dim_t MC,
const dim_t KC,
dim_t* rs_a,
dim_t* cs_a
)
{
const dim_t MR = 6;
const dim_t KR_NDIM = 8;
dim_t m_full_pieces = MC / MR;
dim_t m_full_pieces_loop_limit = m_full_pieces * MR;
dim_t m_partial_pieces = MC % MR;
dim_t kr_full_pieces = KC / KR_NDIM;
dim_t kr_full_pieces_loop_limit = kr_full_pieces * KR_NDIM;
dim_t kr_partial_pieces = KC % KR_NDIM;
__m256 a0;
__m256 b0;
__m256 c0;
__m256 d0;
__m256 e0;
__m256 f0;
__m256 a01;
__m256 c01;
__m256 e01;
__m128 a0_128;
__m128 b0_128;
__m128 c0_128;
__m128 d0_128;
__m128 e0_128;
__m128 f0_128;
for ( dim_t ic = 0; ic < m_full_pieces_loop_limit; ic += MR )
{
for ( dim_t kr = 0; kr < kr_full_pieces_loop_limit; kr += KR_NDIM )
{
a01 = _mm256_loadu_ps( a + ( lda * ( ic + 0 ) ) + kr );
b0 = _mm256_loadu_ps( a + ( lda * ( ic + 1 ) ) + kr );
c01 = _mm256_loadu_ps( a + ( lda * ( ic + 2 ) ) + kr );
d0 = _mm256_loadu_ps( a + ( lda * ( ic + 3 ) ) + kr );
e01 = _mm256_loadu_ps( a + ( lda * ( ic + 4 ) ) + kr );
f0 = _mm256_loadu_ps( a + ( lda * ( ic + 5 ) ) + kr );
F32_ROW_MAJOR_K_PACK_LOOP_AVX2();
}
if ( kr_partial_pieces > 0 )
{
dim_t kr_partial_pieces_4 = ( kr_partial_pieces / 4 ) * 4;
dim_t kr_partial_pieces_rem = kr_partial_pieces - kr_partial_pieces_4;
dim_t kr = kr_full_pieces_loop_limit;
if ( kr_partial_pieces_4 > 0 )
{
__m128 a01_128 = _mm_loadu_ps( a + ( lda * ( ic + 0 ) ) + kr );
b0_128 = _mm_loadu_ps( a + ( lda * ( ic + 1 ) ) + kr );
__m128 c01_128 = _mm_loadu_ps( a + ( lda * ( ic + 2 ) ) + kr );
d0_128 = _mm_loadu_ps( a + ( lda * ( ic + 3 ) ) + kr );
__m128 e01_128 = _mm_loadu_ps( a + ( lda * ( ic + 4 ) ) + kr );
f0_128 = _mm_loadu_ps( a + ( lda * ( ic + 5 ) ) + kr );
F32_ROW_MAJOR_K_PACK_LOOP_SSE();
}
kr += kr_partial_pieces_4;
if ( kr_partial_pieces_rem > 0 )
{
for ( int ii = 0; ii < kr_partial_pieces_rem; ++ii )
{
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + 0 ) =
*( a + ( lda * ( ic + 0 ) ) + ( kr + ii ) );
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + 1 ) =
*( a + ( lda * ( ic + 1 ) ) + ( kr + ii ) );
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + 2 ) =
*( a + ( lda * ( ic + 2 ) ) + ( kr + ii ) );
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + 3 ) =
*( a + ( lda * ( ic + 3 ) ) + ( kr + ii ) );
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + 4 ) =
*( a + ( lda * ( ic + 4 ) ) + ( kr + ii ) );
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + 5 ) =
*( a + ( lda * ( ic + 5 ) ) + ( kr + ii ) );
}
}
}
}
if ( m_partial_pieces > 0 )
{
dim_t ic = m_full_pieces_loop_limit;
__m256 temp_a_reg[6];
for ( dim_t kr = 0; kr < kr_full_pieces_loop_limit; kr += KR_NDIM )
{
for ( int ii = 0; ii < m_partial_pieces; ++ii )
{
temp_a_reg[ii] = _mm256_loadu_ps( a + ( lda * ( ic + ii ) ) + kr );
}
for ( int ii = m_partial_pieces; ii < MR; ++ii )
{
temp_a_reg[ii] = _mm256_setzero_ps();
}
a01 = temp_a_reg[0];
b0 = temp_a_reg[1];
c01 = temp_a_reg[2];
d0 = temp_a_reg[3];
e01 = temp_a_reg[4];
f0 = temp_a_reg[5];
F32_ROW_MAJOR_K_PACK_LOOP_AVX2();
}
if ( kr_partial_pieces > 0 )
{
dim_t kr_partial_pieces_4 = ( kr_partial_pieces / 4 ) * 4;
dim_t kr_partial_pieces_rem = kr_partial_pieces - kr_partial_pieces_4;
dim_t kr = kr_full_pieces_loop_limit;
if ( kr_partial_pieces_4 > 0 )
{
__m128 temp_a_reg_128[6] = {0};
for ( int ii = 0; ii < m_partial_pieces; ++ii )
{
temp_a_reg_128[ii] = _mm_loadu_ps( a + ( lda * ( ic + ii ) ) + kr );
}
for ( int ii = m_partial_pieces; ii < MR; ++ii )
{
temp_a_reg_128[ii] = _mm_setzero_ps();
}
__m128 a01_128 = temp_a_reg_128[0];
b0_128 = temp_a_reg_128[1];
__m128 c01_128 = temp_a_reg_128[2];
d0_128 = temp_a_reg_128[3];
__m128 e01_128 = temp_a_reg_128[4];
f0_128 = temp_a_reg_128[5];
F32_ROW_MAJOR_K_PACK_LOOP_SSE();
}
kr += kr_partial_pieces_4;
if ( kr_partial_pieces_rem > 0 )
{
for ( int ii = 0; ii < kr_partial_pieces_rem; ++ii )
{
for ( int jj = 0; jj < m_partial_pieces; ++jj )
{
*( pack_a_buf + ( ic * KC ) + ( ( kr + ii ) * MR ) + jj ) =
*( a + ( lda * ( ic + jj ) ) + ( kr + ii ) );
}
}
}
}
}
*rs_a = 1;
*cs_a = 6;
}
#define F32_COL_MAJOR_K_PACK_STORE_SSE() \
_mm_storeu_ps \
( \
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 0 ) ), \
a0 \
); \
_mm_store_sd \
( \
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 0 ) + 4 ) ), \
_mm_castps_pd( a0_2e ) \
); \
_mm_storeu_ps \
( \
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 1 ) ), \
b0 \
); \
_mm_store_sd \
( \
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 1 ) + 4 ) ), \
_mm_castps_pd( b0_2e ) \
); \
_mm_storeu_ps \
( \
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 2 ) ), \
c0 \
); \
_mm_store_sd \
( \
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 2 ) + 4 ) ), \
_mm_castps_pd( c0_2e ) \
); \
_mm_storeu_ps \
( \
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 3 ) ), \
d0 \
); \
_mm_store_sd \
( \
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 3 ) + 4 ) ), \
_mm_castps_pd( d0_2e ) \
); \
_mm_storeu_ps \
( \
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 4 ) ), \
e0 \
); \
_mm_store_sd \
( \
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 4 ) + 4 ) ), \
_mm_castps_pd( e0_2e ) \
); \
_mm_storeu_ps \
( \
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 5 ) ), \
f0 \
); \
_mm_store_sd \
( \
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 5 ) + 4 ) ), \
_mm_castps_pd( f0_2e ) \
); \
_mm_storeu_ps \
( \
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 6 ) ), \
g0 \
); \
_mm_store_sd \
( \
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 6 ) + 4 ) ), \
_mm_castps_pd( g0_2e ) \
); \
_mm_storeu_ps \
( \
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 7 ) ), \
h0 \
); \
_mm_store_sd \
( \
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 7 ) + 4 ) ), \
_mm_castps_pd( h0_2e ) \
); \
#define F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, dst, src) \
cp_st = 0; \
if ( m_partial_4 > 0 ) \
{ \
( dst )[cp_st + 0] = ( src )[cp_st + 0]; \
( dst )[cp_st + 1] = ( src )[cp_st + 1]; \
( dst )[cp_st + 2] = ( src )[cp_st + 2]; \
( dst )[cp_st + 3] = ( src )[cp_st + 3]; \
cp_st += 4; \
} \
if ( m_partial_2 > 0 ) \
{ \
( dst )[cp_st + 0] = ( src )[cp_st + 0]; \
( dst )[cp_st + 1] = ( src )[cp_st + 1]; \
cp_st += 2; \
} \
if ( m_partial_1 > 0 ) \
{ \
( dst )[cp_st + 0] = ( src )[cp_st + 0]; \
cp_st += 1; \
} \
void packa_f32f32f32of32_col_major_avx2
(
float* pack_a_buf,
const float* a,
const dim_t lda,
const dim_t MC,
const dim_t KC,
dim_t* rs_a,
dim_t* cs_a
)
{
const dim_t MR = 6;
const dim_t KR_NDIM = 8;
dim_t m_full_pieces = MC / MR;
dim_t m_full_pieces_loop_limit = m_full_pieces * MR;
dim_t m_partial_pieces = MC % MR;
dim_t kr_full_pieces = KC / KR_NDIM;
dim_t kr_full_pieces_loop_limit = kr_full_pieces * KR_NDIM;
dim_t kr_partial_pieces = KC % KR_NDIM;
__m128 a0;
__m128 b0;
__m128 c0;
__m128 d0;
__m128 e0;
__m128 f0;
__m128 g0;
__m128 h0;
__m128 a0_2e;
__m128 b0_2e;
__m128 c0_2e;
__m128 d0_2e;
__m128 e0_2e;
__m128 f0_2e;
__m128 g0_2e;
__m128 h0_2e;
for ( dim_t ic = 0; ic < m_full_pieces_loop_limit; ic += MR )
{
for ( dim_t kr = 0; kr < kr_full_pieces_loop_limit; kr += KR_NDIM )
{
// First 4 elements.
a0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 0 ) ) );
b0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 1 ) ) );
c0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 2 ) ) );
d0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 3 ) ) );
e0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 4 ) ) );
f0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 5 ) ) );
g0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 6 ) ) );
h0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 7 ) ) );
// Last 2 elements.
a0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
( lda * ( kr + 0 ) ) ) ) );
b0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
( lda * ( kr + 1 ) ) ) ) );
c0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
( lda * ( kr + 2 ) ) ) ) );
d0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
( lda * ( kr + 3 ) ) ) ) );
e0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
( lda * ( kr + 4 ) ) ) ) );
f0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
( lda * ( kr + 5 ) ) ) ) );
g0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
( lda * ( kr + 6 ) ) ) ) );
h0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
( lda * ( kr + 7 ) ) ) ) );
F32_COL_MAJOR_K_PACK_STORE_SSE();
}
if ( kr_partial_pieces )
{
for ( dim_t kr = kr_full_pieces_loop_limit; kr < KC; ++kr )
{
a0 = _mm_loadu_ps( a + ic + ( lda * ( kr + 0 ) ) );
a0_2e = _mm_castpd_ps( _mm_load_sd( ( double* )( a + ( ic + 4 ) + \
( lda * ( kr + 0 ) ) ) ) );
_mm_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 0 ) ),
a0
);
_mm_store_sd
(
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + \
( MR * 0 ) + 4 ) ),
_mm_castps_pd( a0_2e )
);
}
}
}
if ( m_partial_pieces > 0 )
{
dim_t ic = m_full_pieces_loop_limit;
dim_t m_partial_4 = ( m_partial_pieces / 4 ) * 4;
dim_t m_partial_2 = ( ( m_partial_pieces - m_partial_4 ) / 2 ) * 2;
dim_t m_partial_1 = m_partial_pieces - ( m_partial_4 + m_partial_2 );
dim_t cp_st = 0;
float temp_pack_a_buf[6] = { 0 };
for ( dim_t kr = 0; kr < kr_full_pieces_loop_limit; kr += KR_NDIM )
{
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
a + ic + ( lda * ( kr + 0 ) ) );
a0 = _mm_loadu_ps( temp_pack_a_buf );
a0_2e = _mm_castpd_ps( _mm_load_sd(
( double* )( temp_pack_a_buf + 4 ) ) );
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
a + ic + ( lda * ( kr + 1 ) ) );
b0 = _mm_loadu_ps( temp_pack_a_buf );
b0_2e = _mm_castpd_ps( _mm_load_sd(
( double* )( temp_pack_a_buf + 4 ) ) );
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
a + ic + ( lda * ( kr + 2 ) ) );
c0 = _mm_loadu_ps( temp_pack_a_buf );
c0_2e = _mm_castpd_ps( _mm_load_sd(
( double* )( temp_pack_a_buf + 4 ) ) );
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
a + ic + ( lda * ( kr + 3 ) ) );
d0 = _mm_loadu_ps( temp_pack_a_buf );
d0_2e = _mm_castpd_ps( _mm_load_sd(
( double* )( temp_pack_a_buf + 4 ) ) );
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
a + ic + ( lda * ( kr + 4 ) ) );
e0 = _mm_loadu_ps( temp_pack_a_buf );
e0_2e = _mm_castpd_ps( _mm_load_sd(
( double* )( temp_pack_a_buf + 4 ) ) );
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
a + ic + ( lda * ( kr + 5 ) ) );
f0 = _mm_loadu_ps( temp_pack_a_buf );
f0_2e = _mm_castpd_ps( _mm_load_sd(
( double* )( temp_pack_a_buf + 4 ) ) );
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
a + ic + ( lda * ( kr + 6 ) ) );
g0 = _mm_loadu_ps( temp_pack_a_buf );
g0_2e = _mm_castpd_ps( _mm_load_sd(
( double* )( temp_pack_a_buf + 4 ) ) );
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
a + ic + ( lda * ( kr + 7 ) ) );
h0 = _mm_loadu_ps( temp_pack_a_buf );
h0_2e = _mm_castpd_ps( _mm_load_sd(
( double* )( temp_pack_a_buf + 4 ) ) );
F32_COL_MAJOR_K_PACK_STORE_SSE();
}
if ( kr_partial_pieces )
{
for ( dim_t kr = kr_full_pieces_loop_limit; kr < KC; ++kr )
{
F32_COL_MAJOR_K_PACK_LOAD_MEMCPY(cp_st, temp_pack_a_buf, \
a + ic + ( lda * ( kr + 0 ) ) );
a0 = _mm_loadu_ps( temp_pack_a_buf );
a0_2e = _mm_castpd_ps( _mm_load_sd(
( double* )( temp_pack_a_buf + 4 ) ) );
_mm_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 0 ) ),
a0
);
_mm_store_sd
(
( double* )( pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + \
( MR * 0 ) + 4 ) ),
_mm_castps_pd( a0_2e )
);
}
}
}
}
void packa_mr6_f32f32f32of32_avx2
(
float* pack_a_buf,
const float* a,
const dim_t rs,
const dim_t cs,
const dim_t MC,
const dim_t KC,
dim_t* rs_a,
dim_t* cs_a
)
{
if( cs == 1 )
{
packa_f32f32f32of32_row_major_avx2
( pack_a_buf, a, rs, MC, KC, rs_a, cs_a );
}
else
{
packa_f32f32f32of32_col_major_avx2
( pack_a_buf, a, cs, MC, KC, rs_a, cs_a );
}
}
#endif

View File

@@ -759,4 +759,408 @@ void packa_mr16_f32f32f32of32_col_major
*rs_p = KC;
*cs_p = 1;
}
#define F32_ROW_MAJOR_K_PACK_LOOP(pack_a_buf, KC, kr) \
a01 = _mm512_unpacklo_ps( a0, b0 ); \
a0 = _mm512_unpackhi_ps( a0, b0 ); \
\
c01 = _mm512_unpacklo_ps( c0, d0 ); \
c0 = _mm512_unpackhi_ps( c0, d0 ); \
\
e01 = _mm512_unpacklo_ps( e0, f0 ); /* Elem 4 */ \
e0 = _mm512_unpackhi_ps( e0, f0 ); /* Elem 5 */ \
\
b0 = _mm512_castpd_ps( _mm512_unpacklo_pd( _mm512_castps_pd( a01 ), \
_mm512_castps_pd( c01 ) ) ); \
a01 = _mm512_castpd_ps( _mm512_unpackhi_pd( _mm512_castps_pd( a01 ), \
_mm512_castps_pd( c01 ) ) ); \
\
d0 = _mm512_castpd_ps( _mm512_unpacklo_pd( _mm512_castps_pd( a0 ), \
_mm512_castps_pd( c0 ) ) ); \
c01 = _mm512_castpd_ps( _mm512_unpackhi_pd( _mm512_castps_pd( a0 ), \
_mm512_castps_pd( c0 ) ) ); \
\
a0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( b0 ), \
selector1, _mm512_castps_pd( a01 ) ) ); \
c0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( d0 ), \
selector1, _mm512_castps_pd( c01 ) ) ); \
b0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( b0 ), \
selector1_1, _mm512_castps_pd( a01 ) ) ); \
d0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( d0 ), \
selector1_1, _mm512_castps_pd( c01 ) ) ); \
\
a01 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( a0 ), \
selector2, _mm512_castps_pd( c0 ) ) ); /* a[0] */ \
c01 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( b0 ), \
selector2, _mm512_castps_pd( d0 ) ) ); /* a[2] */ \
a0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( a0 ), \
selector2_1, _mm512_castps_pd( c0 ) ) ); /* a[1] */ \
c0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( b0 ), \
selector2_1, _mm512_castps_pd( d0 ) ) ); /* a[3] */ \
\
/* First half */ \
b0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( a01 ), \
selector3, _mm512_castps_pd( e01 ) ) ); /* 1st 16 */ \
a01 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( a01 ), \
selector4, _mm512_castps_pd( e0 ) ) ); /* 1st 8 */ \
d0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( a0 ), \
selector5, _mm512_castps_pd( e01 ) ) ); /* 2nd 16 */ \
a0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( a0 ), \
selector6, _mm512_castps_pd( e0 ) ) ); /* 2nd 4 */ \
\
_mm512_storeu_ps( pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 0 ) ) ), b0 ); \
_mm512_storeu_ps( pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 16 ) ) ) , a01 ); \
_mm512_storeu_ps( pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 24 ) ) ), d0 ); \
/* Last piece */ \
last_piece = _mm512_castps512_ps256( a0 ); \
_mm256_mask_storeu_ps \
( \
pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 40 ) ) ), \
_cvtu32_mask16( 0xFFFF), \
last_piece \
); \
\
/* Second half */ \
b0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( c01 ), \
selector7, _mm512_castps_pd( e01 ) ) ); /* 3rd 16 */ \
c01 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( c01 ), \
selector8, _mm512_castps_pd( e0 ) ) ); /* 3rd 8 */ \
d0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( c0 ), \
selector9, _mm512_castps_pd( e01 ) ) ); /* 4th 16 */ \
c0 = _mm512_castpd_ps( _mm512_permutex2var_pd( _mm512_castps_pd( c0 ), \
selector10, _mm512_castps_pd( e0 ) ) ); /* 4th 8 */ \
\
_mm512_storeu_ps( pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 48 ) ) ), b0 ); \
_mm512_storeu_ps( pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 64 ) ) ) , c01 ); \
_mm512_storeu_ps( pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 72 ) ) ), d0 ); \
/* Last piece */ \
last_piece = _mm512_castps512_ps256( c0 ); \
_mm256_mask_storeu_ps \
( \
pack_a_buf + ( ( ic * KC ) + ( ( kr * MR ) + ( 88 ) ) ), \
_cvtu32_mask16( 0xFFFF), \
last_piece \
); \
// Row Major Packing in blocks of MRxKC
void packa_f32f32f32of32_row_major_avx512
(
float* pack_a_buf,
const float* a,
const dim_t lda,
const dim_t MC,
const dim_t KC,
dim_t* rs_a,
dim_t* cs_a
)
{
const dim_t MR = 6;
const dim_t KR_NDIM = 16;
// Used for permuting the mm512i elements for use in vpdpbusd instruction.
// These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3.
// Adding 4 int32 wise gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7.
__m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB );
__m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF );
__m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB );
__m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF );
// First half.
__m512i selector3 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x2, 0x3, 0x9, 0x4, 0x5 ); // 64 elems
__m512i selector4 = _mm512_setr_epi64( 0x8, 0x6, 0x7, 0x9, 0x0, 0x0, 0x0, 0x0 ); // 32 elems
__m512i selector5 = _mm512_setr_epi64( 0x0, 0x1, 0xA, 0x2, 0x3, 0xB, 0x4, 0x5 ); // 64 elems
__m512i selector6 = _mm512_setr_epi64( 0xA, 0x6, 0x7, 0xB, 0x0, 0x0, 0x0, 0x0 ); // 32 elems
// Second half.
__m512i selector7 = _mm512_setr_epi64( 0x0, 0x1, 0xC, 0x2, 0x3, 0xD, 0x4, 0x5 ); // 64 elems
__m512i selector8 = _mm512_setr_epi64( 0xC, 0x6, 0x7, 0xD, 0x0, 0x0, 0x0, 0x0 ); // 32 elems
__m512i selector9 = _mm512_setr_epi64( 0x0, 0x1, 0xE, 0x2, 0x3, 0xF, 0x4, 0x5 ); // 64 elems
__m512i selector10 = _mm512_setr_epi64( 0xE, 0x6, 0x7, 0xF, 0x0, 0x0, 0x0, 0x0 ); // 32 elems
dim_t m_full_pieces = MC / MR;
dim_t m_full_pieces_loop_limit = m_full_pieces * MR;
dim_t m_partial_pieces = MC % MR;
dim_t kr_full_pieces = KC / KR_NDIM;
dim_t kr_full_pieces_loop_limit = kr_full_pieces * KR_NDIM;
dim_t kr_partial_pieces = KC % KR_NDIM;
__m512 a0;
__m512 b0;
__m512 c0;
__m512 d0;
__m512 e0;
__m512 f0;
__m512 a01;
__m512 c01;
__m512 e01;
__m256 last_piece;
__mmask16 mmask[6];
for ( dim_t ic = 0; ic < MC; ic += MR )
{
if ( ic == m_full_pieces_loop_limit )
{
for ( int ii = 0; ii < m_partial_pieces; ++ii )
{
mmask[ii] = _cvtu32_mask16( 0xFFFF );
}
for ( int ii = m_partial_pieces; ii < MR; ++ii )
{
mmask[ii] = _cvtu32_mask16( 0x0 );
}
}
else
{
for ( int ii = 0; ii < MR; ++ii )
{
mmask[ii] = _cvtu32_mask16( 0xFFFF );
}
}
for ( dim_t kr = 0; kr < kr_full_pieces_loop_limit; kr += KR_NDIM )
{
a0 = _mm512_maskz_loadu_ps( mmask[0], a + ( lda * ( ic + 0 ) ) + kr );
b0 = _mm512_maskz_loadu_ps( mmask[1], a + ( lda * ( ic + 1 ) ) + kr );
c0 = _mm512_maskz_loadu_ps( mmask[2], a + ( lda * ( ic + 2 ) ) + kr );
d0 = _mm512_maskz_loadu_ps( mmask[3], a + ( lda * ( ic + 3 ) ) + kr );
e0 = _mm512_maskz_loadu_ps( mmask[4], a + ( lda * ( ic + 4 ) ) + kr );
f0 = _mm512_maskz_loadu_ps( mmask[5], a + ( lda * ( ic + 5 ) ) + kr );
F32_ROW_MAJOR_K_PACK_LOOP(pack_a_buf, KC, kr);
}
if ( kr_partial_pieces > 0 )
{
err_t r_val;
size_t temp_size = MR * KR_NDIM * sizeof( float );
float* temp_pack_a_buf = bli_malloc_user( temp_size, &r_val );
__mmask16 lmask = _cvtu32_mask16( 0xFFFF >> ( 16 - kr_partial_pieces ) );
for ( int ii = 0; ii < MR; ++ii )
{
mmask[ii] = _mm512_kand( mmask[ii], lmask );
}
a0 = _mm512_maskz_loadu_ps( mmask[0], a + ( lda * ( ic + 0 ) ) +
kr_full_pieces_loop_limit );
b0 = _mm512_maskz_loadu_ps( mmask[1], a + ( lda * ( ic + 1 ) ) +
kr_full_pieces_loop_limit );
c0 = _mm512_maskz_loadu_ps( mmask[2], a + ( lda * ( ic + 2 ) ) +
kr_full_pieces_loop_limit );
d0 = _mm512_maskz_loadu_ps( mmask[3], a + ( lda * ( ic + 3 ) ) +
kr_full_pieces_loop_limit );
e0 = _mm512_maskz_loadu_ps( mmask[4], a + ( lda * ( ic + 4 ) ) +
kr_full_pieces_loop_limit );
f0 = _mm512_maskz_loadu_ps( mmask[5], a + ( lda * ( ic + 5 ) ) +
kr_full_pieces_loop_limit );
F32_ROW_MAJOR_K_PACK_LOOP(temp_pack_a_buf, 0, 0);
memcpy
(
pack_a_buf + ( ic * KC ) + ( kr_full_pieces_loop_limit * MR ),
temp_pack_a_buf,
kr_partial_pieces * MR * sizeof( float )
);
bli_free_user( temp_pack_a_buf );
}
}
*rs_a = 1;
*cs_a = 6;
}
void packa_f32f32f32of32_col_major_avx512
(
float* pack_a_buf,
const float* a,
const dim_t lda,
const dim_t MC,
const dim_t KC,
dim_t* rs_a,
dim_t* cs_a
)
{
const dim_t MR = 6;
const dim_t KR_NDIM = 16;
dim_t m_full_pieces = MC / MR;
dim_t m_full_pieces_loop_limit = m_full_pieces * MR;
dim_t m_partial_pieces = MC % MR;
dim_t kr_full_pieces = KC / KR_NDIM;
dim_t kr_full_pieces_loop_limit = kr_full_pieces * KR_NDIM;
dim_t kr_partial_pieces = KC % KR_NDIM;
__m256 a0;
__m256 b0;
__m256 c0;
__m256 d0;
__m256 e0;
__m256 f0;
__m256 g0;
__m256 h0;
__m256 i0;
__m256 j0;
__m256 k0;
__m256 l0;
__m256 m0;
__m256 n0;
__m256 o0;
__m256 p0;
__mmask16 mmask[16];
for ( dim_t ic = 0; ic < MC; ic += MR )
{
if ( ic == m_full_pieces_loop_limit )
{
for ( int ii = 0; ii < 16; ++ii )
{
mmask[ii] = _cvtu32_mask16( 0x3F >> ( MR - m_partial_pieces ) );
}
}
/* Inside the kr loop, the mmask is modified. Need to reset it
* at beginning of each ic loop iteration. */
else
{
for ( int ii = 0; ii < 16; ++ii )
{
mmask[ii] = _cvtu32_mask16( 0x3F );
}
}
for ( dim_t kr = 0; kr < KC; kr += KR_NDIM )
{
if ( kr == kr_full_pieces_loop_limit )
{
for ( int ii = kr_partial_pieces; ii < 16; ++ii )
{
mmask[ii] = _cvtu32_mask16( 0x0 );
}
}
a0 = _mm256_maskz_loadu_ps( mmask[0], a + ic + ( lda * ( kr + 0 ) ) );
b0 = _mm256_maskz_loadu_ps( mmask[1], a + ic + ( lda * ( kr + 1 ) ) );
c0 = _mm256_maskz_loadu_ps( mmask[2], a + ic + ( lda * ( kr + 2 ) ) );
d0 = _mm256_maskz_loadu_ps( mmask[3], a + ic + ( lda * ( kr + 3 ) ) );
e0 = _mm256_maskz_loadu_ps( mmask[4], a + ic + ( lda * ( kr + 4 ) ) );
f0 = _mm256_maskz_loadu_ps( mmask[5], a + ic + ( lda * ( kr + 5 ) ) );
g0 = _mm256_maskz_loadu_ps( mmask[6], a + ic + ( lda * ( kr + 6 ) ) );
h0 = _mm256_maskz_loadu_ps( mmask[7], a + ic + ( lda * ( kr + 7 ) ) );
i0 = _mm256_maskz_loadu_ps( mmask[8], a + ic + ( lda * ( kr + 8 ) ) );
j0 = _mm256_maskz_loadu_ps( mmask[9], a + ic + ( lda * ( kr + 9 ) ) );
k0 = _mm256_maskz_loadu_ps( mmask[10], a + ic + ( lda * ( kr + 10 ) ) );
l0 = _mm256_maskz_loadu_ps( mmask[11], a + ic + ( lda * ( kr + 11 ) ) );
m0 = _mm256_maskz_loadu_ps( mmask[12], a + ic + ( lda * ( kr + 12 ) ) );
n0 = _mm256_maskz_loadu_ps( mmask[13], a + ic + ( lda * ( kr + 13 ) ) );
o0 = _mm256_maskz_loadu_ps( mmask[14], a + ic + ( lda * ( kr + 14 ) ) );
p0 = _mm256_maskz_loadu_ps( mmask[15], a + ic + ( lda * ( kr + 15 ) ) );
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 0 ) ),
mmask[0], a0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 1 ) ),
mmask[1], b0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 2 ) ),
mmask[2], c0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 3 ) ),
mmask[3], d0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 4 ) ),
mmask[4], e0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 5 ) ),
mmask[5], f0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 6 ) ),
mmask[6], g0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 7 ) ),
mmask[7], h0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 8 ) ),
mmask[8], i0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 9 ) ),
mmask[9], j0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 10 ) ),
mmask[10], k0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 11 ) ),
mmask[11], l0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 12 ) ),
mmask[12], m0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 13 ) ),
mmask[13], n0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 14 ) ),
mmask[14], o0
);
_mm256_mask_storeu_ps
(
pack_a_buf + ( ic * KC ) + ( ( kr * MR ) + ( MR * 15 ) ),
mmask[15], p0
);
}
}
}
void packa_mr6_f32f32f32of32_avx512
(
float* pack_a_buf,
const float* a,
const dim_t rs,
const dim_t cs,
const dim_t MC,
const dim_t KC,
dim_t* rs_a,
dim_t* cs_a
)
{
if( cs == 1 )
{
packa_f32f32f32of32_row_major_avx512
( pack_a_buf, a, rs, MC, KC, rs_a, cs_a );
}
else
{
packa_f32f32f32of32_col_major_avx512
( pack_a_buf, a, cs, MC, KC, rs_a, cs_a );
}
}
#endif

View File

@@ -115,16 +115,16 @@ void packa_u8s8s32os32
dim_t* cs_a
)
{
if( cs == 1 )
{
packa_k64_u8s8s32o32
( pack_a_buffer_u8s8s32o32, a, rs, MC, KC, rs_a, cs_a );
}
else
if ( rs == 1 )
{
packa_mr16_u8s8s32o32_col_major
( pack_a_buffer_u8s8s32o32, a, rs, cs, MC, KC, rs_a, cs_a );
}
else
{
packa_k64_u8s8s32o32
( pack_a_buffer_u8s8s32o32, a, rs, MC, KC, rs_a, cs_a );
}
}