mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Implemented GEMV kernel for m=1 case. (#5)
* Implemented GEMV kernel for m=1 case. Description: - Added a new GEMV kernel for AVX2 where m=1. - Added a new GEMV kernel for AVX512 with ymm registers where m=1.
This commit is contained in:
@@ -62,6 +62,31 @@ typedef void (*lpgemm_rowvar_f32)
|
||||
lpgemm_post_op_attr
|
||||
);
|
||||
|
||||
typedef void (*lpgemv_m_one_ker_ft)
|
||||
(
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const AOCL_MEMORY_TAG,
|
||||
const float*,
|
||||
dim_t,
|
||||
const dim_t,
|
||||
const AOCL_MEMORY_TAG,
|
||||
float*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const float,
|
||||
const float,
|
||||
dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
lpgemm_post_op*,
|
||||
lpgemm_post_op_attr*
|
||||
);
|
||||
|
||||
typedef void (*lpgemv_n_one_ker_ft)
|
||||
(
|
||||
const dim_t,
|
||||
@@ -85,7 +110,7 @@ typedef void (*lpgemv_n_one_ker_ft)
|
||||
lpgemm_post_op_attr*
|
||||
);
|
||||
|
||||
typedef void (*lpgemv_n_one_a_pack_ft)
|
||||
typedef void (*lpgemv_a_pack_ft)
|
||||
(
|
||||
float*,
|
||||
const float*,
|
||||
@@ -140,7 +165,7 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
|
||||
dim_t MR;
|
||||
lpgemv_n_one_ker_ft ker_fp;
|
||||
lpgemv_n_one_a_pack_ft packa_fp;
|
||||
lpgemv_a_pack_ft packa_fp;
|
||||
|
||||
// Workaround to select right kernel and blocksizes based on arch
|
||||
// since GEMV parameters are not available in lpgemm context.
|
||||
@@ -232,12 +257,6 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
post_op_list,
|
||||
&post_ops_attr
|
||||
);
|
||||
|
||||
if ( mtag_a == PACK )
|
||||
{
|
||||
// Release pack buffer for A.
|
||||
bli_pba_release( rntm, &mem_a );
|
||||
}
|
||||
}
|
||||
if ( ( mtag_a == PACK ) && ( bli_mem_is_alloc( &mem_a ) ) )
|
||||
{
|
||||
@@ -250,8 +269,24 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
}
|
||||
else
|
||||
{
|
||||
// m = 1 case is not implemented yet for AVX2
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
lpgemv_m_one_ker_ft ker_fp;
|
||||
lpgemv_a_pack_ft packa_fp;
|
||||
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
if( lpgemm_get_enabled_arch() == BLIS_ARCH_ZEN3 )
|
||||
{
|
||||
ker_fp = lpgemv_m_one_f32f32f32of32_avx512_256;
|
||||
packa_fp = packa_mr8_f32f32f32of32_col_major;
|
||||
}
|
||||
else
|
||||
{
|
||||
ker_fp = lpgemv_m_one_f32f32f32of32;
|
||||
packa_fp = packa_mr16_f32f32f32of32_col_major;
|
||||
}
|
||||
#else
|
||||
ker_fp = lpgemv_m_one_f32f32f32of32_avx2;
|
||||
packa_fp = packa_mr8_f32f32f32of32_col_major;
|
||||
#endif
|
||||
// Compute the JC loop thread range for the current thread.
|
||||
dim_t jc_start, jc_end;
|
||||
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
|
||||
@@ -272,7 +307,7 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
pack_a_buffer_f32f32f32of32 =
|
||||
( float* ) bli_mem_buffer( &mem_a );
|
||||
|
||||
packa_mr16_f32f32f32of32_col_major
|
||||
packa_fp
|
||||
(
|
||||
pack_a_buffer_f32f32f32of32,
|
||||
a_use, rs_a, cs_a,
|
||||
@@ -349,9 +384,8 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
|
||||
//update post-op pointer
|
||||
post_ops_attr.post_op_c_j = jc;
|
||||
|
||||
// Call kernel
|
||||
lpgemv_m_one_f32f32f32of32
|
||||
ker_fp
|
||||
(
|
||||
nc0, k,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a,
|
||||
@@ -376,7 +410,6 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
{
|
||||
bli_pba_release( rntm, &mem_b );
|
||||
}
|
||||
#endif // m == 1 case is not implemented for AVX2 yet.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -384,13 +417,8 @@ LPGEMV(float, float, float, f32f32f32of32)
|
||||
LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
{
|
||||
// Handle using LPGEMV when m or/and n equal to 1
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
if ( ( ( (m == 1) && (lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3) ) || ( n == 1 ) ) &&
|
||||
( bli_cpuid_is_avx512_supported() == TRUE ) )
|
||||
#else
|
||||
// m=1 case is not implemented yet for AVX2
|
||||
if ( ( ( n == 1 ) ) && ( bli_cpuid_is_avx2fma3_supported() == TRUE ) )
|
||||
#endif
|
||||
if ( ( (m == 1) || ( n == 1 ) ) &&
|
||||
( ( bli_cpuid_is_avx512_supported() == TRUE ) || ( bli_cpuid_is_avx2fma3_supported() == TRUE ) ) )
|
||||
{
|
||||
lpgemv_rowvar_f32f32f32of32(m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
@@ -405,7 +433,6 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
|
||||
c_downscale);
|
||||
return;
|
||||
}
|
||||
|
||||
// Query the context for various blocksizes.
|
||||
const dim_t NC = lcntx->blksz.NC;
|
||||
const dim_t KC = lcntx->blksz.KC;
|
||||
|
||||
@@ -753,6 +753,8 @@ void lpgemv_m_one_ ## LP_SFX \
|
||||
) \
|
||||
|
||||
LPGEMV_M_EQ1_KERN(float, float, float,f32f32f32of32);
|
||||
LPGEMV_M_EQ1_KERN(float, float, float,f32f32f32of32_avx2);
|
||||
LPGEMV_M_EQ1_KERN(float, float, float,f32f32f32of32_avx512_256);
|
||||
LPGEMV_M_EQ1_KERN(bfloat16,bfloat16,float,bf16bf16f32of32);
|
||||
LPGEMV_M_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32);
|
||||
LPGEMV_M_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32);
|
||||
|
||||
Reference in New Issue
Block a user