mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Implemented AVX2 based GEMV for n=1 case.
- Added a new GEMV kernel with MR = 8 which will be used for cases where n=1. - Modified GEMM and GEMV framework to choose right GEMV kernel based on compile-time and run-time architecture parameters. This had to be done since GEMV kernels are not stored-in/retrieved-from the cntx. - Added a pack kernel that packs A matrix from col-major to row-major using AVX2 instructions. AMD-Internal: [SWLCSG-3519] Change-Id: Ibf7a8121d0bde37660eac58a160c5b9c9ebd2b5c
This commit is contained in:
@@ -70,14 +70,12 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32)
|
||||
|
||||
// Extra space since packing does width in multiples of NR.
|
||||
dim_t n_reorder;
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
if( ( n == 1 ) && ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
|
||||
if( ( n == 1 ) ) //&& ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
|
||||
{
|
||||
//When n == 1, LPGEMV doesn't expect B to be reordered.
|
||||
n_reorder = 1;
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
n_reorder = ( ( n + NR - 1 ) / NR ) * NR;
|
||||
}
|
||||
@@ -172,10 +170,10 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
|
||||
dim_t n_threads = bli_rntm_num_threads( &rntm_g );
|
||||
n_threads = ( n_threads > 0 ) ? n_threads : 1;
|
||||
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
|
||||
//When n == 1, B marix becomes a vector.
|
||||
//Reordering is avoided so that LPGEMV can process it efficiently.
|
||||
if( ( n == 1 ) && ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
|
||||
if( ( n == 1 ) ) //&& ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
|
||||
{
|
||||
if(rs_b == 1)
|
||||
{
|
||||
@@ -189,7 +187,6 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
|
||||
@@ -62,7 +62,41 @@ typedef void (*lpgemm_rowvar_f32)
|
||||
lpgemm_post_op_attr
|
||||
);
|
||||
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
typedef void (*lpgemv_n_one_ker_ft)
|
||||
(
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const AOCL_MEMORY_TAG,
|
||||
const float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const AOCL_MEMORY_TAG,
|
||||
float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const float,
|
||||
const float,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
lpgemm_post_op*,
|
||||
lpgemm_post_op_attr*
|
||||
);
|
||||
|
||||
typedef void (*lpgemv_n_one_a_pack_ft)
|
||||
(
|
||||
float*,
|
||||
const float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
dim_t*,
|
||||
dim_t*
|
||||
);
|
||||
|
||||
LPGEMV(float, float, float, f32f32f32of32)
|
||||
{
|
||||
const float* a_use = (float*)a;
|
||||
@@ -103,9 +137,32 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
if(n == 1)
|
||||
{
|
||||
float* pack_b_buffer_f32f32f32of32;
|
||||
//TODO: AVX2 support need to be added
|
||||
|
||||
dim_t MR;
|
||||
lpgemv_n_one_ker_ft ker_fp;
|
||||
lpgemv_n_one_a_pack_ft packa_fp;
|
||||
|
||||
// Workaround to select right kernel and blocksizes based on arch
|
||||
// since GEMV parameters are not available in lpgemm context.
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
if( lpgemm_get_enabled_arch() == BLIS_ARCH_ZEN3 )
|
||||
{
|
||||
MR = 8;
|
||||
ker_fp = lpgemv_n_one_f32f32f32of32_avx2;
|
||||
packa_fp = packa_mr8_f32f32f32of32_col_major;
|
||||
}
|
||||
else
|
||||
{
|
||||
MR = 16;
|
||||
ker_fp = lpgemv_n_one_f32f32f32of32;
|
||||
packa_fp = packa_mr16_f32f32f32of32_col_major;
|
||||
}
|
||||
#else
|
||||
// Increased MR from 6 to 16 to make use of 32 ZMM registers
|
||||
dim_t MR = 16;
|
||||
MR = 8;
|
||||
ker_fp = lpgemv_n_one_f32f32f32of32_avx2;
|
||||
packa_fp = packa_mr8_f32f32f32of32_col_major;
|
||||
#endif
|
||||
// Pack B matrix if rs_b > 1
|
||||
if( ( mtag_b == PACK ) && ( rs_b != 1 ) )
|
||||
{
|
||||
@@ -144,6 +201,7 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
c_use = c + ic * rs_c;
|
||||
post_ops_attr.post_op_c_i = ic;
|
||||
|
||||
// To-Do: pack A case needs to be handled for AVX2 case.
|
||||
if( mtag_a == PACK && cs_a != 1 )
|
||||
{
|
||||
mem_a_size_req = sizeof(float) * mc0 * k;
|
||||
@@ -154,7 +212,7 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
);
|
||||
pack_a_buffer_f32f32f32of32 = ( float* )bli_mem_buffer( &mem_a );
|
||||
|
||||
packa_mr16_f32f32f32of32_col_major
|
||||
packa_fp
|
||||
(
|
||||
pack_a_buffer_f32f32f32of32,
|
||||
a_use, rs_a, cs_a,
|
||||
@@ -163,19 +221,23 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
);
|
||||
a_use = pack_a_buffer_f32f32f32of32;
|
||||
}
|
||||
|
||||
// Call lpgemv_n_one kernel
|
||||
lpgemv_n_one_f32f32f32of32
|
||||
ker_fp
|
||||
(
|
||||
mc0, k,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a,
|
||||
b_use, rs_b_use, cs_b_use, mtag_b,
|
||||
c_use, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
MR, KC,
|
||||
post_op_list,
|
||||
&post_ops_attr
|
||||
mc0, k,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a,
|
||||
b_use, rs_b_use, cs_b_use, mtag_b,
|
||||
c_use, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
MR, KC,
|
||||
post_op_list,
|
||||
&post_ops_attr
|
||||
);
|
||||
|
||||
if ( mtag_a == PACK )
|
||||
{
|
||||
// Release pack buffer for A.
|
||||
bli_pba_release( rntm, &mem_a );
|
||||
}
|
||||
}
|
||||
if ( ( mtag_a == PACK ) && ( bli_mem_is_alloc( &mem_a ) ) )
|
||||
{
|
||||
@@ -188,6 +250,8 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
}
|
||||
else
|
||||
{
|
||||
// m = 1 case is not implemented yet for AVX2
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
// Compute the JC loop thread range for the current thread.
|
||||
dim_t jc_start, jc_end;
|
||||
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
|
||||
@@ -312,19 +376,21 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
{
|
||||
bli_pba_release( rntm, &mem_b );
|
||||
}
|
||||
#endif // m == 1 case is not implemented for AVX2 yet.
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
{
|
||||
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
// Handle using LPGEMV when m or/and n equal to 1
|
||||
// The avx512 check will be removed when avx2 kernels added in future
|
||||
if ( ( ( m == 1 ) || ( n == 1 ) ) &&
|
||||
( bli_cpuid_is_avx512_supported() == TRUE ) &&
|
||||
( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
if ( ( ( (m == 1) && (lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3) ) || ( n == 1 ) ) &&
|
||||
( bli_cpuid_is_avx512_supported() == TRUE ) )
|
||||
#else
|
||||
// m=1 case is not implemented yet for AVX2
|
||||
if ( ( ( n == 1 ) ) && ( bli_cpuid_is_avx2fma3_supported() == TRUE ) )
|
||||
#endif
|
||||
{
|
||||
lpgemv_rowvar_f32f32f32of32(m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
@@ -339,7 +405,6 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
c_downscale);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Query the context for various blocksizes.
|
||||
const dim_t NC = lcntx->blksz.NC;
|
||||
|
||||
@@ -62,7 +62,43 @@ typedef void (*lpgemm_rowvar_f32)
|
||||
lpgemm_post_op_attr
|
||||
);
|
||||
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
|
||||
typedef void (*lpgemv_n_one_ker_ft)
|
||||
(
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const AOCL_MEMORY_TAG,
|
||||
const float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const AOCL_MEMORY_TAG,
|
||||
float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const float,
|
||||
const float,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
lpgemm_post_op*,
|
||||
lpgemm_post_op_attr*
|
||||
);
|
||||
|
||||
typedef void (*lpgemv_n_one_a_pack_ft)
|
||||
(
|
||||
float*,
|
||||
const float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
dim_t*,
|
||||
dim_t*
|
||||
);
|
||||
|
||||
|
||||
LPGEMV_TINY(float, float, float, f32f32f32of32)
|
||||
{
|
||||
const float* a_use = ( float* )a;
|
||||
@@ -84,9 +120,32 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
|
||||
float* pack_b_buffer_f32f32f32of32 = NULL;
|
||||
err_t err = BLIS_SUCCESS;
|
||||
|
||||
//TODO: AVX2 support need to be added
|
||||
dim_t MR;
|
||||
lpgemv_n_one_ker_ft ker_fp;
|
||||
lpgemv_n_one_a_pack_ft packa_fp;
|
||||
|
||||
// Workaround to select right kernel and blocksizes based on arch
|
||||
// since GEMV parameters are not available in lpgemm context.
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
if( lpgemm_get_enabled_arch() == BLIS_ARCH_ZEN3 )
|
||||
{
|
||||
MR = 8;
|
||||
ker_fp = lpgemv_n_one_f32f32f32of32_avx2;
|
||||
packa_fp = packa_mr8_f32f32f32of32_col_major;
|
||||
}
|
||||
else
|
||||
{
|
||||
MR = 16;
|
||||
ker_fp = lpgemv_n_one_f32f32f32of32;
|
||||
packa_fp = packa_mr16_f32f32f32of32_col_major;
|
||||
}
|
||||
#else
|
||||
// Increased MR from 6 to 16 to make use of 32 ZMM registers
|
||||
dim_t MR = 16;
|
||||
MR = 8;
|
||||
ker_fp = lpgemv_n_one_f32f32f32of32_avx2;
|
||||
packa_fp = packa_mr8_f32f32f32of32_col_major;
|
||||
#endif
|
||||
|
||||
|
||||
// Pack B matrix if rs_b > 1
|
||||
if( ( mtag_b == PACK ) && ( rs_b != 1 ) )
|
||||
@@ -111,7 +170,7 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
|
||||
pack_a_buffer_f32f32f32of32 =
|
||||
( float* )bli_malloc_user(mem_a_size_req, &err);
|
||||
|
||||
packa_mr16_f32f32f32of32_col_major
|
||||
packa_fp
|
||||
(
|
||||
pack_a_buffer_f32f32f32of32,
|
||||
a_use, rs_a, cs_a,
|
||||
@@ -123,7 +182,7 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
|
||||
|
||||
post_ops_attr.post_op_c_i = 0;
|
||||
post_ops_attr.post_op_c_j = 0;
|
||||
lpgemv_n_one_f32f32f32of32
|
||||
ker_fp
|
||||
(
|
||||
m, k,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a,
|
||||
@@ -145,16 +204,19 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
LPGEMM_TINY(float,float,float,f32f32f32of32)
|
||||
{
|
||||
|
||||
// Handle using LPGEMV when m or/and n equal to 1
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
// Handle using LPGEMV when m or/and n equal to 1
|
||||
// The avx512 check will be removed when avx2 kernels added in future
|
||||
if ( ( n == 1 ) &&
|
||||
( bli_cpuid_is_avx512_supported() == TRUE ) &&
|
||||
( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
|
||||
if ( ( ( (m == 1) && (lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3) ) || ( n == 1 ) ) &&
|
||||
( bli_cpuid_is_avx512_supported() == TRUE ) )
|
||||
#else
|
||||
// m=1 case is not implemented yet for AVX2
|
||||
if ( ( ( n == 1 ) ) && ( bli_cpuid_is_avx2fma3_supported() == TRUE ) )
|
||||
#endif
|
||||
{
|
||||
lpgemv_rowvar_tiny_f32f32f32of32(m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
@@ -167,7 +229,7 @@ LPGEMM_TINY(float,float,float,f32f32f32of32)
|
||||
c_downscale);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
const dim_t NR = lcntx->blksz.NR;
|
||||
const dim_t MR = lcntx->blksz.MR;
|
||||
|
||||
@@ -46,6 +46,18 @@ void packa_mr16_f32f32f32of32_col_major
|
||||
dim_t* cs_p
|
||||
);
|
||||
|
||||
void packa_mr8_f32f32f32of32_col_major
|
||||
(
|
||||
float* pack_a_buffer,
|
||||
const float* a,
|
||||
const dim_t rs_a,
|
||||
const dim_t cs_a,
|
||||
const dim_t MC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_p,
|
||||
dim_t* cs_p
|
||||
);
|
||||
|
||||
void packa_mr6_f32f32f32of32_avx512
|
||||
(
|
||||
float* pack_a_buf,
|
||||
|
||||
@@ -782,6 +782,7 @@ void lpgemv_n_one_ ## LP_SFX \
|
||||
) \
|
||||
|
||||
LPGEMV_N_EQ1_KERN(float, float, float,f32f32f32of32);
|
||||
LPGEMV_N_EQ1_KERN(float, float, float,f32f32f32of32_avx2);
|
||||
LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float,bf16bf16f32of32);
|
||||
LPGEMV_N_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32);
|
||||
LPGEMV_N_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32);
|
||||
|
||||
467
kernels/zen/lpgemm/f32f32f32/lpgemm_pack_a_f32_amd256.c
Normal file
467
kernels/zen/lpgemm/f32f32f32/lpgemm_pack_a_f32_amd256.c
Normal file
@@ -0,0 +1,467 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <string.h>
|
||||
#include "blis.h"
|
||||
|
||||
#ifdef BLIS_ADDON_LPGEMM
|
||||
|
||||
#define UNPACKLO_PS8 \
|
||||
b_reg[0] = _mm256_unpacklo_ps( a_reg[0], a_reg[1] ); \
|
||||
b_reg[1] = _mm256_unpacklo_ps( a_reg[2], a_reg[3] ); \
|
||||
b_reg[2] = _mm256_unpacklo_ps( a_reg[4], a_reg[5] ); \
|
||||
b_reg[3] = _mm256_unpacklo_ps( a_reg[6], a_reg[7] );
|
||||
|
||||
#define UNPACKHI_PS8 \
|
||||
b_reg[4] = _mm256_unpackhi_ps( a_reg[0], a_reg[1] ); \
|
||||
b_reg[5] = _mm256_unpackhi_ps( a_reg[2], a_reg[3] ); \
|
||||
b_reg[6] = _mm256_unpackhi_ps( a_reg[4], a_reg[5] ); \
|
||||
b_reg[7] = _mm256_unpackhi_ps( a_reg[6], a_reg[7] );
|
||||
|
||||
#define UNPACKLO_PD8 \
|
||||
a_reg[0] = (__m256)_mm256_unpacklo_pd( (__m256d)b_reg[0], (__m256d)b_reg[1] ); \
|
||||
a_reg[1] = (__m256)_mm256_unpacklo_pd( (__m256d)b_reg[2], (__m256d)b_reg[3] ); \
|
||||
a_reg[2] = (__m256)_mm256_unpacklo_pd( (__m256d)b_reg[4], (__m256d)b_reg[5] ); \
|
||||
a_reg[3] = (__m256)_mm256_unpacklo_pd( (__m256d)b_reg[6], (__m256d)b_reg[7] );
|
||||
|
||||
#define UNPACKHI_PD8 \
|
||||
a_reg[4] = (__m256)_mm256_unpackhi_pd( (__m256d)b_reg[0], (__m256d)b_reg[1] ); \
|
||||
a_reg[5] = (__m256)_mm256_unpackhi_pd( (__m256d)b_reg[2], (__m256d)b_reg[3] ); \
|
||||
a_reg[6] = (__m256)_mm256_unpackhi_pd( (__m256d)b_reg[4], (__m256d)b_reg[5] ); \
|
||||
a_reg[7] = (__m256)_mm256_unpackhi_pd( (__m256d)b_reg[6], (__m256d)b_reg[7] );
|
||||
|
||||
#define PERMUTE2F128_PS8 \
|
||||
b_reg[0] = _mm256_permute2f128_ps( a_reg[0], a_reg[1], 0x20 ); \
|
||||
b_reg[1] = _mm256_permute2f128_ps( a_reg[4], a_reg[5], 0x20 ); \
|
||||
b_reg[2] = _mm256_permute2f128_ps( a_reg[2], a_reg[3], 0x20 ); \
|
||||
b_reg[3] = _mm256_permute2f128_ps( a_reg[6], a_reg[7], 0x20 ); \
|
||||
b_reg[4] = _mm256_permute2f128_ps( a_reg[0], a_reg[1], 0x31 ); \
|
||||
b_reg[5] = _mm256_permute2f128_ps( a_reg[4], a_reg[5], 0x31 ); \
|
||||
b_reg[6] = _mm256_permute2f128_ps( a_reg[2], a_reg[3], 0x31 ); \
|
||||
b_reg[7] = _mm256_permute2f128_ps( a_reg[6], a_reg[7], 0x31 );
|
||||
|
||||
void packa_mr8_f32f32f32of32_col_major
|
||||
(
|
||||
float* pack_a_buffer,
|
||||
const float* a,
|
||||
const dim_t rs_a,
|
||||
const dim_t cs_a,
|
||||
const dim_t MC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_p,
|
||||
dim_t* cs_p
|
||||
)
|
||||
{
|
||||
dim_t MR = 8;
|
||||
dim_t ic, kr;
|
||||
|
||||
__m256 a_reg[8], b_reg[8];
|
||||
|
||||
__m256i k_masks[3] = {
|
||||
_mm256_set_epi32( 0, 0, 0, 0, 0, 0, 0, -1), // 1 element
|
||||
_mm256_set_epi32( 0, 0, 0, 0, 0, 0, -1, -1), // 2 elements
|
||||
_mm256_set_epi32( 0, 0, 0, 0, -1, -1, -1, -1), // 4 elements
|
||||
};
|
||||
|
||||
__m256i load_mask, store_mask;
|
||||
|
||||
// These registers are set with zeroes to avoid compiler warnings
|
||||
// To-DO: TO be removed when pack code is optimized for fringe cases.
|
||||
a_reg[0] = _mm256_setzero_ps();
|
||||
a_reg[1] = _mm256_setzero_ps();
|
||||
a_reg[2] = _mm256_setzero_ps();
|
||||
a_reg[3] = _mm256_setzero_ps();
|
||||
a_reg[4] = _mm256_setzero_ps();
|
||||
a_reg[5] = _mm256_setzero_ps();
|
||||
a_reg[6] = _mm256_setzero_ps();
|
||||
a_reg[7] = _mm256_setzero_ps();
|
||||
|
||||
for( ic = 0; ( ic + MR -1 ) < MC; ic += MR )
|
||||
{
|
||||
for( kr = 0; ( kr + 7 ) < KC; kr += 8 )
|
||||
{
|
||||
// Transposing the 8x8 block of data
|
||||
a_reg[0] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) );
|
||||
a_reg[1] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) );
|
||||
a_reg[2] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ) );
|
||||
a_reg[3] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ) );
|
||||
a_reg[4] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ) );
|
||||
a_reg[5] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ) );
|
||||
a_reg[6] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ) );
|
||||
a_reg[7] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ) );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), b_reg[0] );
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), b_reg[1] );
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), b_reg[2] );
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), b_reg[3] );
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 4 ) * KC + kr ), b_reg[4] );
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 5 ) * KC + kr ), b_reg[5] );
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 6 ) * KC + kr ), b_reg[6] );
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 7 ) * KC + kr ), b_reg[7] );
|
||||
}
|
||||
store_mask = k_masks[2]; // mask to store 4 elements
|
||||
for( ; ( kr + 3 ) < KC; kr += 4 )
|
||||
{
|
||||
// Transposing the 8x8 block of data
|
||||
a_reg[0] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) );
|
||||
a_reg[1] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) );
|
||||
a_reg[2] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ) );
|
||||
a_reg[3] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ) );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), store_mask, b_reg[2] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), store_mask, b_reg[3] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 4 ) * KC + kr ), store_mask, b_reg[4] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 5 ) * KC + kr ), store_mask, b_reg[5] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 6 ) * KC + kr ), store_mask, b_reg[6] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 7 ) * KC + kr ), store_mask, b_reg[7] );
|
||||
}
|
||||
store_mask = k_masks[1]; // mask to store 2 elements
|
||||
for( ; ( kr + 1 ) < KC; kr += 2 )
|
||||
{
|
||||
// Transposing the 8x8 block of data
|
||||
a_reg[0] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) );
|
||||
a_reg[1] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), store_mask, b_reg[2] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), store_mask, b_reg[3] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 4 ) * KC + kr ), store_mask, b_reg[4] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 5 ) * KC + kr ), store_mask, b_reg[5] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 6 ) * KC + kr ), store_mask, b_reg[6] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 7 ) * KC + kr ), store_mask, b_reg[7] );
|
||||
}
|
||||
store_mask = k_masks[0]; // mask to store 1 element
|
||||
for( ; ( kr + 0 ) < KC; kr += 1 )
|
||||
{
|
||||
// Transposing the 8x8 block of data
|
||||
a_reg[0] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), store_mask, b_reg[2] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), store_mask, b_reg[3] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 4 ) * KC + kr ), store_mask, b_reg[4] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 5 ) * KC + kr ), store_mask, b_reg[5] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 6 ) * KC + kr ), store_mask, b_reg[6] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 7 ) * KC + kr ), store_mask, b_reg[7] );
|
||||
}
|
||||
}
|
||||
for( ; ( ic + 3 ) < MC; ic += 4 )
|
||||
{
|
||||
load_mask = k_masks[2]; // mask to load 4 elements
|
||||
for( kr = 0; ( kr + 7 ) < KC; kr += 8 )
|
||||
{
|
||||
// Transposing the 8x8 block of data
|
||||
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
|
||||
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
|
||||
a_reg[2] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ), load_mask );
|
||||
a_reg[3] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ), load_mask );
|
||||
a_reg[4] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ), load_mask );
|
||||
a_reg[5] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ), load_mask );
|
||||
a_reg[6] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ), load_mask );
|
||||
a_reg[7] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ), load_mask );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), b_reg[0] );
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), b_reg[1] );
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), b_reg[2] );
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), b_reg[3] );
|
||||
}
|
||||
store_mask = k_masks[2]; // mask to store 4 elements
|
||||
for( ; ( kr + 3 ) < KC; kr += 4 )
|
||||
{
|
||||
// Transposing the 4x4 block of data
|
||||
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
|
||||
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
|
||||
a_reg[2] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ), load_mask );
|
||||
a_reg[3] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ), load_mask );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), store_mask, b_reg[2] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), store_mask, b_reg[3] );
|
||||
}
|
||||
store_mask = k_masks[1]; // mask to store 2 elements
|
||||
for( ; ( kr + 1 ) < KC; kr += 2 )
|
||||
{
|
||||
// transposing the 4x2 block of data
|
||||
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
|
||||
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), store_mask, b_reg[2] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), store_mask, b_reg[3] );
|
||||
}
|
||||
store_mask = k_masks[0]; // mask to store 1 element
|
||||
for( ; ( kr + 0 ) < KC; kr += 1 )
|
||||
{
|
||||
// transposing the 4x1 block of data
|
||||
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), store_mask, b_reg[2] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), store_mask, b_reg[3] );
|
||||
}
|
||||
}
|
||||
for( ; ( ic + 1 ) < MC; ic += 2 )
|
||||
{
|
||||
load_mask = k_masks[1]; // mask to load 2 elements
|
||||
for( kr = 0; ( kr + 7 ) < KC; kr += 8 )
|
||||
{
|
||||
// Transposing the 2x8 block of data
|
||||
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
|
||||
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
|
||||
a_reg[2] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ), load_mask );
|
||||
a_reg[3] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ), load_mask );
|
||||
a_reg[4] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ), load_mask );
|
||||
a_reg[5] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ), load_mask );
|
||||
a_reg[6] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ), load_mask );
|
||||
a_reg[7] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ), load_mask );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), b_reg[0] );
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), b_reg[1] );
|
||||
}
|
||||
store_mask = k_masks[2]; // mask to store 4 elements
|
||||
for( ; ( kr + 3 ) < KC; kr += 4 )
|
||||
{
|
||||
// Transposing the 2x4 block of data
|
||||
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
|
||||
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
|
||||
a_reg[2] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ), load_mask );
|
||||
a_reg[3] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ), load_mask );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
|
||||
}
|
||||
store_mask = k_masks[1]; // mask to store 2 elements
|
||||
for( ; ( kr + 1 ) < KC; kr += 2 )
|
||||
{
|
||||
// Transposing the 2x2 block of data
|
||||
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
|
||||
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
|
||||
}
|
||||
store_mask = k_masks[0]; // mask to store 1 element
|
||||
for( ; ( kr + 0 ) < KC; kr += 1 )
|
||||
{
|
||||
// Transposing the 2x1 block of data
|
||||
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
|
||||
}
|
||||
}
|
||||
for( ; ( ic + 0 ) < MC; ic += 1 )
|
||||
{
|
||||
load_mask = k_masks[0]; // mask to load 1 element
|
||||
for( kr = 0; ( kr + 7 ) < KC; kr += 8 )
|
||||
{
|
||||
// Transposing the 1x8 block of data
|
||||
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
|
||||
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
|
||||
a_reg[2] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ), load_mask );
|
||||
a_reg[3] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ), load_mask );
|
||||
a_reg[4] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ), load_mask );
|
||||
a_reg[5] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ), load_mask );
|
||||
a_reg[6] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ), load_mask );
|
||||
a_reg[7] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ), load_mask );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), b_reg[0] );
|
||||
}
|
||||
store_mask = k_masks[2]; // mask to store 4 elements
|
||||
for( ; ( kr + 3 ) < KC; kr += 4 )
|
||||
{
|
||||
// Transposing the 1x4 block of data
|
||||
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
|
||||
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
|
||||
a_reg[2] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ), load_mask );
|
||||
a_reg[3] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ), load_mask );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
|
||||
}
|
||||
store_mask = k_masks[1]; // mask to store 2 elements
|
||||
for( ; ( kr + 1 ) < KC; kr += 2 )
|
||||
{
|
||||
// Transposing the 1x2 block of data
|
||||
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
|
||||
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
|
||||
}
|
||||
store_mask = k_masks[0]; // mask to store 1 element
|
||||
for( ; ( kr + 0 ) < KC; kr += 1 )
|
||||
{
|
||||
// Transposing the 1x1 block of data
|
||||
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
|
||||
|
||||
UNPACKLO_PS8
|
||||
UNPACKHI_PS8
|
||||
|
||||
UNPACKLO_PD8
|
||||
UNPACKHI_PD8
|
||||
|
||||
PERMUTE2F128_PS8
|
||||
|
||||
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
|
||||
}
|
||||
}
|
||||
// Set the row and column strides of the packed matrix.
|
||||
*rs_p = KC;
|
||||
*cs_p = 1;
|
||||
}
|
||||
|
||||
#endif // BLIS_ADDON_LPGEMM
|
||||
@@ -43,33 +43,563 @@
|
||||
// to produce C output of MRX1. The vectorization is done in k loop and
|
||||
// the horizontal reduction done to produce one output from each
|
||||
// accumulator register
|
||||
void lpgemv_n_one_kernel_f32_avx2_ker_ft
|
||||
(
|
||||
const dim_t m0,
|
||||
const dim_t k,
|
||||
const float *a,
|
||||
const dim_t rs_a,
|
||||
const dim_t cs_a,
|
||||
const AOCL_MEMORY_TAG mtag_a,
|
||||
const float *b,
|
||||
const dim_t rs_b,
|
||||
const dim_t cs_b,
|
||||
const AOCL_MEMORY_TAG mtag_b,
|
||||
float *c,
|
||||
const dim_t rs_c,
|
||||
const dim_t cs_c,
|
||||
const float alpha,
|
||||
const float beta,
|
||||
const dim_t MR,
|
||||
const dim_t KC,
|
||||
lpgemm_post_op *post_op_list,
|
||||
lpgemm_post_op_attr *post_op_attr
|
||||
)
|
||||
{
|
||||
//TODO: Created dummy function as place holder to get
|
||||
//rid of linking issues in other zen configurations.
|
||||
//AVX2 varient wil be implemented in next commits.
|
||||
//Code will take LPGEMM path for LPGEMV in AVX2 env.
|
||||
}
|
||||
|
||||
#define LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, paddr, stride ) \
|
||||
ymm0 = _mm256_loadu_ps( paddr ); \
|
||||
ymm1 = _mm256_loadu_ps( paddr + stride ); \
|
||||
ymm2 = _mm256_loadu_ps( paddr + 2 * stride ); \
|
||||
ymm3 = _mm256_loadu_ps( paddr + 3 * stride );
|
||||
|
||||
#define LPGEMV_N_KERNEL_4_MASKLOADS( ymm0, ymm1, ymm2, ymm3, mask, paddr, stride ) \
|
||||
ymm0 = _mm256_maskload_ps( paddr, mask ); \
|
||||
ymm1 = _mm256_maskload_ps( paddr + stride, mask ); \
|
||||
ymm2 = _mm256_maskload_ps( paddr + 2 * stride, mask ); \
|
||||
ymm3 = _mm256_maskload_ps( paddr + 3 * stride, mask );
|
||||
|
||||
#define LPGEMV_N_KERNEL_4_FMA( ymm8, ymm9, ymm10, ymm11, ymm7, ymm0, ymm1, ymm2, ymm3 ) \
|
||||
ymm8 = _mm256_fmadd_ps( ymm0, ymm7, ymm8 ); \
|
||||
ymm9 = _mm256_fmadd_ps( ymm1, ymm7, ymm9 ); \
|
||||
ymm10 = _mm256_fmadd_ps( ymm2, ymm7, ymm10 ); \
|
||||
ymm11 = _mm256_fmadd_ps( ymm3, ymm7, ymm11 );
|
||||
|
||||
#define LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, ymm0, ymm1, ymm2, ymm3, xmm0 ) \
|
||||
ymm0 = _mm256_hadd_ps( ymm8, ymm9 ); \
|
||||
ymm1 = _mm256_hadd_ps( ymm10, ymm11 ); \
|
||||
ymm0 = _mm256_hadd_ps( ymm0, ymm1 ); \
|
||||
xmm0 = _mm_add_ps(_mm256_extractf128_ps(ymm0, 0), _mm256_extractf128_ps(ymm0,1));
|
||||
|
||||
#define RELU_SCALE_OP_F32_AVX2( acc, scale_reg, zreg, scratch_reg ) \
|
||||
scratch_reg = _mm256_min_ps( acc, zreg ); \
|
||||
acc = _mm256_max_ps( acc, zreg ); \
|
||||
scratch_reg = _mm256_mul_ps( scratch_reg, scale_reg ); \
|
||||
acc = _mm256_or_ps( acc, scratch_reg );
|
||||
|
||||
|
||||
LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32_avx2 )
|
||||
{
|
||||
static void *post_ops_labels[] =
|
||||
{
|
||||
&&POST_OPS_1x16F_DISABLE,
|
||||
&&POST_OPS_BIAS_1x16F,
|
||||
&&POST_OPS_RELU_1x16F,
|
||||
&&POST_OPS_RELU_SCALE_1x16F,
|
||||
&&POST_OPS_GELU_TANH_1x16F,
|
||||
&&POST_OPS_GELU_ERF_1x16F,
|
||||
&&POST_OPS_CLIP_1x16F,
|
||||
&&POST_OPS_DOWNSCALE_1x16F,
|
||||
&&POST_OPS_MATRIX_ADD_1x16F,
|
||||
&&POST_OPS_SWISH_1x16F,
|
||||
&&POST_OPS_MATRIX_MUL_1x16F,
|
||||
&&POST_OPS_TANH_1x16F,
|
||||
&&POST_OPS_SIGMOID_1x16F
|
||||
};
|
||||
|
||||
// Strides are updated based on matrix packing/reordering.
|
||||
const float *a_use = NULL;
|
||||
const float *b_use = NULL;
|
||||
float *c_use = NULL;
|
||||
|
||||
lpgemm_post_op_attr post_ops_attr = *(post_op_attr);
|
||||
|
||||
__m256 ymm0, ymm1, ymm2, ymm3, ymm4, ymm7;
|
||||
__m256 ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15;
|
||||
|
||||
__m128 xmm0, xmm1;
|
||||
|
||||
__m256i masks[9] = {
|
||||
_mm256_set_epi32( 0, 0, 0, 0, 0, 0, 0, 0), // 0 elements
|
||||
_mm256_set_epi32( 0, 0, 0, 0, 0, 0, 0, -1), // 1 element
|
||||
_mm256_set_epi32( 0, 0, 0, 0, 0, 0, -1, -1), // 2 elements
|
||||
_mm256_set_epi32( 0, 0, 0, 0, 0, -1, -1, -1), // 3 elements
|
||||
_mm256_set_epi32( 0, 0, 0, 0, -1, -1, -1, -1), // 4 elements
|
||||
_mm256_set_epi32( 0, 0, 0, -1, -1, -1, -1, -1), // 5 elements
|
||||
_mm256_set_epi32( 0, 0, -1, -1, -1, -1, -1, -1), // 6 elements
|
||||
_mm256_set_epi32( 0, -1, -1, -1, -1, -1, -1, -1), // 7 elements
|
||||
_mm256_set_epi32(-1, -1, -1, -1, -1, -1, -1, -1) // 8 elements
|
||||
};
|
||||
|
||||
|
||||
|
||||
// MR comes from framework, we need to set it based on the underlying hardware configuration.
|
||||
for (dim_t mr = 0; mr < m0; mr += MR)
|
||||
{
|
||||
dim_t mr0 = bli_min( m0 - mr, MR );
|
||||
|
||||
dim_t k_iter = k / 8;
|
||||
dim_t k_rem = k % 8;
|
||||
|
||||
const __m256i store_mask = masks[mr0];
|
||||
const __m256i k_rem_mask = masks[k_rem];
|
||||
|
||||
/* zero the accumulator registers */
|
||||
ZERO_ACC_YMM_4_REG( ymm8, ymm9, ymm10, ymm11 );
|
||||
ZERO_ACC_YMM_4_REG( ymm12, ymm13, ymm14, ymm15 );
|
||||
|
||||
//update pointers
|
||||
a_use = a + mr * rs_a;
|
||||
b_use = b;
|
||||
c_use = c + mr * rs_c;
|
||||
|
||||
if( mr0 == MR )
|
||||
{
|
||||
for( dim_t k = 0; k < k_iter; k++ )
|
||||
{
|
||||
ymm7 = _mm256_loadu_ps( b_use );
|
||||
b_use += 8; // move b pointer to next 8 elements
|
||||
|
||||
// Load 4x16 from row 0-3 of A
|
||||
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, a_use, rs_a );
|
||||
a_use += 4 * rs_a; // move a pointer to next 4x16 elements
|
||||
|
||||
// Perform the dot product
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm8, ymm9, ymm10, ymm11, ymm7, ymm0, ymm1, ymm2, ymm3 );
|
||||
|
||||
// Load 4x16 from row 4-7 of A
|
||||
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, a_use, rs_a );
|
||||
a_use -= 4 * rs_a; // move a pointer to next 4x16 elements
|
||||
|
||||
// Perform the dot product
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm12, ymm13, ymm14, ymm15, ymm7, ymm0, ymm1, ymm2, ymm3 );
|
||||
|
||||
a_use += 8;
|
||||
} // k-loop
|
||||
|
||||
if( k_rem )
|
||||
{
|
||||
ymm7 = _mm256_maskload_ps( b_use, k_rem_mask );
|
||||
|
||||
// Load 4x16 from row 0-3 of A
|
||||
LPGEMV_N_KERNEL_4_MASKLOADS( ymm0, ymm1, ymm2, ymm3, k_rem_mask, a_use, rs_a );
|
||||
a_use += 4 * rs_a; // move a pointer to next 4x16 elements
|
||||
|
||||
// Perform the dot product
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm8, ymm9, ymm10, ymm11, ymm7, ymm0, ymm1, ymm2, ymm3 );
|
||||
|
||||
// Load 4x16 from row 4-7 of A
|
||||
LPGEMV_N_KERNEL_4_MASKLOADS( ymm0, ymm1, ymm2, ymm3, k_rem_mask, a_use, rs_a );
|
||||
a_use -= 4 * rs_a; // move a pointer to next 4x16 elements
|
||||
|
||||
// Perform the dot product
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm12, ymm13, ymm14, ymm15, ymm7, ymm0, ymm1, ymm2, ymm3 );
|
||||
}
|
||||
|
||||
// Add the registers horizontally to get one output
|
||||
LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, ymm0, ymm1, ymm2, ymm3, xmm0 );
|
||||
LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, ymm4, ymm1, ymm2, ymm3, xmm1 );
|
||||
|
||||
// compose outputs into one ymm to perform post-ops.
|
||||
ymm8 = _mm256_insertf128_ps( ymm8, xmm0, 0 );
|
||||
ymm8 = _mm256_insertf128_ps( ymm8, xmm1, 1 );
|
||||
}
|
||||
else
|
||||
{
|
||||
//Handle fringe cases when mr0 < MR
|
||||
const float *a_use_fringe = a_use;
|
||||
dim_t mr0_use = mr0;
|
||||
dim_t regidx = 0;
|
||||
|
||||
// Dot product for mfringe 4
|
||||
if (mr0_use >= 4)
|
||||
{
|
||||
for( dim_t k = 0; k < k_iter; k++ )
|
||||
{
|
||||
ymm7 = _mm256_loadu_ps( b_use );
|
||||
b_use += 8; // move b pointer to next 8 elements
|
||||
|
||||
// Load 4x16 from row 0-3 of A
|
||||
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, a_use, rs_a );
|
||||
a_use += 8; // move a pointer to next 4x16 elements
|
||||
|
||||
// Perform the dot product
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm8, ymm9, ymm10, ymm11, ymm7, ymm0, ymm1, ymm2, ymm3 );
|
||||
} // k-loop
|
||||
|
||||
if( k_rem )
|
||||
{
|
||||
ymm7 = _mm256_maskload_ps( b_use, k_rem_mask );
|
||||
|
||||
// Load 4x16 from row 0-3 of A
|
||||
LPGEMV_N_KERNEL_4_MASKLOADS( ymm0, ymm1, ymm2, ymm3, k_rem_mask, a_use, rs_a );
|
||||
|
||||
// Perform the dot product
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm8, ymm9, ymm10, ymm11, ymm7, ymm0, ymm1, ymm2, ymm3 );
|
||||
}
|
||||
|
||||
//update pointers
|
||||
mr0_use -= 4;
|
||||
a_use = a_use_fringe + 4 * rs_a;
|
||||
a_use_fringe = a_use;
|
||||
b_use = b;
|
||||
|
||||
//Horizontal add 4 ymm registers and get output into 2 xmm registers
|
||||
LPGEMV_YMM2XMM(ymm8, ymm9, ymm10, ymm11, ymm0, ymm1, ymm2, ymm3, xmm0)
|
||||
|
||||
// compose outputs into one ymm to perform post-ops.
|
||||
ymm8 = _mm256_insertf128_ps( ymm8, xmm0, 0 );
|
||||
regidx = 1;
|
||||
}
|
||||
// Dot product for <= 3
|
||||
if (mr0_use)
|
||||
{
|
||||
if( mr0_use >= 2 )
|
||||
{
|
||||
for (dim_t k = 0; k < k_iter; k++)
|
||||
{
|
||||
ymm7 = _mm256_loadu_ps( b_use );
|
||||
b_use += 8; // move b pointer to next 8 elements
|
||||
|
||||
// Load 2x16 from row 0-1 of A
|
||||
ymm0 = _mm256_loadu_ps( a_use );
|
||||
ymm1 = _mm256_loadu_ps( a_use + rs_a );
|
||||
a_use += 8; // move a pointer to next 4x16 elements
|
||||
|
||||
ymm12 = _mm256_fmadd_ps( ymm0, ymm7, ymm12 );
|
||||
ymm13 = _mm256_fmadd_ps( ymm1, ymm7, ymm13 );
|
||||
|
||||
} // k-loop
|
||||
if( k_rem )
|
||||
{
|
||||
ymm7 = _mm256_maskload_ps( b_use, k_rem_mask );
|
||||
|
||||
// Load 2x16 from row 0-1 of A
|
||||
ymm0 = _mm256_maskload_ps( a_use, k_rem_mask );
|
||||
ymm1 = _mm256_maskload_ps( a_use + rs_a, k_rem_mask );
|
||||
|
||||
ymm12 = _mm256_fmadd_ps( ymm0, ymm7, ymm12 );
|
||||
ymm13 = _mm256_fmadd_ps( ymm1, ymm7, ymm13 );
|
||||
}
|
||||
//update pointers
|
||||
mr0_use -= 2;
|
||||
a_use = a_use_fringe + 2 * rs_a;
|
||||
a_use_fringe = a_use;
|
||||
b_use = b;
|
||||
}
|
||||
if( mr0_use == 1 )
|
||||
{
|
||||
for (dim_t k = 0; k < k_iter; k++)
|
||||
{
|
||||
ymm7 = _mm256_loadu_ps( b_use );
|
||||
b_use += 8; // move b pointer to next 8 elements
|
||||
|
||||
// Load 1x16 from row 0 of A
|
||||
ymm0 = _mm256_loadu_ps( a_use );
|
||||
a_use += 8; // move a pointer to next 4x16 elements
|
||||
|
||||
ymm14 = _mm256_fmadd_ps( ymm0, ymm7, ymm14 );
|
||||
|
||||
} // k-loop
|
||||
if( k_rem )
|
||||
{
|
||||
ymm7 = _mm256_maskload_ps( b_use, k_rem_mask );
|
||||
|
||||
// Load 1x16 from row 0 of A
|
||||
ymm0 = _mm256_maskload_ps( a_use, k_rem_mask );
|
||||
|
||||
ymm14 = _mm256_fmadd_ps( ymm0, ymm7, ymm14 );
|
||||
}
|
||||
// When only fringe 1, update the registers to store in order
|
||||
if (!(mr0 & 0x2)) ymm12 = ymm14;
|
||||
}
|
||||
|
||||
LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, ymm0, ymm1, ymm2, ymm3, xmm1 );
|
||||
if (regidx == 0) ymm8 = _mm256_insertf128_ps(ymm8, xmm1, 0);
|
||||
else ymm8 = _mm256_insertf128_ps(ymm8, xmm1, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// scale accumulated output with alpha
|
||||
ymm0 = _mm256_set1_ps( alpha );
|
||||
ymm8 = _mm256_mul_ps( ymm8, ymm0 );
|
||||
|
||||
if( beta != 0.0f )
|
||||
{
|
||||
const float *_cbuf = c_use;
|
||||
|
||||
//C = beta*C + alpha*A*B
|
||||
ymm3 = _mm256_set1_ps(beta);
|
||||
if( rs_c == 1 )
|
||||
{
|
||||
ymm0 = _mm256_maskload_ps( _cbuf, store_mask );
|
||||
}
|
||||
else
|
||||
{
|
||||
// load c into ymm0
|
||||
float ctemp[8] = { 0 };
|
||||
for( dim_t i = 0; i < 8; i++ )
|
||||
{
|
||||
ctemp[i] = _cbuf[i * rs_c];
|
||||
}
|
||||
ymm0 = _mm256_loadu_ps( ctemp );
|
||||
}
|
||||
|
||||
// scale c with beta
|
||||
ymm8 = _mm256_fmadd_ps( ymm0, ymm3, ymm8 );
|
||||
}
|
||||
|
||||
// post-ops
|
||||
post_ops_attr.is_last_k = TRUE;
|
||||
lpgemm_post_op *post_ops_list_temp = post_op;
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP
|
||||
|
||||
POST_OPS_BIAS_1x16F:
|
||||
{
|
||||
if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
|
||||
{
|
||||
ymm0 = _mm256_set1_ps(*( ( float * )post_ops_list_temp->op_args1 ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
// If original output was columns major, then by the time
|
||||
// kernel sees it, the matrix would be accessed as if it were
|
||||
// transposed. Due to this the bias array will be accessed by
|
||||
// the ic index, and each bias element corresponds to an
|
||||
// entire row of the transposed output array, instead of an
|
||||
// entire column.
|
||||
ymm0 = _mm256_maskload_ps( ( float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i , store_mask );
|
||||
}
|
||||
ymm8 = _mm256_add_ps(ymm0, ymm8);
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
|
||||
POST_OPS_RELU_1x16F:
|
||||
{
|
||||
ymm0 = _mm256_setzero_ps();
|
||||
ymm8 = _mm256_max_ps( ymm8, ymm0 );
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_RELU_SCALE_1x16F:
|
||||
{
|
||||
ymm0 = _mm256_set1_ps( *( float* )post_ops_list_temp->op_args2 );
|
||||
ymm1 = _mm256_setzero_ps();
|
||||
|
||||
RELU_SCALE_OP_F32_AVX2( ymm8, ymm0, ymm1, ymm2 );
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_GELU_TANH_1x16F:
|
||||
{
|
||||
__m256 dn, x_tanh;
|
||||
__m256i q;
|
||||
|
||||
// c[0,0-3]
|
||||
GELU_TANH_F32_AVX2_DEF(ymm8, ymm0, ymm1, ymm2, ymm3, dn, x_tanh, q)
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_GELU_ERF_1x16F:
|
||||
{
|
||||
// c[0, 0-15]
|
||||
GELU_ERF_F32S_AVX2(ymm8, ymm0, ymm1, ymm2)
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_CLIP_1x16F:
|
||||
{
|
||||
ymm0 = _mm256_set1_ps(*(float *)post_ops_list_temp->op_args2);
|
||||
ymm1 = _mm256_set1_ps(*(float *)post_ops_list_temp->op_args3);
|
||||
|
||||
// c[0, 0-15]
|
||||
CLIP_F32S_AVX2(ymm8, ymm0, ymm1)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_DOWNSCALE_1x16F:
|
||||
{
|
||||
__m256 zero_point0 = _mm256_setzero_ps();
|
||||
__m256 selector1 = _mm256_setzero_ps();
|
||||
// Need to account for row vs column major swaps. For scalars
|
||||
// scale and zero point, no implications.
|
||||
// Even though different registers are used for scalar in column
|
||||
// and row major downscale path, all those registers will contain
|
||||
// the same value.
|
||||
if ( post_ops_list_temp->scale_factor_len == 1 )
|
||||
{
|
||||
selector1 =
|
||||
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
|
||||
|
||||
}
|
||||
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 )
|
||||
{
|
||||
zero_point0 = _mm256_set1_ps( *(float *)post_ops_list_temp->op_args1 );
|
||||
}
|
||||
if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
|
||||
{
|
||||
// Scale/zp len cannot be > 1, since orignal n = 1.
|
||||
F32_SCL_MULRND_AVX2(ymm8, selector1, zero_point0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// If original output was columns major, then by the time
|
||||
// kernel sees it, the matrix would be accessed as if it were
|
||||
// transposed. Due to this the scale as well as zp array will
|
||||
// be accessed by the ic index, and each scale/zp element
|
||||
// corresponds to an entire row of the transposed output array,
|
||||
// instead of an entire column.
|
||||
if( post_ops_list_temp->scale_factor_len > 1 )
|
||||
{
|
||||
selector1 = _mm256_maskload_ps( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_i, store_mask );
|
||||
}
|
||||
if( *( dim_t*)post_ops_list_temp->op_args3 > 1 )
|
||||
{
|
||||
zero_point0 = _mm256_maskload_ps( ( float * )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i, store_mask );
|
||||
}
|
||||
F32_SCL_MULRND_AVX2(ymm8, selector1, zero_point0);
|
||||
}
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_MATRIX_ADD_1x16F:
|
||||
{
|
||||
|
||||
__m256 selector1 = _mm256_setzero_ps();
|
||||
|
||||
dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3;
|
||||
|
||||
__m256 scl_fctr1 = _mm256_setzero_ps();
|
||||
|
||||
// Even though different registers are used for scalar in column and
|
||||
// row major case, all those registers will contain the same value.
|
||||
// For column major, if m==1, then it means n=1 and scale_factor_len=1.
|
||||
if ( post_ops_list_temp->scale_factor_len == 1 )
|
||||
{
|
||||
scl_fctr1 =
|
||||
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
if ( ( *( char* )post_ops_list_temp->op_args2 == 'c' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'C' ) )
|
||||
{
|
||||
scl_fctr1 =
|
||||
_mm256_maskload_ps( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_i + ( 0 * 16 ), store_mask );
|
||||
}
|
||||
}
|
||||
float* matptr = ( float* )post_ops_list_temp->op_args1;
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
selector1 = _mm256_maskload_ps(( matptr +
|
||||
post_ops_attr.post_op_c_i ), store_mask );
|
||||
selector1 = _mm256_mul_ps( selector1, scl_fctr1 );
|
||||
ymm8 = _mm256_add_ps( selector1, ymm8 );
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
float ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( matptr +
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm256_maskload_ps( ctemp, store_mask );
|
||||
selector1 = _mm256_mul_ps( selector1, scl_fctr1 ); \
|
||||
ymm8 = _mm256_add_ps( selector1, ymm8 );
|
||||
}
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_MATRIX_MUL_1x16F:
|
||||
{
|
||||
__m256 selector1 = _mm256_setzero_ps();
|
||||
|
||||
dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3;
|
||||
|
||||
__m256 scl_fctr1 = _mm256_setzero_ps();
|
||||
|
||||
// Even though different registers are used for scalar in column and
|
||||
// row major case, all those registers will contain the same value.
|
||||
// For column major, if m==1, then it means n=1 and scale_factor_len=1.
|
||||
if ( post_ops_list_temp->scale_factor_len == 1 )
|
||||
{
|
||||
scl_fctr1 =
|
||||
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
if ( ( *( char* )post_ops_list_temp->op_args2 == 'c' ) ||
|
||||
( *( char* )post_ops_list_temp->op_args2 == 'C' ) )
|
||||
{
|
||||
scl_fctr1 =
|
||||
_mm256_maskload_ps( ( float* )post_ops_list_temp->scale_factor +
|
||||
post_ops_attr.post_op_c_i + ( 0 * 16 ), store_mask );
|
||||
}
|
||||
}
|
||||
float* matptr = ( float* )post_ops_list_temp->op_args1;
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
selector1 = _mm256_maskload_ps(( matptr +
|
||||
post_ops_attr.post_op_c_i ), store_mask );
|
||||
selector1 = _mm256_mul_ps( selector1, scl_fctr1 );
|
||||
ymm8 = _mm256_mul_ps( selector1, ymm8 );
|
||||
}
|
||||
else
|
||||
{
|
||||
float ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( matptr +
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm256_maskload_ps( ctemp, store_mask );
|
||||
selector1 = _mm256_mul_ps( selector1, scl_fctr1 ); \
|
||||
ymm8 = _mm256_mul_ps( selector1, ymm8 );
|
||||
}
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_SWISH_1x16F:
|
||||
{
|
||||
ymm7 =
|
||||
_mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
|
||||
__m256i ex_out;
|
||||
|
||||
// c[0, 0-15]
|
||||
SWISH_F32_AVX2_DEF(ymm8, ymm7, ymm0, ymm1, ymm2, ymm3, ymm4, ex_out);
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_TANH_1x16F:
|
||||
{
|
||||
__m256i ymm6;
|
||||
// c[0, 0-15]
|
||||
TANH_F32S_AVX2(ymm8, ymm0, ymm1, ymm2, ymm3, ymm4, ymm6)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_SIGMOID_1x16F:
|
||||
{
|
||||
__m256i ex_out;
|
||||
|
||||
// c[0, 0-15]
|
||||
SIGMOID_F32_AVX2_DEF(ymm8, ymm0, ymm1, ymm2, ymm3, ymm4, ex_out);
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
|
||||
POST_OPS_1x16F_DISABLE:
|
||||
{
|
||||
|
||||
if( rs_c == 1 )
|
||||
{
|
||||
_mm256_maskstore_ps ( c_use, store_mask, ymm8 );
|
||||
}
|
||||
else
|
||||
{
|
||||
// store c from ymm0
|
||||
float ctemp[8];
|
||||
_mm256_storeu_ps( ctemp, ymm8 );
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
c_use[i * rs_c] = ctemp[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
post_ops_attr.post_op_c_i += MR;
|
||||
} // mr loop
|
||||
}
|
||||
#endif // BLIS_ADDON_LPGEMM
|
||||
|
||||
Reference in New Issue
Block a user