GEMV support for S8S8S32O32 Symmetric Quantization

Introduced support for GEMV operations with group-level symmetric quantization for the S8S8S32032 API.

Framework Changes:
- Added macro definitions and function prototypes for GEMV with symmetric quantization in lpgemm_5loop_interface_apis.h and lpgemm_kernels.h.
  - LPGEMV_M_EQ1_KERN2 for the lpgemv_m_one_s8s8s32os32_sym_quant kernel, and
  - LPGEMV_N_EQ1_KERN2 for the lpgemv_n_one_s8s8s32os32_sym_quant kernel.
- Implemented the main GEMV framework for symmetric quantization in lpgemm_s8s8s32_sym_quant.c.

Kernel Changes:
- lpgemv_m_one_s8s8s32os32_sym_quant for handling the case where M = 1 and implemented in lpgemv_m_kernel_s8_grp_amd512vnni.c.
- lpgemv_n_one_s8s8s32os32_sym_quant for handling the case where N = 1 and implemented in lpgemv_n_kernel_s8_grp_amd512vnni.c.
- Updated the buffer reordering logic for group quantization for N=1 cases in aocl_gemm_s8s8s32os32_utils.c.

Notes
- Ensure that group_size is a factor of both K (and KC when K > KC).
- The B matrix must be provided in reordered format (mtag_b == REORDERED).

AMD-Internal: [SWLCSG-3604]
This commit is contained in:
Sharma, Arnav
2025-08-14 13:41:25 +05:30
committed by GitHub
parent 3a14417ce1
commit 76c4872718
6 changed files with 3280 additions and 142 deletions

View File

@@ -140,8 +140,7 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE_SYM_QUANT(s8s8s32os32_sym_quant)
// loaded; and since k_dim needs to be atleast 4, having n_dim atleast 16
// should give 4x16=64 elements, enough for 1 zmm register.The padding is
// not rounded to NR (=64), since that would result in memory wastage.
// Not supported yet
#if 0 //def BLIS_KERNELS_ZEN4
#ifdef BLIS_KERNELS_ZEN4
dim_t n_reorder;
if( n == 1 )
{
@@ -150,7 +149,6 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE_SYM_QUANT(s8s8s32os32_sym_quant)
else
{
n_reorder = make_multiple_of_n( n, 16 );
}
// Extra space since packing does length in multiples of 4.
@@ -364,24 +362,36 @@ AOCL_GEMM_REORDER_SYM_QUANT(int8_t,s8s8s32os32_sym_quant)
{
return; // A reorder not supported.
}
// Not supported yet
#if 0 //def BLIS_KERNELS_ZEN4
#ifdef BLIS_KERNELS_ZEN4
if( n == 1 )
{
int32_t* pack_b_column_sum = ( int32_t* ) ( reorder_buf_addr +
( sizeof( int8_t ) * n * k ));
// Calculate the address of the beginning of the column sum buffer that
// is allocated after the reorder buffer.
int32_t* pack_b_column_sum = ( int32_t* ) ( reorder_buf_addr +
( k * sizeof( int8_t ) ) );
*pack_b_column_sum = 0;
// NOTE We're working under the assumption that group_size is a factor
// of k.
for ( dim_t k0 = 0; k0 < k; k0 += group_size )
{
// Initialize the current column sum to 0.
*pack_b_column_sum = 0;
for ( dim_t group = 0; group < group_size; group++ )
{
reorder_buf_addr[ k0 + group ] = input_buf_addr[ ( k0 + group ) * rs_b ];
*pack_b_column_sum += reorder_buf_addr[ k0 + group ];
}
for( dim_t k0 = 0; k0 < k; k0++ )
{
reorder_buf_addr[k0] = input_buf_addr[ k0 * rs_b ];
*pack_b_column_sum += reorder_buf_addr[k0];
}
*pack_b_column_sum *= 128;
return;
*pack_b_column_sum *= 128;
// Move the pack_b_column_sum pointer one step to the next group.
pack_b_column_sum += 1;
}
return;
}
#endif
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;

View File

@@ -297,4 +297,33 @@ void lpgemv_rowvar_avx2_ ## LP_SFX \
LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32);
#define LPGEMV2(A_type, B_type, C_type, LP_SFX) \
void lpgemv_rowvar_ ## LP_SFX \
( \
const dim_t m, \
const dim_t n, \
const dim_t k, \
const A_type *a, \
const dim_t rs_a, \
const dim_t cs_a, \
const AOCL_MEMORY_TAG mtag_a, \
const B_type *b, \
const dim_t rs_b, \
const dim_t cs_b, \
const AOCL_MEMORY_TAG mtag_b, \
float *c, \
const dim_t rs_c, \
const dim_t cs_c, \
const C_type alpha, \
const C_type beta, \
rntm_t *rntm, \
lpgemm_thrinfo_t *thread, \
lpgemm_cntx_t *lcntx, \
lpgemm_group_post_op *grp_post_op_list, \
lpgemm_post_op *post_op_list, \
AOCL_STORAGE_TYPE c_downscale \
) \
LPGEMV2(int8_t,int8_t,int32_t,s8s8s32os32_sym_quant);
#endif // LPGEMM_5LOOP_INTF_H

View File

@@ -42,17 +42,27 @@
#include "lpgemm_config.h"
#include "lpgemm_packa.h"
// Not supported yet
#if 0//def BLIS_KERNELS_ZEN4
LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
// NOTE
// 1. Mandatory for matrix B to be reordered, i.e., mtag_b == REORDERED.
// 2. K should be divisible by group_size.
#ifdef BLIS_KERNELS_ZEN4
LPGEMV2(int8_t,int8_t,int32_t,s8s8s32o32_sym_quant)
{
dim_t NC = lcntx->blksz.NC;
dim_t KC = lcntx->blksz.KC;
dim_t MC = lcntx->blksz.MC;
dim_t NR = lcntx->blksz.NR;
// Group size should always be <= KC to make sure that entire group is processed
// within one micro-kernel call.
// If group size is greater than KC, then KC will be updated to group size.
// This same change is done in reorder function to maintain consistency between
// reorder and GEMM execution.
if( grp_post_op_list->group_size > KC )
{
KC = grp_post_op_list->group_size;
}
// Strides are updated based on matrix packing/reordering.
int8_t* a_use = ( int8_t* )a;
inc_t rs_a_use = rs_a;
@@ -62,25 +72,49 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
dim_t rs_b_use = rs_b;
inc_t cs_b_use = cs_b;
int32_t *c_use = NULL;
int32_t* pack_b_column_sum = NULL;
float *c_use = NULL;
lpgemm_post_op_attr post_ops_attr;
post_ops_attr.c_stor_type = c_downscale;
if (c_downscale < S32 || c_downscale == F32) post_ops_attr.buf_downscale = c;
else post_ops_attr.buf_downscale = NULL;
if ( c_downscale < F32 )
{
post_ops_attr.buf_downscale = c;
}
else
{
post_ops_attr.buf_downscale = NULL;
}
siz_t mem_a_size_req = 0;
siz_t mem_b_size_req = 0;
mem_t mem_a = BLIS_MEM_INITIALIZER;
mem_t mem_b = BLIS_MEM_INITIALIZER;
int8_t* pack_b_buffer_s8s8s32os32;
int8_t* pack_a_buffer_s8s8s32os32;
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
lpgemm_grp_post_op_attr grp_post_ops_attr;
dim_t group_size = grp_post_op_list->group_size;
// Initialize group post ops attributes.
grp_post_ops_attr.a_scale_factor = grp_post_op_list->a_scale_factor;
grp_post_ops_attr.a_scale_factor_len = grp_post_op_list->a_scale_factor_len;
grp_post_ops_attr.b_scale_factor = grp_post_op_list->b_scale_factor;
grp_post_ops_attr.b_scale_factor_len = grp_post_op_list->b_scale_factor_len;
grp_post_ops_attr.a_zp = grp_post_op_list->a_zp;
grp_post_ops_attr.b_zp = grp_post_op_list->b_zp;
grp_post_ops_attr.a_zp_len = grp_post_op_list->a_zp_len;
grp_post_ops_attr.b_zp_len = grp_post_op_list->b_zp_len;
grp_post_ops_attr.group_size = group_size;
grp_post_ops_attr.sf_stor_type = grp_post_op_list->sf_stor_type;
grp_post_ops_attr.zp_stor_type = grp_post_op_list->zp_stor_type;
dim_t num_groups = ( k + group_size - 1 ) / group_size;
grp_post_ops_attr.grp_post_op_lda = num_groups;
grp_post_ops_attr.grp_post_op_ldb = n;
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
thrinfo_t thread_jc;
thrinfo_t thread_ic;
@@ -91,51 +125,34 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
// Increased MR from 6 to 16 to make use of 32 ZMM registers
dim_t MR = 16;
// pack B matrix if rs_b > 1
if( ( mtag_b == PACK ) )
{
mem_b_size_req = sizeof( int8_t ) * k + sizeof( int32_t );
lpgemm_alloc_mem_panel
(
mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE,
&mem_b, rntm
);
pack_b_buffer_s8s8s32os32 = ( int8_t* ) bli_mem_buffer( &mem_b );
int32_t* pack_b_column_sum = ( int32_t* ) ( pack_b_buffer_s8s8s32os32 +
( sizeof( int8_t ) * k ));
*pack_b_column_sum = 0;
for( dim_t k0 = 0; k0 < k; k0++ )
{
pack_b_buffer_s8s8s32os32[k0] = b[ k0*rs_b ];
*pack_b_column_sum += pack_b_buffer_s8s8s32os32[k0];
}
*pack_b_column_sum *= 128;
post_ops_attr.b_col_sum_vec = pack_b_column_sum;
b_use = pack_b_buffer_s8s8s32os32;
rs_b_use = 1;
cs_b_use = 1;
}
else if( mtag_b == REORDERED )
if( mtag_b == REORDERED )
{
post_ops_attr.b_col_sum_vec = ( int32_t* )( b + k );
}
// Compute the IC loop thread range for the current thread.
else if( mtag_b == PACK )
{
// Unreordered B not supported.
return;
}
else
{
// Unpacked B not supported.
return;
}
// Compute the IC loop thread range for the current thread.
dim_t ic_start, ic_end;
thread_ic.n_way = ( thread_ic.n_way == 1 ) ?
( thread->n_threads ) : ( thread_ic.n_way );
( thread->n_threads ) : ( thread_ic.n_way );
thread_ic.work_id = thread->tid;
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
for (dim_t ic = ic_start; ic < ic_end; ic += MC)
grp_post_ops_attr.grp_post_op_k = 0;
for ( dim_t ic = ic_start; ic < ic_end; ic += MC )
{
dim_t mc0 = bli_min((ic_end - ic), MC);
grp_post_ops_attr.grp_post_op_i = ic;
dim_t mc0 = bli_min( ( ic_end - ic ), MC );
const int8_t *a_use = a + ic * rs_a;
c_use = c + ic * rs_c;
@@ -154,7 +171,7 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
&mem_a, rntm
);
pack_a_buffer_s8s8s32os32 = (int8_t*)bli_mem_buffer( &mem_a );
pack_a_buffer_s8s8s32os32 = ( int8_t* )bli_mem_buffer( &mem_a );
( ( packa_s32 ) lcntx->packa_fun_ptr )
(
@@ -165,8 +182,9 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
);
a_use = pack_a_buffer_s8s8s32os32;
}
// Call lpgemv_n_one kernel
lpgemv_n_one_s8s8s32os32
lpgemv_n_one_s8s8s32os32_sym_quant
(
mc0, k,
a_use, rs_a_use, cs_a_use, mtag_a,
@@ -174,6 +192,7 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
c_use, rs_c, cs_c,
alpha, beta,
MR, KC,
grp_post_ops_attr,
post_op_list,
&post_ops_attr
);
@@ -182,11 +201,11 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
// Release pack buffers
if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) )
{
bli_pba_release(rntm, &mem_a);
bli_pba_release( rntm, &mem_a );
}
if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) )
{
bli_pba_release(rntm, &mem_b);
bli_pba_release( rntm, &mem_b );
}
}
else
@@ -195,19 +214,22 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
dim_t jc_start, jc_end;
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
( thread->n_threads ) : ( thread_jc.n_way );
( thread->n_threads ) : ( thread_jc.n_way );
thread_jc.work_id = thread->tid;
bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end);
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
dim_t packb_min_NR = get_packb_s8s8s32o32_min_NR();
// kc needs to be a multiple of 4 so that it can be used with vpdpbusd
// instruction. Padding is added in cases this condition is not
// satisfied, and therefore the k offset used for packed/reordered
// buffer needs to be updated.
dim_t k_updated = make_multiple_of_n( k, 4 );
dim_t n_updated = make_multiple_of_n( n, 16 );
rs_a_use = rs_a;
cs_a_use = 4;
if ( mtag_a == PACK )
{
mem_a_size_req = sizeof( uint8_t ) * k;
@@ -228,90 +250,66 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
1, k,
&rs_a_use, &cs_a_use
);
get_packa_strides_mfringe_u8s8s32os32
(
&rs_a_use, &cs_a_use, gemm_MR, 1
rs_a, cs_a, &rs_a_use, &cs_a_use, gemm_MR, 1
);
a_use = pack_a_buffer_s8s8s32os32;
}
for (dim_t jc = jc_start; jc < jc_end; jc += NC)
grp_post_ops_attr.grp_post_op_k = 0;
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
{
dim_t nc0 = bli_min((jc_end - jc), NC);
grp_post_ops_attr.grp_post_op_j = jc;
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
c_use = c + jc;
dim_t jc_cur_loop = jc;
dim_t jc_cur_loop_rem = 0;
dim_t n_sub_updated = 0;
if (mtag_b == REORDERED)
dim_t kc0_updated = make_multiple_of_n( k, 4 );
if ( mtag_b == REORDERED )
{
get_B_panel_reordered_start_offset_width(
jc, n, NC, packb_min_NR,
&jc_cur_loop, &jc_cur_loop_rem,
&nc0, &n_sub_updated );
b_use = (int8_t*) ( b + (jc_cur_loop * k_updated ) );
b_use = ( int8_t* ) ( b +
( jc_cur_loop * k_updated ) +
( jc_cur_loop_rem * kc0_updated )
);
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
post_ops_attr.b_col_sum_vec = ( (int32_t*)( b +
( k_updated * n_updated ) ) )
+ jc;
post_ops_attr.b_col_sum_vec = ( ( int32_t* )( b +
( k_updated * n_updated ) ) ) +
jc;
grp_post_ops_attr.grp_post_op_sum_ld = n_updated;
}
else if( mtag_b == PACK )
{
dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR );
mem_b_size_req = sizeof( int8_t ) * nc0_updated * k_updated
+ ( nc0_updated * sizeof( int32_t ) );
n_sub_updated = nc0_updated;
lpgemm_alloc_mem_panel
(
mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL,
&mem_b, rntm
);
pack_b_buffer_s8s8s32os32 =
( int8_t* ) bli_mem_buffer( &mem_b );
pack_b_column_sum = ( int32_t* )( pack_b_buffer_s8s8s32os32
+ ( sizeof( int8_t ) * nc0_updated
* k_updated ) );
for (dim_t idx = 0; idx < nc0; idx++ )
{
*( pack_b_column_sum + idx ) = 0;
}
for ( dim_t pc = 0; pc < k; pc += KC )
{
dim_t kc0 = bli_min( ( k - pc ), KC );
( ( packb_s32_s8 )lcntx->packb_fun_ptr )
(
( pack_b_buffer_s8s8s32os32 ) +
( n_sub_updated * pc ),
pack_b_column_sum,
( b + ( rs_b * pc ) + (jc * cs_b)),
rs_b, cs_b, nc0, kc0, &rs_b_use, &cs_b_use
);
}
b_use = pack_b_buffer_s8s8s32os32;
post_ops_attr.b_col_sum_vec = pack_b_column_sum;
// Unreordered B not supported.
return;
}
else
{
// Unpacked B not supported.
return;
}
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_j = jc;
post_ops_attr.rs_c_downscale = rs_c;
post_ops_attr.b_sum_offset = 0;
lpgemv_m_one_s8s8s32os32
lpgemv_m_one_s8s8s32os32_sym_quant
(
nc0, k,
a_use, rs_a_use, cs_a_use, mtag_a,
@@ -321,13 +319,14 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
NR, KC,
n_sub_updated,
jc_cur_loop_rem,
post_op_list,
&post_ops_attr
grp_post_ops_attr,
post_op_list,
&post_ops_attr
);
if (mtag_b == REORDERED)
if ( mtag_b == REORDERED )
{
adjust_B_panel_reordered_jc(&jc, jc_cur_loop);
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
}
} // jc loop
@@ -342,8 +341,8 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
}
}
}
#endif
// B should always be packed.
LPGEMM_5LOOP2(int8_t,int8_t,int32_t,s8s8s32o32_sym_quant)
{
@@ -364,25 +363,38 @@ LPGEMM_5LOOP2(int8_t,int8_t,int32_t,s8s8s32o32_sym_quant)
}
if ( mtag_b == UNPACKED )
{
//Error: can only work with packed B now.
// Error: can only work with packed B now.
return;
}
// Not supported yet
#if 0 //def BLIS_KERNELS_ZEN4
if( ( m == 1 ) || ( n == 1 ) )
#ifdef BLIS_KERNELS_ZEN4
// Invoke gemv kernels for m = 1 or n = 1.
if ( ( ( m == 1 ) || ( n == 1 ) ) && ( mtag_b == REORDERED) )
{
lpgemv_rowvar_s8s8s32o32( m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha,
beta,
rntm,
thread,
lcntx,
post_op_list,
c_downscale );
if ( ( k % grp_post_op_list->group_size != 0 ) ||
( KC % grp_post_op_list->group_size != 0 ) )
{
bli_print_msg( "Quantized GEMV is only supported only when k and KC are "
"divisible by group_size." , __FILE__, __LINE__ );
return; // Error
}
lpgemv_rowvar_s8s8s32o32_sym_quant
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha,
beta,
rntm,
thread,
lcntx,
grp_post_op_list,
post_op_list,
c_downscale
);
return;
}

View File

@@ -889,6 +889,36 @@ LPGEMV_M_EQ1_KERN(bfloat16,bfloat16,float,bf16bf16f32of32);
LPGEMV_M_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32);
LPGEMV_M_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32);
#define LPGEMV_M_EQ1_KERN2(A_type,B_type,C_type,LP_SFX) \
void lpgemv_m_one_ ## LP_SFX \
( \
const dim_t n0, \
const dim_t k, \
const A_type *a, \
const dim_t rs_a, \
const dim_t cs_a, \
const AOCL_MEMORY_TAG mtag_a, \
const B_type *b, \
dim_t rs_b, \
const dim_t cs_b, \
const AOCL_MEMORY_TAG mtag_b, \
float *c, \
const dim_t rs_c, \
const dim_t cs_c, \
const C_type alpha, \
const C_type beta, \
dim_t NR, \
const dim_t KC, \
const dim_t n_sub_updated, \
const dim_t jc_cur_loop_rem, \
lpgemm_grp_post_op_attr grp_post_ops_attr, \
lpgemm_post_op *post_op, \
lpgemm_post_op_attr *post_op_attr \
) \
LPGEMV_M_EQ1_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_sym_quant);
#define LPGEMV_N_EQ1_KERN(A_type,B_type,C_type,LP_SFX) \
void lpgemv_n_one_ ## LP_SFX \
( \
@@ -920,4 +950,32 @@ LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float,bf16bf16f32of32);
LPGEMV_N_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32);
LPGEMV_N_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32);
#define LPGEMV_N_EQ1_KERN2(A_type,B_type,C_type,LP_SFX) \
void lpgemv_n_one_ ## LP_SFX \
( \
const dim_t m0, \
const dim_t k, \
const A_type *a, \
const dim_t rs_a, \
const dim_t cs_a, \
const AOCL_MEMORY_TAG mtag_a, \
const B_type *b, \
const dim_t rs_b, \
const dim_t cs_b, \
const AOCL_MEMORY_TAG mtag_b, \
float *c, \
const dim_t rs_c, \
const dim_t cs_c, \
const C_type alpha, \
const C_type beta, \
const dim_t MR, \
const dim_t KC, \
lpgemm_grp_post_op_attr grp_post_ops_attr, \
lpgemm_post_op *post_op, \
lpgemm_post_op_attr *post_op_attr \
) \
LPGEMV_N_EQ1_KERN2(int8_t,int8_t,int32_t,s8s8s32os32_sym_quant);
#endif //BLIS_LPGEMM_KERN_H