Implemented GEMV kernel for m=1 case. (#5)

* Implemented GEMV kernel for m=1 case.

Description:

- Added a new GEMV kernel for AVX2 where m=1.
- Added a new GEMV kernel for AVX512 with ymm registers where m=1.
This commit is contained in:
Negi, Deepak
2025-05-13 16:33:04 +05:30
committed by GitHub
parent cd83fc38b5
commit 121d81df16
5 changed files with 2324 additions and 48 deletions

View File

@@ -62,6 +62,31 @@ typedef void (*lpgemm_rowvar_f32)
lpgemm_post_op_attr
);
typedef void (*lpgemv_m_one_ker_ft)
(
const dim_t,
const dim_t,
const float*,
const dim_t,
const dim_t,
const AOCL_MEMORY_TAG,
const float*,
dim_t,
const dim_t,
const AOCL_MEMORY_TAG,
float*,
const dim_t,
const dim_t,
const float,
const float,
dim_t,
const dim_t,
const dim_t,
const dim_t,
lpgemm_post_op*,
lpgemm_post_op_attr*
);
typedef void (*lpgemv_n_one_ker_ft)
(
const dim_t,
@@ -85,7 +110,7 @@ typedef void (*lpgemv_n_one_ker_ft)
lpgemm_post_op_attr*
);
typedef void (*lpgemv_n_one_a_pack_ft)
typedef void (*lpgemv_a_pack_ft)
(
float*,
const float*,
@@ -140,7 +165,7 @@ LPGEMV(float, float, float, f32f32f32of32)
dim_t MR;
lpgemv_n_one_ker_ft ker_fp;
lpgemv_n_one_a_pack_ft packa_fp;
lpgemv_a_pack_ft packa_fp;
// Workaround to select right kernel and blocksizes based on arch
// since GEMV parameters are not available in lpgemm context.
@@ -232,12 +257,6 @@ LPGEMV(float, float, float, f32f32f32of32)
post_op_list,
&post_ops_attr
);
if ( mtag_a == PACK )
{
// Release pack buffer for A.
bli_pba_release( rntm, &mem_a );
}
}
if ( ( mtag_a == PACK ) && ( bli_mem_is_alloc( &mem_a ) ) )
{
@@ -250,8 +269,24 @@ LPGEMV(float, float, float, f32f32f32of32)
}
else
{
// m = 1 case is not implemented yet for AVX2
#ifdef BLIS_KERNELS_ZEN4
lpgemv_m_one_ker_ft ker_fp;
lpgemv_a_pack_ft packa_fp;
#ifdef BLIS_KERNELS_ZEN4
if( lpgemm_get_enabled_arch() == BLIS_ARCH_ZEN3 )
{
ker_fp = lpgemv_m_one_f32f32f32of32_avx512_256;
packa_fp = packa_mr8_f32f32f32of32_col_major;
}
else
{
ker_fp = lpgemv_m_one_f32f32f32of32;
packa_fp = packa_mr16_f32f32f32of32_col_major;
}
#else
ker_fp = lpgemv_m_one_f32f32f32of32_avx2;
packa_fp = packa_mr8_f32f32f32of32_col_major;
#endif
// Compute the JC loop thread range for the current thread.
dim_t jc_start, jc_end;
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
@@ -272,7 +307,7 @@ LPGEMV(float, float, float, f32f32f32of32)
pack_a_buffer_f32f32f32of32 =
( float* ) bli_mem_buffer( &mem_a );
packa_mr16_f32f32f32of32_col_major
packa_fp
(
pack_a_buffer_f32f32f32of32,
a_use, rs_a, cs_a,
@@ -349,9 +384,8 @@ LPGEMV(float, float, float, f32f32f32of32)
//update post-op pointer
post_ops_attr.post_op_c_j = jc;
// Call kernel
lpgemv_m_one_f32f32f32of32
ker_fp
(
nc0, k,
a_use, rs_a_use, cs_a_use, mtag_a,
@@ -376,7 +410,6 @@ LPGEMV(float, float, float, f32f32f32of32)
{
bli_pba_release( rntm, &mem_b );
}
#endif // m == 1 case is not implemented for AVX2 yet.
}
}
@@ -384,13 +417,8 @@ LPGEMV(float, float, float, f32f32f32of32)
LPGEMM_5LOOP(float, float, float, f32f32f32of32)
{
// Handle using LPGEMV when m or/and n equal to 1
#ifdef BLIS_KERNELS_ZEN4
if ( ( ( (m == 1) && (lpgemm_get_enabled_arch() != BLIS_ARCH_ZEN3) ) || ( n == 1 ) ) &&
( bli_cpuid_is_avx512_supported() == TRUE ) )
#else
// m=1 case is not implemented yet for AVX2
if ( ( ( n == 1 ) ) && ( bli_cpuid_is_avx2fma3_supported() == TRUE ) )
#endif
if ( ( (m == 1) || ( n == 1 ) ) &&
( ( bli_cpuid_is_avx512_supported() == TRUE ) || ( bli_cpuid_is_avx2fma3_supported() == TRUE ) ) )
{
lpgemv_rowvar_f32f32f32of32(m, n, k,
a, rs_a, cs_a, mtag_a,
@@ -405,7 +433,6 @@ LPGEMM_5LOOP(float, float, float, f32f32f32of32)
c_downscale);
return;
}
// Query the context for various blocksizes.
const dim_t NC = lcntx->blksz.NC;
const dim_t KC = lcntx->blksz.KC;

View File

@@ -753,6 +753,8 @@ void lpgemv_m_one_ ## LP_SFX \
) \
LPGEMV_M_EQ1_KERN(float, float, float,f32f32f32of32);
LPGEMV_M_EQ1_KERN(float, float, float,f32f32f32of32_avx2);
LPGEMV_M_EQ1_KERN(float, float, float,f32f32f32of32_avx512_256);
LPGEMV_M_EQ1_KERN(bfloat16,bfloat16,float,bf16bf16f32of32);
LPGEMV_M_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32);
LPGEMV_M_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32);

View File

@@ -3401,7 +3401,7 @@ LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x16)
&&POST_OPS_TANH_1x16F,
&&POST_OPS_SIGMOID_1x16F
};
fflush(stdout);
// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
uint64_t k_iter = (uint64_t)k0;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff