Implemented AVX2 based GEMV for n=1 case.

- Added a new GEMV kernel with MR = 8 which will be used
  for cases where n=1.
- Modified GEMM and GEMV framework to choose right GEMV kernel
  based on compile-time and run-time architecture parameters. This
  had to be done since GEMV kernels are not stored-in/retrieved-from
  the cntx.
- Added a pack kernel that packs A matrix from col-major to row-major
  using AVX2 instructions.

AMD-Internal: [SWLCSG-3519]
Change-Id: Ibf7a8121d0bde37660eac58a160c5b9c9ebd2b5c
This commit is contained in:
Meghana Vankadari
2025-04-23 12:14:57 +05:30
parent a285cf4b27
commit 21aa63eca1
7 changed files with 1203 additions and 69 deletions

View File

@@ -70,14 +70,12 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32)
// Extra space since packing does width in multiples of NR.
dim_t n_reorder;
#ifdef BLIS_KERNELS_ZEN4
if( ( n == 1 ) && ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
if( ( n == 1 ) ) //&& ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
{
//When n == 1, LPGEMV doesn't expect B to be reordered.
n_reorder = 1;
}
else
#endif
{
n_reorder = ( ( n + NR - 1 ) / NR ) * NR;
}
@@ -172,10 +170,10 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
dim_t n_threads = bli_rntm_num_threads( &rntm_g );
n_threads = ( n_threads > 0 ) ? n_threads : 1;
#ifdef BLIS_KERNELS_ZEN4
//When n == 1, B marix becomes a vector.
//Reordering is avoided so that LPGEMV can process it efficiently.
if( ( n == 1 ) && ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
if( ( n == 1 ) ) //&& ( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
{
if(rs_b == 1)
{
@@ -189,7 +187,6 @@ AOCL_GEMM_REORDER(float,f32f32f32of32)
}
return;
}
#endif
#ifdef BLIS_ENABLE_OPENMP
_Pragma( "omp parallel num_threads(n_threads)" )

View File

@@ -62,7 +62,41 @@ typedef void (*lpgemm_rowvar_f32)
lpgemm_post_op_attr
);
#ifdef BLIS_KERNELS_ZEN4
typedef void (*lpgemv_n_one_ker_ft)
(
const dim_t,
const dim_t,
const float*,
const dim_t,
const dim_t,
const AOCL_MEMORY_TAG,
const float*,
const dim_t,
const dim_t,
const AOCL_MEMORY_TAG,
float*,
const dim_t,
const dim_t,
const float,
const float,
const dim_t,
const dim_t,
lpgemm_post_op*,
lpgemm_post_op_attr*
);
typedef void (*lpgemv_n_one_a_pack_ft)
(
float*,
const float*,
const dim_t,
const dim_t,
const dim_t,
const dim_t,
dim_t*,
dim_t*
);
LPGEMV(float, float, float, f32f32f32of32)
{
const float* a_use = (float*)a;
@@ -103,9 +137,32 @@ LPGEMV(float, float, float, f32f32f32of32)
if(n == 1)
{
float* pack_b_buffer_f32f32f32of32;
//TODO: AVX2 support need to be added
dim_t MR;
lpgemv_n_one_ker_ft ker_fp;
lpgemv_n_one_a_pack_ft packa_fp;
// Workaround to select right kernel and blocksizes based on arch
// since GEMV parameters are not available in lpgemm context.
#ifdef BLIS_KERNELS_ZEN4
if( lpgemm_get_enabled_arch() == BLIS_ARCH_ZEN3 )
{
MR = 8;
ker_fp = lpgemv_n_one_f32f32f32of32_avx2;
packa_fp = packa_mr8_f32f32f32of32_col_major;
}
else
{
MR = 16;
ker_fp = lpgemv_n_one_f32f32f32of32;
packa_fp = packa_mr16_f32f32f32of32_col_major;
}
#else
// Increased MR from 6 to 16 to make use of 32 ZMM registers
dim_t MR = 16;
MR = 8;
ker_fp = lpgemv_n_one_f32f32f32of32_avx2;
packa_fp = packa_mr8_f32f32f32of32_col_major;
#endif
// Pack B matrix if rs_b > 1
if( ( mtag_b == PACK ) && ( rs_b != 1 ) )
{
@@ -144,6 +201,7 @@ LPGEMV(float, float, float, f32f32f32of32)
c_use = c + ic * rs_c;
post_ops_attr.post_op_c_i = ic;
// To-Do: pack A case needs to be handled for AVX2 case.
if( mtag_a == PACK && cs_a != 1 )
{
mem_a_size_req = sizeof(float) * mc0 * k;
@@ -154,7 +212,7 @@ LPGEMV(float, float, float, f32f32f32of32)
);
pack_a_buffer_f32f32f32of32 = ( float* )bli_mem_buffer( &mem_a );
packa_mr16_f32f32f32of32_col_major
packa_fp
(
pack_a_buffer_f32f32f32of32,
a_use, rs_a, cs_a,
@@ -163,19 +221,23 @@ LPGEMV(float, float, float, f32f32f32of32)
);
a_use = pack_a_buffer_f32f32f32of32;
}
// Call lpgemv_n_one kernel
lpgemv_n_one_f32f32f32of32
ker_fp
(
mc0, k,
a_use, rs_a_use, cs_a_use, mtag_a,
b_use, rs_b_use, cs_b_use, mtag_b,
c_use, rs_c, cs_c,
alpha, beta,
MR, KC,
post_op_list,
&post_ops_attr
mc0, k,
a_use, rs_a_use, cs_a_use, mtag_a,
b_use, rs_b_use, cs_b_use, mtag_b,
c_use, rs_c, cs_c,
alpha, beta,
MR, KC,
post_op_list,
&post_ops_attr
);
if ( mtag_a == PACK )
{
// Release pack buffer for A.
bli_pba_release( rntm, &mem_a );
}
}
if ( ( mtag_a == PACK ) && ( bli_mem_is_alloc( &mem_a ) ) )
{
@@ -188,6 +250,8 @@ LPGEMV(float, float, float, f32f32f32of32)
}
else
{
// m = 1 case is not implemented yet for AVX2
#ifdef BLIS_KERNELS_ZEN4
// Compute the JC loop thread range for the current thread.
dim_t jc_start, jc_end;
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
@@ -312,19 +376,21 @@ LPGEMV(float, float, float, f32f32f32of32)
{
bli_pba_release( rntm, &mem_b );
}
#endif // m == 1 case is not implemented for AVX2 yet.
}
}
#endif
LPGEMM_5LOOP(float, float, float, f32f32f32of32)
{
#ifdef BLIS_KERNELS_ZEN4
// Handle using LPGEMV when m or/and n equal to 1
// The avx512 check will be removed when avx2 kernels added in future
if ( ( ( m == 1 ) || ( n == 1 ) ) &&
( bli_cpuid_is_avx512_supported() == TRUE ) &&
( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
#ifdef BLIS_KERNELS_ZEN4
if ( ( ( (m == 1) && (lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3) ) || ( n == 1 ) ) &&
( bli_cpuid_is_avx512_supported() == TRUE ) )
#else
// m=1 case is not implemented yet for AVX2
if ( ( ( n == 1 ) ) && ( bli_cpuid_is_avx2fma3_supported() == TRUE ) )
#endif
{
lpgemv_rowvar_f32f32f32of32(m, n, k,
a, rs_a, cs_a, mtag_a,
@@ -339,7 +405,6 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
c_downscale);
return;
}
#endif
// Query the context for various blocksizes.
const dim_t NC = lcntx->blksz.NC;

View File

@@ -62,7 +62,43 @@ typedef void (*lpgemm_rowvar_f32)
lpgemm_post_op_attr
);
#ifdef BLIS_KERNELS_ZEN4
typedef void (*lpgemv_n_one_ker_ft)
(
const dim_t,
const dim_t,
const float*,
const dim_t,
const dim_t,
const AOCL_MEMORY_TAG,
const float*,
const dim_t,
const dim_t,
const AOCL_MEMORY_TAG,
float*,
const dim_t,
const dim_t,
const float,
const float,
const dim_t,
const dim_t,
lpgemm_post_op*,
lpgemm_post_op_attr*
);
typedef void (*lpgemv_n_one_a_pack_ft)
(
float*,
const float*,
const dim_t,
const dim_t,
const dim_t,
const dim_t,
dim_t*,
dim_t*
);
LPGEMV_TINY(float, float, float, f32f32f32of32)
{
const float* a_use = ( float* )a;
@@ -84,9 +120,32 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
float* pack_b_buffer_f32f32f32of32 = NULL;
err_t err = BLIS_SUCCESS;
//TODO: AVX2 support need to be added
dim_t MR;
lpgemv_n_one_ker_ft ker_fp;
lpgemv_n_one_a_pack_ft packa_fp;
// Workaround to select right kernel and blocksizes based on arch
// since GEMV parameters are not available in lpgemm context.
#ifdef BLIS_KERNELS_ZEN4
if( lpgemm_get_enabled_arch() == BLIS_ARCH_ZEN3 )
{
MR = 8;
ker_fp = lpgemv_n_one_f32f32f32of32_avx2;
packa_fp = packa_mr8_f32f32f32of32_col_major;
}
else
{
MR = 16;
ker_fp = lpgemv_n_one_f32f32f32of32;
packa_fp = packa_mr16_f32f32f32of32_col_major;
}
#else
// Increased MR from 6 to 16 to make use of 32 ZMM registers
dim_t MR = 16;
MR = 8;
ker_fp = lpgemv_n_one_f32f32f32of32_avx2;
packa_fp = packa_mr8_f32f32f32of32_col_major;
#endif
// Pack B matrix if rs_b > 1
if( ( mtag_b == PACK ) && ( rs_b != 1 ) )
@@ -111,7 +170,7 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
pack_a_buffer_f32f32f32of32 =
( float* )bli_malloc_user(mem_a_size_req, &err);
packa_mr16_f32f32f32of32_col_major
packa_fp
(
pack_a_buffer_f32f32f32of32,
a_use, rs_a, cs_a,
@@ -123,7 +182,7 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_j = 0;
lpgemv_n_one_f32f32f32of32
ker_fp
(
m, k,
a_use, rs_a_use, cs_a_use, mtag_a,
@@ -145,16 +204,19 @@ LPGEMV_TINY(float, float, float, f32f32f32of32)
}
}
}
#endif
LPGEMM_TINY(float,float,float,f32f32f32of32)
{
// Handle using LPGEMV when m or/and n equal to 1
#ifdef BLIS_KERNELS_ZEN4
// Handle using LPGEMV when m or/and n equal to 1
// The avx512 check will be removed when avx2 kernels added in future
if ( ( n == 1 ) &&
( bli_cpuid_is_avx512_supported() == TRUE ) &&
( lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3 ) )
if ( ( ( (m == 1) && (lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3) ) || ( n == 1 ) ) &&
( bli_cpuid_is_avx512_supported() == TRUE ) )
#else
// m=1 case is not implemented yet for AVX2
if ( ( ( n == 1 ) ) && ( bli_cpuid_is_avx2fma3_supported() == TRUE ) )
#endif
{
lpgemv_rowvar_tiny_f32f32f32of32(m, n, k,
a, rs_a, cs_a, mtag_a,
@@ -167,7 +229,7 @@ LPGEMM_TINY(float,float,float,f32f32f32of32)
c_downscale);
return;
}
#endif
const dim_t NR = lcntx->blksz.NR;
const dim_t MR = lcntx->blksz.MR;

View File

@@ -46,6 +46,18 @@ void packa_mr16_f32f32f32of32_col_major
dim_t* cs_p
);
void packa_mr8_f32f32f32of32_col_major
(
float* pack_a_buffer,
const float* a,
const dim_t rs_a,
const dim_t cs_a,
const dim_t MC,
const dim_t KC,
dim_t* rs_p,
dim_t* cs_p
);
void packa_mr6_f32f32f32of32_avx512
(
float* pack_a_buf,

View File

@@ -782,6 +782,7 @@ void lpgemv_n_one_ ## LP_SFX \
) \
LPGEMV_N_EQ1_KERN(float, float, float,f32f32f32of32);
LPGEMV_N_EQ1_KERN(float, float, float,f32f32f32of32_avx2);
LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float,bf16bf16f32of32);
LPGEMV_N_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32);
LPGEMV_N_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32);

View File

@@ -0,0 +1,467 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <immintrin.h>
#include <string.h>
#include "blis.h"
#ifdef BLIS_ADDON_LPGEMM
#define UNPACKLO_PS8 \
b_reg[0] = _mm256_unpacklo_ps( a_reg[0], a_reg[1] ); \
b_reg[1] = _mm256_unpacklo_ps( a_reg[2], a_reg[3] ); \
b_reg[2] = _mm256_unpacklo_ps( a_reg[4], a_reg[5] ); \
b_reg[3] = _mm256_unpacklo_ps( a_reg[6], a_reg[7] );
#define UNPACKHI_PS8 \
b_reg[4] = _mm256_unpackhi_ps( a_reg[0], a_reg[1] ); \
b_reg[5] = _mm256_unpackhi_ps( a_reg[2], a_reg[3] ); \
b_reg[6] = _mm256_unpackhi_ps( a_reg[4], a_reg[5] ); \
b_reg[7] = _mm256_unpackhi_ps( a_reg[6], a_reg[7] );
#define UNPACKLO_PD8 \
a_reg[0] = (__m256)_mm256_unpacklo_pd( (__m256d)b_reg[0], (__m256d)b_reg[1] ); \
a_reg[1] = (__m256)_mm256_unpacklo_pd( (__m256d)b_reg[2], (__m256d)b_reg[3] ); \
a_reg[2] = (__m256)_mm256_unpacklo_pd( (__m256d)b_reg[4], (__m256d)b_reg[5] ); \
a_reg[3] = (__m256)_mm256_unpacklo_pd( (__m256d)b_reg[6], (__m256d)b_reg[7] );
#define UNPACKHI_PD8 \
a_reg[4] = (__m256)_mm256_unpackhi_pd( (__m256d)b_reg[0], (__m256d)b_reg[1] ); \
a_reg[5] = (__m256)_mm256_unpackhi_pd( (__m256d)b_reg[2], (__m256d)b_reg[3] ); \
a_reg[6] = (__m256)_mm256_unpackhi_pd( (__m256d)b_reg[4], (__m256d)b_reg[5] ); \
a_reg[7] = (__m256)_mm256_unpackhi_pd( (__m256d)b_reg[6], (__m256d)b_reg[7] );
#define PERMUTE2F128_PS8 \
b_reg[0] = _mm256_permute2f128_ps( a_reg[0], a_reg[1], 0x20 ); \
b_reg[1] = _mm256_permute2f128_ps( a_reg[4], a_reg[5], 0x20 ); \
b_reg[2] = _mm256_permute2f128_ps( a_reg[2], a_reg[3], 0x20 ); \
b_reg[3] = _mm256_permute2f128_ps( a_reg[6], a_reg[7], 0x20 ); \
b_reg[4] = _mm256_permute2f128_ps( a_reg[0], a_reg[1], 0x31 ); \
b_reg[5] = _mm256_permute2f128_ps( a_reg[4], a_reg[5], 0x31 ); \
b_reg[6] = _mm256_permute2f128_ps( a_reg[2], a_reg[3], 0x31 ); \
b_reg[7] = _mm256_permute2f128_ps( a_reg[6], a_reg[7], 0x31 );
void packa_mr8_f32f32f32of32_col_major
(
float* pack_a_buffer,
const float* a,
const dim_t rs_a,
const dim_t cs_a,
const dim_t MC,
const dim_t KC,
dim_t* rs_p,
dim_t* cs_p
)
{
dim_t MR = 8;
dim_t ic, kr;
__m256 a_reg[8], b_reg[8];
__m256i k_masks[3] = {
_mm256_set_epi32( 0, 0, 0, 0, 0, 0, 0, -1), // 1 element
_mm256_set_epi32( 0, 0, 0, 0, 0, 0, -1, -1), // 2 elements
_mm256_set_epi32( 0, 0, 0, 0, -1, -1, -1, -1), // 4 elements
};
__m256i load_mask, store_mask;
// These registers are set with zeroes to avoid compiler warnings
// To-DO: TO be removed when pack code is optimized for fringe cases.
a_reg[0] = _mm256_setzero_ps();
a_reg[1] = _mm256_setzero_ps();
a_reg[2] = _mm256_setzero_ps();
a_reg[3] = _mm256_setzero_ps();
a_reg[4] = _mm256_setzero_ps();
a_reg[5] = _mm256_setzero_ps();
a_reg[6] = _mm256_setzero_ps();
a_reg[7] = _mm256_setzero_ps();
for( ic = 0; ( ic + MR -1 ) < MC; ic += MR )
{
for( kr = 0; ( kr + 7 ) < KC; kr += 8 )
{
// Transposing the 8x8 block of data
a_reg[0] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) );
a_reg[1] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) );
a_reg[2] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ) );
a_reg[3] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ) );
a_reg[4] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ) );
a_reg[5] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ) );
a_reg[6] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ) );
a_reg[7] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ) );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), b_reg[0] );
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), b_reg[1] );
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), b_reg[2] );
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), b_reg[3] );
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 4 ) * KC + kr ), b_reg[4] );
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 5 ) * KC + kr ), b_reg[5] );
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 6 ) * KC + kr ), b_reg[6] );
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 7 ) * KC + kr ), b_reg[7] );
}
store_mask = k_masks[2]; // mask to store 4 elements
for( ; ( kr + 3 ) < KC; kr += 4 )
{
// Transposing the 8x8 block of data
a_reg[0] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) );
a_reg[1] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) );
a_reg[2] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ) );
a_reg[3] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ) );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), store_mask, b_reg[2] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), store_mask, b_reg[3] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 4 ) * KC + kr ), store_mask, b_reg[4] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 5 ) * KC + kr ), store_mask, b_reg[5] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 6 ) * KC + kr ), store_mask, b_reg[6] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 7 ) * KC + kr ), store_mask, b_reg[7] );
}
store_mask = k_masks[1]; // mask to store 2 elements
for( ; ( kr + 1 ) < KC; kr += 2 )
{
// Transposing the 8x8 block of data
a_reg[0] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) );
a_reg[1] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), store_mask, b_reg[2] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), store_mask, b_reg[3] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 4 ) * KC + kr ), store_mask, b_reg[4] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 5 ) * KC + kr ), store_mask, b_reg[5] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 6 ) * KC + kr ), store_mask, b_reg[6] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 7 ) * KC + kr ), store_mask, b_reg[7] );
}
store_mask = k_masks[0]; // mask to store 1 element
for( ; ( kr + 0 ) < KC; kr += 1 )
{
// Transposing the 8x8 block of data
a_reg[0] = _mm256_loadu_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), store_mask, b_reg[2] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), store_mask, b_reg[3] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 4 ) * KC + kr ), store_mask, b_reg[4] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 5 ) * KC + kr ), store_mask, b_reg[5] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 6 ) * KC + kr ), store_mask, b_reg[6] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 7 ) * KC + kr ), store_mask, b_reg[7] );
}
}
for( ; ( ic + 3 ) < MC; ic += 4 )
{
load_mask = k_masks[2]; // mask to load 4 elements
for( kr = 0; ( kr + 7 ) < KC; kr += 8 )
{
// Transposing the 8x8 block of data
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
a_reg[2] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ), load_mask );
a_reg[3] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ), load_mask );
a_reg[4] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ), load_mask );
a_reg[5] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ), load_mask );
a_reg[6] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ), load_mask );
a_reg[7] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ), load_mask );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), b_reg[0] );
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), b_reg[1] );
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), b_reg[2] );
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), b_reg[3] );
}
store_mask = k_masks[2]; // mask to store 4 elements
for( ; ( kr + 3 ) < KC; kr += 4 )
{
// Transposing the 4x4 block of data
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
a_reg[2] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ), load_mask );
a_reg[3] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ), load_mask );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), store_mask, b_reg[2] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), store_mask, b_reg[3] );
}
store_mask = k_masks[1]; // mask to store 2 elements
for( ; ( kr + 1 ) < KC; kr += 2 )
{
// transposing the 4x2 block of data
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), store_mask, b_reg[2] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), store_mask, b_reg[3] );
}
store_mask = k_masks[0]; // mask to store 1 element
for( ; ( kr + 0 ) < KC; kr += 1 )
{
// transposing the 4x1 block of data
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), store_mask, b_reg[2] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), store_mask, b_reg[3] );
}
}
for( ; ( ic + 1 ) < MC; ic += 2 )
{
load_mask = k_masks[1]; // mask to load 2 elements
for( kr = 0; ( kr + 7 ) < KC; kr += 8 )
{
// Transposing the 2x8 block of data
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
a_reg[2] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ), load_mask );
a_reg[3] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ), load_mask );
a_reg[4] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ), load_mask );
a_reg[5] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ), load_mask );
a_reg[6] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ), load_mask );
a_reg[7] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ), load_mask );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), b_reg[0] );
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), b_reg[1] );
}
store_mask = k_masks[2]; // mask to store 4 elements
for( ; ( kr + 3 ) < KC; kr += 4 )
{
// Transposing the 2x4 block of data
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
a_reg[2] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ), load_mask );
a_reg[3] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ), load_mask );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
}
store_mask = k_masks[1]; // mask to store 2 elements
for( ; ( kr + 1 ) < KC; kr += 2 )
{
// Transposing the 2x2 block of data
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
}
store_mask = k_masks[0]; // mask to store 1 element
for( ; ( kr + 0 ) < KC; kr += 1 )
{
// Transposing the 2x1 block of data
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), store_mask, b_reg[1] );
}
}
for( ; ( ic + 0 ) < MC; ic += 1 )
{
load_mask = k_masks[0]; // mask to load 1 element
for( kr = 0; ( kr + 7 ) < KC; kr += 8 )
{
// Transposing the 1x8 block of data
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
a_reg[2] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ), load_mask );
a_reg[3] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ), load_mask );
a_reg[4] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ), load_mask );
a_reg[5] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ), load_mask );
a_reg[6] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ), load_mask );
a_reg[7] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ), load_mask );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_storeu_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), b_reg[0] );
}
store_mask = k_masks[2]; // mask to store 4 elements
for( ; ( kr + 3 ) < KC; kr += 4 )
{
// Transposing the 1x4 block of data
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
a_reg[2] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ), load_mask );
a_reg[3] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ), load_mask );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
}
store_mask = k_masks[1]; // mask to store 2 elements
for( ; ( kr + 1 ) < KC; kr += 2 )
{
// Transposing the 1x2 block of data
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
a_reg[1] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ), load_mask );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
}
store_mask = k_masks[0]; // mask to store 1 element
for( ; ( kr + 0 ) < KC; kr += 1 )
{
// Transposing the 1x1 block of data
a_reg[0] = _mm256_maskload_ps( ( float const* )( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ), load_mask );
UNPACKLO_PS8
UNPACKHI_PS8
UNPACKLO_PD8
UNPACKHI_PD8
PERMUTE2F128_PS8
_mm256_maskstore_ps( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), store_mask, b_reg[0] );
}
}
// Set the row and column strides of the packed matrix.
*rs_p = KC;
*cs_p = 1;
}
#endif // BLIS_ADDON_LPGEMM

View File

@@ -43,33 +43,563 @@
// to produce C output of MRX1. The vectorization is done in k loop and
// the horizontal reduction done to produce one output from each
// accumulator register
void lpgemv_n_one_kernel_f32_avx2_ker_ft
(
const dim_t m0,
const dim_t k,
const float *a,
const dim_t rs_a,
const dim_t cs_a,
const AOCL_MEMORY_TAG mtag_a,
const float *b,
const dim_t rs_b,
const dim_t cs_b,
const AOCL_MEMORY_TAG mtag_b,
float *c,
const dim_t rs_c,
const dim_t cs_c,
const float alpha,
const float beta,
const dim_t MR,
const dim_t KC,
lpgemm_post_op *post_op_list,
lpgemm_post_op_attr *post_op_attr
)
{
//TODO: Created dummy function as place holder to get
//rid of linking issues in other zen configurations.
//AVX2 varient wil be implemented in next commits.
//Code will take LPGEMM path for LPGEMV in AVX2 env.
}
#define LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, paddr, stride ) \
ymm0 = _mm256_loadu_ps( paddr ); \
ymm1 = _mm256_loadu_ps( paddr + stride ); \
ymm2 = _mm256_loadu_ps( paddr + 2 * stride ); \
ymm3 = _mm256_loadu_ps( paddr + 3 * stride );
#define LPGEMV_N_KERNEL_4_MASKLOADS( ymm0, ymm1, ymm2, ymm3, mask, paddr, stride ) \
ymm0 = _mm256_maskload_ps( paddr, mask ); \
ymm1 = _mm256_maskload_ps( paddr + stride, mask ); \
ymm2 = _mm256_maskload_ps( paddr + 2 * stride, mask ); \
ymm3 = _mm256_maskload_ps( paddr + 3 * stride, mask );
#define LPGEMV_N_KERNEL_4_FMA( ymm8, ymm9, ymm10, ymm11, ymm7, ymm0, ymm1, ymm2, ymm3 ) \
ymm8 = _mm256_fmadd_ps( ymm0, ymm7, ymm8 ); \
ymm9 = _mm256_fmadd_ps( ymm1, ymm7, ymm9 ); \
ymm10 = _mm256_fmadd_ps( ymm2, ymm7, ymm10 ); \
ymm11 = _mm256_fmadd_ps( ymm3, ymm7, ymm11 );
#define LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, ymm0, ymm1, ymm2, ymm3, xmm0 ) \
ymm0 = _mm256_hadd_ps( ymm8, ymm9 ); \
ymm1 = _mm256_hadd_ps( ymm10, ymm11 ); \
ymm0 = _mm256_hadd_ps( ymm0, ymm1 ); \
xmm0 = _mm_add_ps(_mm256_extractf128_ps(ymm0, 0), _mm256_extractf128_ps(ymm0,1));
#define RELU_SCALE_OP_F32_AVX2( acc, scale_reg, zreg, scratch_reg ) \
scratch_reg = _mm256_min_ps( acc, zreg ); \
acc = _mm256_max_ps( acc, zreg ); \
scratch_reg = _mm256_mul_ps( scratch_reg, scale_reg ); \
acc = _mm256_or_ps( acc, scratch_reg );
LPGEMV_N_EQ1_KERN( float, float, float, f32f32f32of32_avx2 )
{
static void *post_ops_labels[] =
{
&&POST_OPS_1x16F_DISABLE,
&&POST_OPS_BIAS_1x16F,
&&POST_OPS_RELU_1x16F,
&&POST_OPS_RELU_SCALE_1x16F,
&&POST_OPS_GELU_TANH_1x16F,
&&POST_OPS_GELU_ERF_1x16F,
&&POST_OPS_CLIP_1x16F,
&&POST_OPS_DOWNSCALE_1x16F,
&&POST_OPS_MATRIX_ADD_1x16F,
&&POST_OPS_SWISH_1x16F,
&&POST_OPS_MATRIX_MUL_1x16F,
&&POST_OPS_TANH_1x16F,
&&POST_OPS_SIGMOID_1x16F
};
// Strides are updated based on matrix packing/reordering.
const float *a_use = NULL;
const float *b_use = NULL;
float *c_use = NULL;
lpgemm_post_op_attr post_ops_attr = *(post_op_attr);
__m256 ymm0, ymm1, ymm2, ymm3, ymm4, ymm7;
__m256 ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15;
__m128 xmm0, xmm1;
__m256i masks[9] = {
_mm256_set_epi32( 0, 0, 0, 0, 0, 0, 0, 0), // 0 elements
_mm256_set_epi32( 0, 0, 0, 0, 0, 0, 0, -1), // 1 element
_mm256_set_epi32( 0, 0, 0, 0, 0, 0, -1, -1), // 2 elements
_mm256_set_epi32( 0, 0, 0, 0, 0, -1, -1, -1), // 3 elements
_mm256_set_epi32( 0, 0, 0, 0, -1, -1, -1, -1), // 4 elements
_mm256_set_epi32( 0, 0, 0, -1, -1, -1, -1, -1), // 5 elements
_mm256_set_epi32( 0, 0, -1, -1, -1, -1, -1, -1), // 6 elements
_mm256_set_epi32( 0, -1, -1, -1, -1, -1, -1, -1), // 7 elements
_mm256_set_epi32(-1, -1, -1, -1, -1, -1, -1, -1) // 8 elements
};
// MR comes from framework, we need to set it based on the underlying hardware configuration.
for (dim_t mr = 0; mr < m0; mr += MR)
{
dim_t mr0 = bli_min( m0 - mr, MR );
dim_t k_iter = k / 8;
dim_t k_rem = k % 8;
const __m256i store_mask = masks[mr0];
const __m256i k_rem_mask = masks[k_rem];
/* zero the accumulator registers */
ZERO_ACC_YMM_4_REG( ymm8, ymm9, ymm10, ymm11 );
ZERO_ACC_YMM_4_REG( ymm12, ymm13, ymm14, ymm15 );
//update pointers
a_use = a + mr * rs_a;
b_use = b;
c_use = c + mr * rs_c;
if( mr0 == MR )
{
for( dim_t k = 0; k < k_iter; k++ )
{
ymm7 = _mm256_loadu_ps( b_use );
b_use += 8; // move b pointer to next 8 elements
// Load 4x16 from row 0-3 of A
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, a_use, rs_a );
a_use += 4 * rs_a; // move a pointer to next 4x16 elements
// Perform the dot product
LPGEMV_N_KERNEL_4_FMA( ymm8, ymm9, ymm10, ymm11, ymm7, ymm0, ymm1, ymm2, ymm3 );
// Load 4x16 from row 4-7 of A
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, a_use, rs_a );
a_use -= 4 * rs_a; // move a pointer to next 4x16 elements
// Perform the dot product
LPGEMV_N_KERNEL_4_FMA( ymm12, ymm13, ymm14, ymm15, ymm7, ymm0, ymm1, ymm2, ymm3 );
a_use += 8;
} // k-loop
if( k_rem )
{
ymm7 = _mm256_maskload_ps( b_use, k_rem_mask );
// Load 4x16 from row 0-3 of A
LPGEMV_N_KERNEL_4_MASKLOADS( ymm0, ymm1, ymm2, ymm3, k_rem_mask, a_use, rs_a );
a_use += 4 * rs_a; // move a pointer to next 4x16 elements
// Perform the dot product
LPGEMV_N_KERNEL_4_FMA( ymm8, ymm9, ymm10, ymm11, ymm7, ymm0, ymm1, ymm2, ymm3 );
// Load 4x16 from row 4-7 of A
LPGEMV_N_KERNEL_4_MASKLOADS( ymm0, ymm1, ymm2, ymm3, k_rem_mask, a_use, rs_a );
a_use -= 4 * rs_a; // move a pointer to next 4x16 elements
// Perform the dot product
LPGEMV_N_KERNEL_4_FMA( ymm12, ymm13, ymm14, ymm15, ymm7, ymm0, ymm1, ymm2, ymm3 );
}
// Add the registers horizontally to get one output
LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, ymm0, ymm1, ymm2, ymm3, xmm0 );
LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, ymm4, ymm1, ymm2, ymm3, xmm1 );
// compose outputs into one ymm to perform post-ops.
ymm8 = _mm256_insertf128_ps( ymm8, xmm0, 0 );
ymm8 = _mm256_insertf128_ps( ymm8, xmm1, 1 );
}
else
{
//Handle fringe cases when mr0 < MR
const float *a_use_fringe = a_use;
dim_t mr0_use = mr0;
dim_t regidx = 0;
// Dot product for mfringe 4
if (mr0_use >= 4)
{
for( dim_t k = 0; k < k_iter; k++ )
{
ymm7 = _mm256_loadu_ps( b_use );
b_use += 8; // move b pointer to next 8 elements
// Load 4x16 from row 0-3 of A
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, a_use, rs_a );
a_use += 8; // move a pointer to next 4x16 elements
// Perform the dot product
LPGEMV_N_KERNEL_4_FMA( ymm8, ymm9, ymm10, ymm11, ymm7, ymm0, ymm1, ymm2, ymm3 );
} // k-loop
if( k_rem )
{
ymm7 = _mm256_maskload_ps( b_use, k_rem_mask );
// Load 4x16 from row 0-3 of A
LPGEMV_N_KERNEL_4_MASKLOADS( ymm0, ymm1, ymm2, ymm3, k_rem_mask, a_use, rs_a );
// Perform the dot product
LPGEMV_N_KERNEL_4_FMA( ymm8, ymm9, ymm10, ymm11, ymm7, ymm0, ymm1, ymm2, ymm3 );
}
//update pointers
mr0_use -= 4;
a_use = a_use_fringe + 4 * rs_a;
a_use_fringe = a_use;
b_use = b;
//Horizontal add 4 ymm registers and get output into 2 xmm registers
LPGEMV_YMM2XMM(ymm8, ymm9, ymm10, ymm11, ymm0, ymm1, ymm2, ymm3, xmm0)
// compose outputs into one ymm to perform post-ops.
ymm8 = _mm256_insertf128_ps( ymm8, xmm0, 0 );
regidx = 1;
}
// Dot product for <= 3
if (mr0_use)
{
if( mr0_use >= 2 )
{
for (dim_t k = 0; k < k_iter; k++)
{
ymm7 = _mm256_loadu_ps( b_use );
b_use += 8; // move b pointer to next 8 elements
// Load 2x16 from row 0-1 of A
ymm0 = _mm256_loadu_ps( a_use );
ymm1 = _mm256_loadu_ps( a_use + rs_a );
a_use += 8; // move a pointer to next 4x16 elements
ymm12 = _mm256_fmadd_ps( ymm0, ymm7, ymm12 );
ymm13 = _mm256_fmadd_ps( ymm1, ymm7, ymm13 );
} // k-loop
if( k_rem )
{
ymm7 = _mm256_maskload_ps( b_use, k_rem_mask );
// Load 2x16 from row 0-1 of A
ymm0 = _mm256_maskload_ps( a_use, k_rem_mask );
ymm1 = _mm256_maskload_ps( a_use + rs_a, k_rem_mask );
ymm12 = _mm256_fmadd_ps( ymm0, ymm7, ymm12 );
ymm13 = _mm256_fmadd_ps( ymm1, ymm7, ymm13 );
}
//update pointers
mr0_use -= 2;
a_use = a_use_fringe + 2 * rs_a;
a_use_fringe = a_use;
b_use = b;
}
if( mr0_use == 1 )
{
for (dim_t k = 0; k < k_iter; k++)
{
ymm7 = _mm256_loadu_ps( b_use );
b_use += 8; // move b pointer to next 8 elements
// Load 1x16 from row 0 of A
ymm0 = _mm256_loadu_ps( a_use );
a_use += 8; // move a pointer to next 4x16 elements
ymm14 = _mm256_fmadd_ps( ymm0, ymm7, ymm14 );
} // k-loop
if( k_rem )
{
ymm7 = _mm256_maskload_ps( b_use, k_rem_mask );
// Load 1x16 from row 0 of A
ymm0 = _mm256_maskload_ps( a_use, k_rem_mask );
ymm14 = _mm256_fmadd_ps( ymm0, ymm7, ymm14 );
}
// When only fringe 1, update the registers to store in order
if (!(mr0 & 0x2)) ymm12 = ymm14;
}
LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, ymm0, ymm1, ymm2, ymm3, xmm1 );
if (regidx == 0) ymm8 = _mm256_insertf128_ps(ymm8, xmm1, 0);
else ymm8 = _mm256_insertf128_ps(ymm8, xmm1, 1);
}
}
// scale accumulated output with alpha
ymm0 = _mm256_set1_ps( alpha );
ymm8 = _mm256_mul_ps( ymm8, ymm0 );
if( beta != 0.0f )
{
const float *_cbuf = c_use;
//C = beta*C + alpha*A*B
ymm3 = _mm256_set1_ps(beta);
if( rs_c == 1 )
{
ymm0 = _mm256_maskload_ps( _cbuf, store_mask );
}
else
{
// load c into ymm0
float ctemp[8] = { 0 };
for( dim_t i = 0; i < 8; i++ )
{
ctemp[i] = _cbuf[i * rs_c];
}
ymm0 = _mm256_loadu_ps( ctemp );
}
// scale c with beta
ymm8 = _mm256_fmadd_ps( ymm0, ymm3, ymm8 );
}
// post-ops
post_ops_attr.is_last_k = TRUE;
lpgemm_post_op *post_ops_list_temp = post_op;
POST_OP_LABEL_LASTK_SAFE_JUMP
POST_OPS_BIAS_1x16F:
{
if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
ymm0 = _mm256_set1_ps(*( ( float * )post_ops_list_temp->op_args1 ) );
}
else
{
// If original output was columns major, then by the time
// kernel sees it, the matrix would be accessed as if it were
// transposed. Due to this the bias array will be accessed by
// the ic index, and each bias element corresponds to an
// entire row of the transposed output array, instead of an
// entire column.
ymm0 = _mm256_maskload_ps( ( float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i , store_mask );
}
ymm8 = _mm256_add_ps(ymm0, ymm8);
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_RELU_1x16F:
{
ymm0 = _mm256_setzero_ps();
ymm8 = _mm256_max_ps( ymm8, ymm0 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_RELU_SCALE_1x16F:
{
ymm0 = _mm256_set1_ps( *( float* )post_ops_list_temp->op_args2 );
ymm1 = _mm256_setzero_ps();
RELU_SCALE_OP_F32_AVX2( ymm8, ymm0, ymm1, ymm2 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_GELU_TANH_1x16F:
{
__m256 dn, x_tanh;
__m256i q;
// c[0,0-3]
GELU_TANH_F32_AVX2_DEF(ymm8, ymm0, ymm1, ymm2, ymm3, dn, x_tanh, q)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_GELU_ERF_1x16F:
{
// c[0, 0-15]
GELU_ERF_F32S_AVX2(ymm8, ymm0, ymm1, ymm2)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_CLIP_1x16F:
{
ymm0 = _mm256_set1_ps(*(float *)post_ops_list_temp->op_args2);
ymm1 = _mm256_set1_ps(*(float *)post_ops_list_temp->op_args3);
// c[0, 0-15]
CLIP_F32S_AVX2(ymm8, ymm0, ymm1)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_DOWNSCALE_1x16F:
{
__m256 zero_point0 = _mm256_setzero_ps();
__m256 selector1 = _mm256_setzero_ps();
// Need to account for row vs column major swaps. For scalars
// scale and zero point, no implications.
// Even though different registers are used for scalar in column
// and row major downscale path, all those registers will contain
// the same value.
if ( post_ops_list_temp->scale_factor_len == 1 )
{
selector1 =
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
}
if ( *( ( dim_t* )post_ops_list_temp->op_args3 ) == 1 )
{
zero_point0 = _mm256_set1_ps( *(float *)post_ops_list_temp->op_args1 );
}
if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'R' ) )
{
// Scale/zp len cannot be > 1, since orignal n = 1.
F32_SCL_MULRND_AVX2(ymm8, selector1, zero_point0);
}
else
{
// If original output was columns major, then by the time
// kernel sees it, the matrix would be accessed as if it were
// transposed. Due to this the scale as well as zp array will
// be accessed by the ic index, and each scale/zp element
// corresponds to an entire row of the transposed output array,
// instead of an entire column.
if( post_ops_list_temp->scale_factor_len > 1 )
{
selector1 = _mm256_maskload_ps( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_i, store_mask );
}
if( *( dim_t*)post_ops_list_temp->op_args3 > 1 )
{
zero_point0 = _mm256_maskload_ps( ( float * )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i, store_mask );
}
F32_SCL_MULRND_AVX2(ymm8, selector1, zero_point0);
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_MATRIX_ADD_1x16F:
{
__m256 selector1 = _mm256_setzero_ps();
dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3;
__m256 scl_fctr1 = _mm256_setzero_ps();
// Even though different registers are used for scalar in column and
// row major case, all those registers will contain the same value.
// For column major, if m==1, then it means n=1 and scale_factor_len=1.
if ( post_ops_list_temp->scale_factor_len == 1 )
{
scl_fctr1 =
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
}
else
{
if ( ( *( char* )post_ops_list_temp->op_args2 == 'c' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'C' ) )
{
scl_fctr1 =
_mm256_maskload_ps( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_i + ( 0 * 16 ), store_mask );
}
}
float* matptr = ( float* )post_ops_list_temp->op_args1;
if( ldm == 1 )
{
selector1 = _mm256_maskload_ps(( matptr +
post_ops_attr.post_op_c_i ), store_mask );
selector1 = _mm256_mul_ps( selector1, scl_fctr1 );
ymm8 = _mm256_add_ps( selector1, ymm8 );
}
else
{
float ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( matptr +
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm256_maskload_ps( ctemp, store_mask );
selector1 = _mm256_mul_ps( selector1, scl_fctr1 ); \
ymm8 = _mm256_add_ps( selector1, ymm8 );
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_MATRIX_MUL_1x16F:
{
__m256 selector1 = _mm256_setzero_ps();
dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3;
__m256 scl_fctr1 = _mm256_setzero_ps();
// Even though different registers are used for scalar in column and
// row major case, all those registers will contain the same value.
// For column major, if m==1, then it means n=1 and scale_factor_len=1.
if ( post_ops_list_temp->scale_factor_len == 1 )
{
scl_fctr1 =
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
}
else
{
if ( ( *( char* )post_ops_list_temp->op_args2 == 'c' ) ||
( *( char* )post_ops_list_temp->op_args2 == 'C' ) )
{
scl_fctr1 =
_mm256_maskload_ps( ( float* )post_ops_list_temp->scale_factor +
post_ops_attr.post_op_c_i + ( 0 * 16 ), store_mask );
}
}
float* matptr = ( float* )post_ops_list_temp->op_args1;
if( ldm == 1 )
{
selector1 = _mm256_maskload_ps(( matptr +
post_ops_attr.post_op_c_i ), store_mask );
selector1 = _mm256_mul_ps( selector1, scl_fctr1 );
ymm8 = _mm256_mul_ps( selector1, ymm8 );
}
else
{
float ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( matptr +
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm256_maskload_ps( ctemp, store_mask );
selector1 = _mm256_mul_ps( selector1, scl_fctr1 ); \
ymm8 = _mm256_mul_ps( selector1, ymm8 );
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_SWISH_1x16F:
{
ymm7 =
_mm256_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) );
__m256i ex_out;
// c[0, 0-15]
SWISH_F32_AVX2_DEF(ymm8, ymm7, ymm0, ymm1, ymm2, ymm3, ymm4, ex_out);
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_TANH_1x16F:
{
__m256i ymm6;
// c[0, 0-15]
TANH_F32S_AVX2(ymm8, ymm0, ymm1, ymm2, ymm3, ymm4, ymm6)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_SIGMOID_1x16F:
{
__m256i ex_out;
// c[0, 0-15]
SIGMOID_F32_AVX2_DEF(ymm8, ymm0, ymm1, ymm2, ymm3, ymm4, ex_out);
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_1x16F_DISABLE:
{
if( rs_c == 1 )
{
_mm256_maskstore_ps ( c_use, store_mask, ymm8 );
}
else
{
// store c from ymm0
float ctemp[8];
_mm256_storeu_ps( ctemp, ymm8 );
for( dim_t i = 0; i < mr0; i++ )
{
c_use[i * rs_c] = ctemp[i];
}
}
}
post_ops_attr.post_op_c_i += MR;
} // mr loop
}
#endif // BLIS_ADDON_LPGEMM