mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
GEMV support for S8S8S32O32 Symmetric Quantization
Introduced support for GEMV operations with group-level symmetric quantization for the S8S8S32032 API. Framework Changes: - Added macro definitions and function prototypes for GEMV with symmetric quantization in lpgemm_5loop_interface_apis.h and lpgemm_kernels.h. - LPGEMV_M_EQ1_KERN2 for the lpgemv_m_one_s8s8s32os32_sym_quant kernel, and - LPGEMV_N_EQ1_KERN2 for the lpgemv_n_one_s8s8s32os32_sym_quant kernel. - Implemented the main GEMV framework for symmetric quantization in lpgemm_s8s8s32_sym_quant.c. Kernel Changes: - lpgemv_m_one_s8s8s32os32_sym_quant for handling the case where M = 1 and implemented in lpgemv_m_kernel_s8_grp_amd512vnni.c. - lpgemv_n_one_s8s8s32os32_sym_quant for handling the case where N = 1 and implemented in lpgemv_n_kernel_s8_grp_amd512vnni.c. - Updated the buffer reordering logic for group quantization for N=1 cases in aocl_gemm_s8s8s32os32_utils.c. Notes - Ensure that group_size is a factor of both K (and KC when K > KC). - The B matrix must be provided in reordered format (mtag_b == REORDERED). AMD-Internal: [SWLCSG-3604]
This commit is contained in:
@@ -140,8 +140,7 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE_SYM_QUANT(s8s8s32os32_sym_quant)
|
||||
// loaded; and since k_dim needs to be atleast 4, having n_dim atleast 16
|
||||
// should give 4x16=64 elements, enough for 1 zmm register.The padding is
|
||||
// not rounded to NR (=64), since that would result in memory wastage.
|
||||
// Not supported yet
|
||||
#if 0 //def BLIS_KERNELS_ZEN4
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
dim_t n_reorder;
|
||||
if( n == 1 )
|
||||
{
|
||||
@@ -150,7 +149,6 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE_SYM_QUANT(s8s8s32os32_sym_quant)
|
||||
else
|
||||
{
|
||||
n_reorder = make_multiple_of_n( n, 16 );
|
||||
|
||||
}
|
||||
|
||||
// Extra space since packing does length in multiples of 4.
|
||||
@@ -364,24 +362,36 @@ AOCL_GEMM_REORDER_SYM_QUANT(int8_t,s8s8s32os32_sym_quant)
|
||||
{
|
||||
return; // A reorder not supported.
|
||||
}
|
||||
// Not supported yet
|
||||
#if 0 //def BLIS_KERNELS_ZEN4
|
||||
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
if( n == 1 )
|
||||
{
|
||||
int32_t* pack_b_column_sum = ( int32_t* ) ( reorder_buf_addr +
|
||||
( sizeof( int8_t ) * n * k ));
|
||||
// Calculate the address of the beginning of the column sum buffer that
|
||||
// is allocated after the reorder buffer.
|
||||
int32_t* pack_b_column_sum = ( int32_t* ) ( reorder_buf_addr +
|
||||
( k * sizeof( int8_t ) ) );
|
||||
|
||||
*pack_b_column_sum = 0;
|
||||
// NOTE We're working under the assumption that group_size is a factor
|
||||
// of k.
|
||||
for ( dim_t k0 = 0; k0 < k; k0 += group_size )
|
||||
{
|
||||
// Initialize the current column sum to 0.
|
||||
*pack_b_column_sum = 0;
|
||||
for ( dim_t group = 0; group < group_size; group++ )
|
||||
{
|
||||
reorder_buf_addr[ k0 + group ] = input_buf_addr[ ( k0 + group ) * rs_b ];
|
||||
*pack_b_column_sum += reorder_buf_addr[ k0 + group ];
|
||||
}
|
||||
|
||||
for( dim_t k0 = 0; k0 < k; k0++ )
|
||||
{
|
||||
reorder_buf_addr[k0] = input_buf_addr[ k0 * rs_b ];
|
||||
*pack_b_column_sum += reorder_buf_addr[k0];
|
||||
}
|
||||
*pack_b_column_sum *= 128;
|
||||
return;
|
||||
*pack_b_column_sum *= 128;
|
||||
// Move the pack_b_column_sum pointer one step to the next group.
|
||||
pack_b_column_sum += 1;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
|
||||
@@ -297,4 +297,33 @@ void lpgemv_rowvar_avx2_ ## LP_SFX \
|
||||
|
||||
LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32);
|
||||
|
||||
#define LPGEMV2(A_type, B_type, C_type, LP_SFX) \
|
||||
void lpgemv_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
const dim_t m, \
|
||||
const dim_t n, \
|
||||
const dim_t k, \
|
||||
const A_type *a, \
|
||||
const dim_t rs_a, \
|
||||
const dim_t cs_a, \
|
||||
const AOCL_MEMORY_TAG mtag_a, \
|
||||
const B_type *b, \
|
||||
const dim_t rs_b, \
|
||||
const dim_t cs_b, \
|
||||
const AOCL_MEMORY_TAG mtag_b, \
|
||||
float *c, \
|
||||
const dim_t rs_c, \
|
||||
const dim_t cs_c, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
rntm_t *rntm, \
|
||||
lpgemm_thrinfo_t *thread, \
|
||||
lpgemm_cntx_t *lcntx, \
|
||||
lpgemm_group_post_op *grp_post_op_list, \
|
||||
lpgemm_post_op *post_op_list, \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
) \
|
||||
|
||||
LPGEMV2(int8_t,int8_t,int32_t,s8s8s32os32_sym_quant);
|
||||
|
||||
#endif // LPGEMM_5LOOP_INTF_H
|
||||
|
||||
@@ -42,17 +42,27 @@
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_packa.h"
|
||||
|
||||
|
||||
// Not supported yet
|
||||
#if 0//def BLIS_KERNELS_ZEN4
|
||||
|
||||
LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
// NOTE
|
||||
// 1. Mandatory for matrix B to be reordered, i.e., mtag_b == REORDERED.
|
||||
// 2. K should be divisible by group_size.
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
LPGEMV2(int8_t,int8_t,int32_t,s8s8s32o32_sym_quant)
|
||||
{
|
||||
dim_t NC = lcntx->blksz.NC;
|
||||
dim_t KC = lcntx->blksz.KC;
|
||||
dim_t MC = lcntx->blksz.MC;
|
||||
dim_t NR = lcntx->blksz.NR;
|
||||
|
||||
// Group size should always be <= KC to make sure that entire group is processed
|
||||
// within one micro-kernel call.
|
||||
// If group size is greater than KC, then KC will be updated to group size.
|
||||
// This same change is done in reorder function to maintain consistency between
|
||||
// reorder and GEMM execution.
|
||||
if( grp_post_op_list->group_size > KC )
|
||||
{
|
||||
KC = grp_post_op_list->group_size;
|
||||
}
|
||||
|
||||
// Strides are updated based on matrix packing/reordering.
|
||||
int8_t* a_use = ( int8_t* )a;
|
||||
inc_t rs_a_use = rs_a;
|
||||
@@ -62,25 +72,49 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
dim_t rs_b_use = rs_b;
|
||||
inc_t cs_b_use = cs_b;
|
||||
|
||||
int32_t *c_use = NULL;
|
||||
|
||||
int32_t* pack_b_column_sum = NULL;
|
||||
float *c_use = NULL;
|
||||
|
||||
lpgemm_post_op_attr post_ops_attr;
|
||||
|
||||
post_ops_attr.c_stor_type = c_downscale;
|
||||
if (c_downscale < S32 || c_downscale == F32) post_ops_attr.buf_downscale = c;
|
||||
else post_ops_attr.buf_downscale = NULL;
|
||||
if ( c_downscale < F32 )
|
||||
{
|
||||
post_ops_attr.buf_downscale = c;
|
||||
}
|
||||
else
|
||||
{
|
||||
post_ops_attr.buf_downscale = NULL;
|
||||
}
|
||||
|
||||
siz_t mem_a_size_req = 0;
|
||||
siz_t mem_b_size_req = 0;
|
||||
|
||||
mem_t mem_a = BLIS_MEM_INITIALIZER;
|
||||
mem_t mem_b = BLIS_MEM_INITIALIZER;
|
||||
|
||||
int8_t* pack_b_buffer_s8s8s32os32;
|
||||
int8_t* pack_a_buffer_s8s8s32os32;
|
||||
|
||||
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
|
||||
lpgemm_grp_post_op_attr grp_post_ops_attr;
|
||||
|
||||
dim_t group_size = grp_post_op_list->group_size;
|
||||
|
||||
// Initialize group post ops attributes.
|
||||
grp_post_ops_attr.a_scale_factor = grp_post_op_list->a_scale_factor;
|
||||
grp_post_ops_attr.a_scale_factor_len = grp_post_op_list->a_scale_factor_len;
|
||||
grp_post_ops_attr.b_scale_factor = grp_post_op_list->b_scale_factor;
|
||||
grp_post_ops_attr.b_scale_factor_len = grp_post_op_list->b_scale_factor_len;
|
||||
grp_post_ops_attr.a_zp = grp_post_op_list->a_zp;
|
||||
grp_post_ops_attr.b_zp = grp_post_op_list->b_zp;
|
||||
grp_post_ops_attr.a_zp_len = grp_post_op_list->a_zp_len;
|
||||
grp_post_ops_attr.b_zp_len = grp_post_op_list->b_zp_len;
|
||||
grp_post_ops_attr.group_size = group_size;
|
||||
grp_post_ops_attr.sf_stor_type = grp_post_op_list->sf_stor_type;
|
||||
grp_post_ops_attr.zp_stor_type = grp_post_op_list->zp_stor_type;
|
||||
|
||||
dim_t num_groups = ( k + group_size - 1 ) / group_size;
|
||||
grp_post_ops_attr.grp_post_op_lda = num_groups;
|
||||
grp_post_ops_attr.grp_post_op_ldb = n;
|
||||
|
||||
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
|
||||
thrinfo_t thread_jc;
|
||||
thrinfo_t thread_ic;
|
||||
|
||||
@@ -91,51 +125,34 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
// Increased MR from 6 to 16 to make use of 32 ZMM registers
|
||||
dim_t MR = 16;
|
||||
|
||||
// pack B matrix if rs_b > 1
|
||||
if( ( mtag_b == PACK ) )
|
||||
{
|
||||
mem_b_size_req = sizeof( int8_t ) * k + sizeof( int32_t );
|
||||
|
||||
lpgemm_alloc_mem_panel
|
||||
(
|
||||
mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE,
|
||||
&mem_b, rntm
|
||||
);
|
||||
|
||||
pack_b_buffer_s8s8s32os32 = ( int8_t* ) bli_mem_buffer( &mem_b );
|
||||
|
||||
int32_t* pack_b_column_sum = ( int32_t* ) ( pack_b_buffer_s8s8s32os32 +
|
||||
( sizeof( int8_t ) * k ));
|
||||
|
||||
*pack_b_column_sum = 0;
|
||||
|
||||
for( dim_t k0 = 0; k0 < k; k0++ )
|
||||
{
|
||||
pack_b_buffer_s8s8s32os32[k0] = b[ k0*rs_b ];
|
||||
*pack_b_column_sum += pack_b_buffer_s8s8s32os32[k0];
|
||||
}
|
||||
*pack_b_column_sum *= 128;
|
||||
post_ops_attr.b_col_sum_vec = pack_b_column_sum;
|
||||
|
||||
b_use = pack_b_buffer_s8s8s32os32;
|
||||
rs_b_use = 1;
|
||||
cs_b_use = 1;
|
||||
}
|
||||
else if( mtag_b == REORDERED )
|
||||
if( mtag_b == REORDERED )
|
||||
{
|
||||
post_ops_attr.b_col_sum_vec = ( int32_t* )( b + k );
|
||||
}
|
||||
|
||||
// Compute the IC loop thread range for the current thread.
|
||||
else if( mtag_b == PACK )
|
||||
{
|
||||
// Unreordered B not supported.
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Unpacked B not supported.
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute the IC loop thread range for the current thread.
|
||||
dim_t ic_start, ic_end;
|
||||
thread_ic.n_way = ( thread_ic.n_way == 1 ) ?
|
||||
( thread->n_threads ) : ( thread_ic.n_way );
|
||||
( thread->n_threads ) : ( thread_ic.n_way );
|
||||
thread_ic.work_id = thread->tid;
|
||||
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
|
||||
|
||||
for (dim_t ic = ic_start; ic < ic_end; ic += MC)
|
||||
grp_post_ops_attr.grp_post_op_k = 0;
|
||||
for ( dim_t ic = ic_start; ic < ic_end; ic += MC )
|
||||
{
|
||||
dim_t mc0 = bli_min((ic_end - ic), MC);
|
||||
grp_post_ops_attr.grp_post_op_i = ic;
|
||||
|
||||
dim_t mc0 = bli_min( ( ic_end - ic ), MC );
|
||||
|
||||
const int8_t *a_use = a + ic * rs_a;
|
||||
c_use = c + ic * rs_c;
|
||||
@@ -154,7 +171,7 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
&mem_a, rntm
|
||||
);
|
||||
|
||||
pack_a_buffer_s8s8s32os32 = (int8_t*)bli_mem_buffer( &mem_a );
|
||||
pack_a_buffer_s8s8s32os32 = ( int8_t* )bli_mem_buffer( &mem_a );
|
||||
|
||||
( ( packa_s32 ) lcntx->packa_fun_ptr )
|
||||
(
|
||||
@@ -165,8 +182,9 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
);
|
||||
a_use = pack_a_buffer_s8s8s32os32;
|
||||
}
|
||||
|
||||
// Call lpgemv_n_one kernel
|
||||
lpgemv_n_one_s8s8s32os32
|
||||
lpgemv_n_one_s8s8s32os32_sym_quant
|
||||
(
|
||||
mc0, k,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a,
|
||||
@@ -174,6 +192,7 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
c_use, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
MR, KC,
|
||||
grp_post_ops_attr,
|
||||
post_op_list,
|
||||
&post_ops_attr
|
||||
);
|
||||
@@ -182,11 +201,11 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
// Release pack buffers
|
||||
if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) )
|
||||
{
|
||||
bli_pba_release(rntm, &mem_a);
|
||||
bli_pba_release( rntm, &mem_a );
|
||||
}
|
||||
if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) )
|
||||
{
|
||||
bli_pba_release(rntm, &mem_b);
|
||||
bli_pba_release( rntm, &mem_b );
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -195,19 +214,22 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
|
||||
dim_t jc_start, jc_end;
|
||||
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
|
||||
( thread->n_threads ) : ( thread_jc.n_way );
|
||||
( thread->n_threads ) : ( thread_jc.n_way );
|
||||
thread_jc.work_id = thread->tid;
|
||||
bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end);
|
||||
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
|
||||
|
||||
dim_t packb_min_NR = get_packb_s8s8s32o32_min_NR();
|
||||
|
||||
// kc needs to be a multiple of 4 so that it can be used with vpdpbusd
|
||||
// instruction. Padding is added in cases this condition is not
|
||||
// satisfied, and therefore the k offset used for packed/reordered
|
||||
// buffer needs to be updated.
|
||||
dim_t k_updated = make_multiple_of_n( k, 4 );
|
||||
dim_t n_updated = make_multiple_of_n( n, 16 );
|
||||
|
||||
rs_a_use = rs_a;
|
||||
cs_a_use = 4;
|
||||
|
||||
|
||||
if ( mtag_a == PACK )
|
||||
{
|
||||
mem_a_size_req = sizeof( uint8_t ) * k;
|
||||
@@ -228,90 +250,66 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
1, k,
|
||||
&rs_a_use, &cs_a_use
|
||||
);
|
||||
|
||||
get_packa_strides_mfringe_u8s8s32os32
|
||||
(
|
||||
&rs_a_use, &cs_a_use, gemm_MR, 1
|
||||
rs_a, cs_a, &rs_a_use, &cs_a_use, gemm_MR, 1
|
||||
);
|
||||
|
||||
a_use = pack_a_buffer_s8s8s32os32;
|
||||
}
|
||||
|
||||
for (dim_t jc = jc_start; jc < jc_end; jc += NC)
|
||||
grp_post_ops_attr.grp_post_op_k = 0;
|
||||
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
|
||||
{
|
||||
dim_t nc0 = bli_min((jc_end - jc), NC);
|
||||
grp_post_ops_attr.grp_post_op_j = jc;
|
||||
|
||||
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
|
||||
c_use = c + jc;
|
||||
|
||||
dim_t jc_cur_loop = jc;
|
||||
dim_t jc_cur_loop_rem = 0;
|
||||
dim_t n_sub_updated = 0;
|
||||
|
||||
if (mtag_b == REORDERED)
|
||||
dim_t kc0_updated = make_multiple_of_n( k, 4 );
|
||||
|
||||
if ( mtag_b == REORDERED )
|
||||
{
|
||||
get_B_panel_reordered_start_offset_width(
|
||||
jc, n, NC, packb_min_NR,
|
||||
&jc_cur_loop, &jc_cur_loop_rem,
|
||||
&nc0, &n_sub_updated );
|
||||
|
||||
b_use = (int8_t*) ( b + (jc_cur_loop * k_updated ) );
|
||||
b_use = ( int8_t* ) ( b +
|
||||
( jc_cur_loop * k_updated ) +
|
||||
( jc_cur_loop_rem * kc0_updated )
|
||||
);
|
||||
|
||||
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
|
||||
|
||||
post_ops_attr.b_col_sum_vec = ( (int32_t*)( b +
|
||||
( k_updated * n_updated ) ) )
|
||||
+ jc;
|
||||
post_ops_attr.b_col_sum_vec = ( ( int32_t* )( b +
|
||||
( k_updated * n_updated ) ) ) +
|
||||
jc;
|
||||
|
||||
grp_post_ops_attr.grp_post_op_sum_ld = n_updated;
|
||||
}
|
||||
else if( mtag_b == PACK )
|
||||
{
|
||||
dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR );
|
||||
|
||||
mem_b_size_req = sizeof( int8_t ) * nc0_updated * k_updated
|
||||
+ ( nc0_updated * sizeof( int32_t ) );
|
||||
|
||||
n_sub_updated = nc0_updated;
|
||||
|
||||
lpgemm_alloc_mem_panel
|
||||
(
|
||||
mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL,
|
||||
&mem_b, rntm
|
||||
);
|
||||
|
||||
pack_b_buffer_s8s8s32os32 =
|
||||
( int8_t* ) bli_mem_buffer( &mem_b );
|
||||
|
||||
|
||||
pack_b_column_sum = ( int32_t* )( pack_b_buffer_s8s8s32os32
|
||||
+ ( sizeof( int8_t ) * nc0_updated
|
||||
* k_updated ) );
|
||||
|
||||
for (dim_t idx = 0; idx < nc0; idx++ )
|
||||
{
|
||||
*( pack_b_column_sum + idx ) = 0;
|
||||
}
|
||||
|
||||
for ( dim_t pc = 0; pc < k; pc += KC )
|
||||
{
|
||||
dim_t kc0 = bli_min( ( k - pc ), KC );
|
||||
|
||||
( ( packb_s32_s8 )lcntx->packb_fun_ptr )
|
||||
(
|
||||
( pack_b_buffer_s8s8s32os32 ) +
|
||||
( n_sub_updated * pc ),
|
||||
pack_b_column_sum,
|
||||
( b + ( rs_b * pc ) + (jc * cs_b)),
|
||||
rs_b, cs_b, nc0, kc0, &rs_b_use, &cs_b_use
|
||||
);
|
||||
}
|
||||
|
||||
b_use = pack_b_buffer_s8s8s32os32;
|
||||
post_ops_attr.b_col_sum_vec = pack_b_column_sum;
|
||||
// Unreordered B not supported.
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Unpacked B not supported.
|
||||
return;
|
||||
}
|
||||
|
||||
post_ops_attr.post_op_c_i = 0;
|
||||
post_ops_attr.post_op_c_j = jc;
|
||||
post_ops_attr.rs_c_downscale = rs_c;
|
||||
post_ops_attr.b_sum_offset = 0;
|
||||
|
||||
lpgemv_m_one_s8s8s32os32
|
||||
lpgemv_m_one_s8s8s32os32_sym_quant
|
||||
(
|
||||
nc0, k,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a,
|
||||
@@ -321,13 +319,14 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
NR, KC,
|
||||
n_sub_updated,
|
||||
jc_cur_loop_rem,
|
||||
post_op_list,
|
||||
&post_ops_attr
|
||||
grp_post_ops_attr,
|
||||
post_op_list,
|
||||
&post_ops_attr
|
||||
);
|
||||
|
||||
if (mtag_b == REORDERED)
|
||||
if ( mtag_b == REORDERED )
|
||||
{
|
||||
adjust_B_panel_reordered_jc(&jc, jc_cur_loop);
|
||||
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
|
||||
}
|
||||
} // jc loop
|
||||
|
||||
@@ -342,8 +341,8 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// B should always be packed.
|
||||
LPGEMM_5LOOP2(int8_t,int8_t,int32_t,s8s8s32o32_sym_quant)
|
||||
{
|
||||
@@ -364,25 +363,38 @@ LPGEMM_5LOOP2(int8_t,int8_t,int32_t,s8s8s32o32_sym_quant)
|
||||
}
|
||||
if ( mtag_b == UNPACKED )
|
||||
{
|
||||
//Error: can only work with packed B now.
|
||||
// Error: can only work with packed B now.
|
||||
return;
|
||||
}
|
||||
// Not supported yet
|
||||
#if 0 //def BLIS_KERNELS_ZEN4
|
||||
|
||||
if( ( m == 1 ) || ( n == 1 ) )
|
||||
#ifdef BLIS_KERNELS_ZEN4
|
||||
// Invoke gemv kernels for m = 1 or n = 1.
|
||||
if ( ( ( m == 1 ) || ( n == 1 ) ) && ( mtag_b == REORDERED) )
|
||||
{
|
||||
lpgemv_rowvar_s8s8s32o32( m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha,
|
||||
beta,
|
||||
rntm,
|
||||
thread,
|
||||
lcntx,
|
||||
post_op_list,
|
||||
c_downscale );
|
||||
if ( ( k % grp_post_op_list->group_size != 0 ) ||
|
||||
( KC % grp_post_op_list->group_size != 0 ) )
|
||||
{
|
||||
bli_print_msg( "Quantized GEMV is only supported only when k and KC are "
|
||||
"divisible by group_size." , __FILE__, __LINE__ );
|
||||
return; // Error
|
||||
}
|
||||
|
||||
lpgemv_rowvar_s8s8s32o32_sym_quant
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha,
|
||||
beta,
|
||||
rntm,
|
||||
thread,
|
||||
lcntx,
|
||||
grp_post_op_list,
|
||||
post_op_list,
|
||||
c_downscale
|
||||
);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -889,6 +889,36 @@ LPGEMV_M_EQ1_KERN(bfloat16,bfloat16,float,bf16bf16f32of32);
|
||||
LPGEMV_M_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32);
|
||||
LPGEMV_M_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32);
|
||||
|
||||
|
||||
#define LPGEMV_M_EQ1_KERN2(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemv_m_one_ ## LP_SFX \
|
||||
( \
|
||||
const dim_t n0, \
|
||||
const dim_t k, \
|
||||
const A_type *a, \
|
||||
const dim_t rs_a, \
|
||||
const dim_t cs_a, \
|
||||
const AOCL_MEMORY_TAG mtag_a, \
|
||||
const B_type *b, \
|
||||
dim_t rs_b, \
|
||||
const dim_t cs_b, \
|
||||
const AOCL_MEMORY_TAG mtag_b, \
|
||||
float *c, \
|
||||
const dim_t rs_c, \
|
||||
const dim_t cs_c, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
dim_t NR, \
|
||||
const dim_t KC, \
|
||||
const dim_t n_sub_updated, \
|
||||
const dim_t jc_cur_loop_rem, \
|
||||
lpgemm_grp_post_op_attr grp_post_ops_attr, \
|
||||
lpgemm_post_op *post_op, \
|
||||
lpgemm_post_op_attr *post_op_attr \
|
||||
) \
|
||||
|
||||
LPGEMV_M_EQ1_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_sym_quant);
|
||||
|
||||
#define LPGEMV_N_EQ1_KERN(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemv_n_one_ ## LP_SFX \
|
||||
( \
|
||||
@@ -920,4 +950,32 @@ LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float,bf16bf16f32of32);
|
||||
LPGEMV_N_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32);
|
||||
LPGEMV_N_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32);
|
||||
|
||||
|
||||
#define LPGEMV_N_EQ1_KERN2(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemv_n_one_ ## LP_SFX \
|
||||
( \
|
||||
const dim_t m0, \
|
||||
const dim_t k, \
|
||||
const A_type *a, \
|
||||
const dim_t rs_a, \
|
||||
const dim_t cs_a, \
|
||||
const AOCL_MEMORY_TAG mtag_a, \
|
||||
const B_type *b, \
|
||||
const dim_t rs_b, \
|
||||
const dim_t cs_b, \
|
||||
const AOCL_MEMORY_TAG mtag_b, \
|
||||
float *c, \
|
||||
const dim_t rs_c, \
|
||||
const dim_t cs_c, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
const dim_t MR, \
|
||||
const dim_t KC, \
|
||||
lpgemm_grp_post_op_attr grp_post_ops_attr, \
|
||||
lpgemm_post_op *post_op, \
|
||||
lpgemm_post_op_attr *post_op_attr \
|
||||
) \
|
||||
|
||||
LPGEMV_N_EQ1_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_sym_quant);
|
||||
|
||||
#endif //BLIS_LPGEMM_KERN_H
|
||||
|
||||
Reference in New Issue
Block a user