Files
blis/addon/aocl_gemm/frame/s8s8s32/lpgemm_s8s8s32_sym_quant.c
Sharma, Arnav 76c4872718 GEMV support for S8S8S32O32 Symmetric Quantization
Introduced support for GEMV operations with group-level symmetric quantization for the S8S8S32032 API.

Framework Changes:
- Added macro definitions and function prototypes for GEMV with symmetric quantization in lpgemm_5loop_interface_apis.h and lpgemm_kernels.h.
  - LPGEMV_M_EQ1_KERN2 for the lpgemv_m_one_s8s8s32os32_sym_quant kernel, and
  - LPGEMV_N_EQ1_KERN2 for the lpgemv_n_one_s8s8s32os32_sym_quant kernel.
- Implemented the main GEMV framework for symmetric quantization in lpgemm_s8s8s32_sym_quant.c.

Kernel Changes:
- lpgemv_m_one_s8s8s32os32_sym_quant for handling the case where M = 1 and implemented in lpgemv_m_kernel_s8_grp_amd512vnni.c.
- lpgemv_n_one_s8s8s32os32_sym_quant for handling the case where N = 1 and implemented in lpgemv_n_kernel_s8_grp_amd512vnni.c.
- Updated the buffer reordering logic for group quantization for N=1 cases in aocl_gemm_s8s8s32os32_utils.c.

Notes
- Ensure that group_size is a factor of both K (and KC when K > KC).
- The B matrix must be provided in reordered format (mtag_b == REORDERED).

AMD-Internal: [SWLCSG-3604]
2025-08-14 13:41:25 +05:30

905 lines
26 KiB
C

/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_packa_s8.h"
#include "lpgemm_packb_s8.h"
#include "lpgemm_kernels.h"
#include "lpgemm_utils_s8.h"
#include "lpgemm_thrinfo_utils.h"
#include "lpgemm_config.h"
#include "lpgemm_packa.h"
// NOTE
// 1. Mandatory for matrix B to be reordered, i.e., mtag_b == REORDERED.
// 2. K should be divisible by group_size.
#ifdef BLIS_KERNELS_ZEN4
LPGEMV2(int8_t,int8_t,int32_t,s8s8s32o32_sym_quant)
{
dim_t NC = lcntx->blksz.NC;
dim_t KC = lcntx->blksz.KC;
dim_t MC = lcntx->blksz.MC;
dim_t NR = lcntx->blksz.NR;
// Group size should always be <= KC to make sure that entire group is processed
// within one micro-kernel call.
// If group size is greater than KC, then KC will be updated to group size.
// This same change is done in reorder function to maintain consistency between
// reorder and GEMM execution.
if( grp_post_op_list->group_size > KC )
{
KC = grp_post_op_list->group_size;
}
// Strides are updated based on matrix packing/reordering.
int8_t* a_use = ( int8_t* )a;
inc_t rs_a_use = rs_a;
inc_t cs_a_use = cs_a;
int8_t* b_use = ( int8_t* )b;
dim_t rs_b_use = rs_b;
inc_t cs_b_use = cs_b;
float *c_use = NULL;
lpgemm_post_op_attr post_ops_attr;
post_ops_attr.c_stor_type = c_downscale;
if ( c_downscale < F32 )
{
post_ops_attr.buf_downscale = c;
}
else
{
post_ops_attr.buf_downscale = NULL;
}
siz_t mem_a_size_req = 0;
mem_t mem_a = BLIS_MEM_INITIALIZER;
mem_t mem_b = BLIS_MEM_INITIALIZER;
int8_t* pack_a_buffer_s8s8s32os32;
lpgemm_grp_post_op_attr grp_post_ops_attr;
dim_t group_size = grp_post_op_list->group_size;
// Initialize group post ops attributes.
grp_post_ops_attr.a_scale_factor = grp_post_op_list->a_scale_factor;
grp_post_ops_attr.a_scale_factor_len = grp_post_op_list->a_scale_factor_len;
grp_post_ops_attr.b_scale_factor = grp_post_op_list->b_scale_factor;
grp_post_ops_attr.b_scale_factor_len = grp_post_op_list->b_scale_factor_len;
grp_post_ops_attr.a_zp = grp_post_op_list->a_zp;
grp_post_ops_attr.b_zp = grp_post_op_list->b_zp;
grp_post_ops_attr.a_zp_len = grp_post_op_list->a_zp_len;
grp_post_ops_attr.b_zp_len = grp_post_op_list->b_zp_len;
grp_post_ops_attr.group_size = group_size;
grp_post_ops_attr.sf_stor_type = grp_post_op_list->sf_stor_type;
grp_post_ops_attr.zp_stor_type = grp_post_op_list->zp_stor_type;
dim_t num_groups = ( k + group_size - 1 ) / group_size;
grp_post_ops_attr.grp_post_op_lda = num_groups;
grp_post_ops_attr.grp_post_op_ldb = n;
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
thrinfo_t thread_jc;
thrinfo_t thread_ic;
lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic );
if( n == 1 )
{
// Increased MR from 6 to 16 to make use of 32 ZMM registers
dim_t MR = 16;
if( mtag_b == REORDERED )
{
post_ops_attr.b_col_sum_vec = ( int32_t* )( b + k );
}
else if( mtag_b == PACK )
{
// Unreordered B not supported.
return;
}
else
{
// Unpacked B not supported.
return;
}
// Compute the IC loop thread range for the current thread.
dim_t ic_start, ic_end;
thread_ic.n_way = ( thread_ic.n_way == 1 ) ?
( thread->n_threads ) : ( thread_ic.n_way );
thread_ic.work_id = thread->tid;
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
grp_post_ops_attr.grp_post_op_k = 0;
for ( dim_t ic = ic_start; ic < ic_end; ic += MC )
{
grp_post_ops_attr.grp_post_op_i = ic;
dim_t mc0 = bli_min( ( ic_end - ic ), MC );
const int8_t *a_use = a + ic * rs_a;
c_use = c + ic * rs_c;
post_ops_attr.post_op_c_i = ic;
post_ops_attr.post_op_c_j = 0;
post_ops_attr.rs_c_downscale = rs_c;
if( mtag_a == PACK )
{
mem_a_size_req = sizeof( int8_t ) * mc0 * k;
lpgemm_alloc_mem_panel
(
mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE,
&mem_a, rntm
);
pack_a_buffer_s8s8s32os32 = ( int8_t* )bli_mem_buffer( &mem_a );
( ( packa_s32 ) lcntx->packa_fun_ptr )
(
( uint8_t* ) pack_a_buffer_s8s8s32os32,
( uint8_t* )( a + ( rs_a * ic )), rs_a, cs_a,
mc0, k,
&rs_a_use, &cs_a_use
);
a_use = pack_a_buffer_s8s8s32os32;
}
// Call lpgemv_n_one kernel
lpgemv_n_one_s8s8s32os32_sym_quant
(
mc0, k,
a_use, rs_a_use, cs_a_use, mtag_a,
b_use, rs_b_use, cs_b_use, mtag_b,
c_use, rs_c, cs_c,
alpha, beta,
MR, KC,
grp_post_ops_attr,
post_op_list,
&post_ops_attr
);
}
// Release pack buffers
if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) )
{
bli_pba_release( rntm, &mem_a );
}
if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) )
{
bli_pba_release( rntm, &mem_b );
}
}
else
{
dim_t gemm_MR = lcntx->blksz.MR;
dim_t jc_start, jc_end;
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
( thread->n_threads ) : ( thread_jc.n_way );
thread_jc.work_id = thread->tid;
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
dim_t packb_min_NR = get_packb_s8s8s32o32_min_NR();
// kc needs to be a multiple of 4 so that it can be used with vpdpbusd
// instruction. Padding is added in cases this condition is not
// satisfied, and therefore the k offset used for packed/reordered
// buffer needs to be updated.
dim_t k_updated = make_multiple_of_n( k, 4 );
dim_t n_updated = make_multiple_of_n( n, 16 );
rs_a_use = rs_a;
cs_a_use = 4;
if ( mtag_a == PACK )
{
mem_a_size_req = sizeof( uint8_t ) * k;
lpgemm_alloc_mem_panel
(
mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE,
&mem_a, rntm
);
pack_a_buffer_s8s8s32os32 =
( int8_t* ) bli_mem_buffer( &mem_a );
( ( packa_s32 )lcntx->packa_fun_ptr )
(
( uint8_t* )pack_a_buffer_s8s8s32os32,
( uint8_t* )a, rs_a, cs_a,
1, k,
&rs_a_use, &cs_a_use
);
get_packa_strides_mfringe_u8s8s32os32
(
rs_a, cs_a, &rs_a_use, &cs_a_use, gemm_MR, 1
);
a_use = pack_a_buffer_s8s8s32os32;
}
grp_post_ops_attr.grp_post_op_k = 0;
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
{
grp_post_ops_attr.grp_post_op_j = jc;
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
c_use = c + jc;
dim_t jc_cur_loop = jc;
dim_t jc_cur_loop_rem = 0;
dim_t n_sub_updated = 0;
dim_t kc0_updated = make_multiple_of_n( k, 4 );
if ( mtag_b == REORDERED )
{
get_B_panel_reordered_start_offset_width(
jc, n, NC, packb_min_NR,
&jc_cur_loop, &jc_cur_loop_rem,
&nc0, &n_sub_updated );
b_use = ( int8_t* ) ( b +
( jc_cur_loop * k_updated ) +
( jc_cur_loop_rem * kc0_updated )
);
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
post_ops_attr.b_col_sum_vec = ( ( int32_t* )( b +
( k_updated * n_updated ) ) ) +
jc;
grp_post_ops_attr.grp_post_op_sum_ld = n_updated;
}
else if( mtag_b == PACK )
{
// Unreordered B not supported.
return;
}
else
{
// Unpacked B not supported.
return;
}
post_ops_attr.post_op_c_i = 0;
post_ops_attr.post_op_c_j = jc;
post_ops_attr.rs_c_downscale = rs_c;
post_ops_attr.b_sum_offset = 0;
lpgemv_m_one_s8s8s32os32_sym_quant
(
nc0, k,
a_use, rs_a_use, cs_a_use, mtag_a,
b_use, rs_b_use, cs_b_use, mtag_b,
c_use, rs_c, cs_c,
alpha, beta,
NR, KC,
n_sub_updated,
jc_cur_loop_rem,
grp_post_ops_attr,
post_op_list,
&post_ops_attr
);
if ( mtag_b == REORDERED )
{
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
}
} // jc loop
// Release pack buffers.
if ( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) )
{
bli_pba_release( rntm, &mem_b );
}
if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) )
{
bli_pba_release(rntm, &mem_a);
}
}
}
#endif
// B should always be packed.
LPGEMM_5LOOP2(int8_t,int8_t,int32_t,s8s8s32o32_sym_quant)
{
dim_t NC = lcntx->blksz.NC;
dim_t KC = lcntx->blksz.KC;
dim_t MC = lcntx->blksz.MC;
dim_t NR = lcntx->blksz.NR;
dim_t MR = lcntx->blksz.MR;
// Group size should always be <= KC to make sure that entire group is processed
// within one micro-kernel call.
// If group size is greater than KC, then KC will be updated to group size.
// This same change is done in reorder function to maintain consistency between
// reorder and GEMM execution.
if( grp_post_op_list->group_size > KC )
{
KC = grp_post_op_list->group_size;
}
if ( mtag_b == UNPACKED )
{
// Error: can only work with packed B now.
return;
}
#ifdef BLIS_KERNELS_ZEN4
// Invoke gemv kernels for m = 1 or n = 1.
if ( ( ( m == 1 ) || ( n == 1 ) ) && ( mtag_b == REORDERED) )
{
if ( ( k % grp_post_op_list->group_size != 0 ) ||
( KC % grp_post_op_list->group_size != 0 ) )
{
bli_print_msg( "Quantized GEMV is only supported only when k and KC are "
"divisible by group_size." , __FILE__, __LINE__ );
return; // Error
}
lpgemv_rowvar_s8s8s32o32_sym_quant
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha,
beta,
rntm,
thread,
lcntx,
grp_post_op_list,
post_op_list,
c_downscale
);
return;
}
#endif
// Strides are updated based on matrix packing/reordering.
const int8_t* a_use = NULL;
dim_t rs_a_use = rs_a;
dim_t cs_a_use = cs_a;
dim_t a_block_stride = 0;
const int8_t* b_use = NULL;
dim_t rs_b_use = rs_b;
dim_t cs_b_use = cs_b;
float* c_use_jc = NULL;
float* c_use_ic = NULL;
dim_t rs_c_use = rs_c;
dim_t rs_c_downscale = rs_c;
// Pack buffer for A.
int8_t* pack_a_buffer_s8s8s32o32;
mem_t mem_a = BLIS_MEM_INITIALIZER;
siz_t mem_a_size_req = 0;
// Pack buffer for B.
int8_t* pack_b_buffer_s8s8s32o32;
mem_t mem_b = BLIS_MEM_INITIALIZER;
siz_t mem_b_size_req = 0;
dim_t packb_min_NR = get_packb_s8s8s32o32_min_NR();
// Temporary buffer for C accumulation when downscaling is required.
float* temp_scal_c_buffer_s8s8s32o32;
mem_t mem_scale_c = BLIS_MEM_INITIALIZER;
siz_t mem_scale_c_size_req = 0;
// kc needs to be a multiple of 4 so that it can be used with vpdpbusd
// instruction. Padding is added in cases this condition is not
// satisfied, and therefore the k offset used for packed/reordered
// buffer needs to be updated.
dim_t k_updated = make_multiple_of_n( k, 4 );
dim_t n_updated = make_multiple_of_n( n, 16 );
// To decide whether to apply post ops or not.
bool is_last_k = FALSE;
// To decide whether to use original s8 C or temp buffer for beta scale.
bool is_first_k = FALSE;
lpgemm_post_op_attr post_ops_attr;
lpgemm_grp_post_op_attr grp_post_ops_attr;
post_ops_attr.c_stor_type = c_downscale;
if ( c_downscale < F32 )
{
post_ops_attr.buf_downscale = c;
}
else
{
post_ops_attr.buf_downscale = NULL;
}
dim_t group_size = grp_post_op_list->group_size;
// Initialize group post ops attributes.
grp_post_ops_attr.a_scale_factor = grp_post_op_list->a_scale_factor;
grp_post_ops_attr.a_scale_factor_len = grp_post_op_list->a_scale_factor_len;
grp_post_ops_attr.b_scale_factor = grp_post_op_list->b_scale_factor;
grp_post_ops_attr.b_scale_factor_len = grp_post_op_list->b_scale_factor_len;
grp_post_ops_attr.a_zp = grp_post_op_list->a_zp;
grp_post_ops_attr.b_zp = grp_post_op_list->b_zp;
grp_post_ops_attr.a_zp_len = grp_post_op_list->a_zp_len;
grp_post_ops_attr.b_zp_len = grp_post_op_list->b_zp_len;
grp_post_ops_attr.group_size = group_size;
grp_post_ops_attr.sf_stor_type = grp_post_op_list->sf_stor_type;
grp_post_ops_attr.zp_stor_type = grp_post_op_list->zp_stor_type;
dim_t num_groups = ( k + group_size - 1 ) / group_size;
grp_post_ops_attr.grp_post_op_lda = num_groups;
grp_post_ops_attr.grp_post_op_ldb = n;
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
thrinfo_t thread_jc;
thrinfo_t thread_ic;
lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic );
// Compute the JC, IC loop thread range for the current thread.
dim_t jc_start, jc_end;
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
dim_t ic_start, ic_end;
bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end );
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
{
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
dim_t jc_cur_loop = jc;
dim_t jc_cur_loop_rem = 0;
dim_t n_sub_updated = 0;
if ( mtag_b == REORDERED )
{
get_B_panel_reordered_start_offset_width
(
jc, n, NC, packb_min_NR,
&jc_cur_loop, &jc_cur_loop_rem,
&nc0, &n_sub_updated
);
}
if ( c_downscale == F32 )
{
c_use_jc = c + jc;
}
// Temp accumulaton buffer for C allocation.
else if ( c_downscale < F32 )
{
// Buffer memory is only required if output needs to be
// persisted across iterations of the pc/KC loop.
// It was observed that the locks used while checking out
// a buffer from memory pool had an impact on performance
// and is better to not checkout if k <= KC.
if ( k > KC )
{
mem_scale_c_size_req = sizeof( float ) * nc0 * ( ic_end - ic_start );
lpgemm_alloc_mem_panel
(
mem_scale_c_size_req, BLIS_BUFFER_FOR_GEN_USE,
&mem_scale_c, rntm
);
temp_scal_c_buffer_s8s8s32o32 = bli_mem_buffer( &mem_scale_c );
c_use_jc = ( float* )temp_scal_c_buffer_s8s8s32o32;
}
// The temp c buffer stride is modified as opposed to original C matrix.
rs_c_use = nc0;
}
int32_t* pack_b_column_sum = NULL;
for ( dim_t pc = 0; pc < k; pc += KC )
{
int32_t beta0 = ( pc == 0 ) ? beta : 1;
dim_t kc0 = bli_min( ( k - pc ), KC );
grp_post_ops_attr.grp_post_op_k = pc;
// kc0 needs to be a multiple of 4 so that it can be
// used with vpdpbusd instruction. Padding is added in
// cases this condition is not satisfied, and therefore
// the kc0 offsets used for packed/reordered buffers
// needs to be updated.
dim_t kc0_updated = make_multiple_of_n( kc0, 4 );
// No parallelization in k dim, k always starts at 0.
is_first_k = ( pc == 0 ) ? ( TRUE ) : ( FALSE );
post_ops_attr.is_first_k = is_first_k;
is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE );
post_ops_attr.is_last_k = is_last_k;
if ( mtag_b == PACK )
{
// Pack B chunks are based on jc work id.
dim_t jc_work_id = bli_thread_work_id( &thread_jc );
// Using child thrinfo (thread_ic) tid to decide chief thread
// per B matrix chunk (jc work id group)
dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR );
dim_t group_start = pc / group_size;
dim_t group_end = ( pc + kc0 - 1 ) / group_size;
dim_t total_groups = ( k + group_size - 1 ) / group_size;
dim_t n_groups_per_kc = group_end - group_start + 1;
if ( bli_thread_am_ochief( &thread_ic ) )
{
// nc0 needs to be a multiple of 16 since this gives maximum
// vectorization. Packing B always results in buffers with width
// which is a multiple of 16. Subsequently the nc0 offsets used
// for packed/reordered buffers needs to be updated.pack
mem_b_size_req = sizeof( int8_t ) * nc0_updated * kc0_updated
+ ( n_groups_per_kc * nc0_updated * sizeof( int32_t ) );
lpgemm_alloc_mem_panel
(
mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL,
&mem_b, rntm
);
thread->comm[jc_work_id].sent_object = bli_mem_buffer( &mem_b );
}
// All threads in work group should wait till chief thread has
// finished allocating the packing buffers.
bli_thrcomm_barrier
(
bli_thread_ocomm_id( &thread_ic ),
&thread->comm[jc_work_id]
);
pack_b_buffer_s8s8s32o32 =
( int8_t* ) thread->comm[jc_work_id].sent_object;
// Compute the B panel per thread loop range for parallel
// packing using ic_ways number of threads. Since atmost only
// ic_ways threads can be used, the thread_ic attributes are
// used to split the loop range.
dim_t jc_packb_start, jc_packb_end;
bli_thread_range_sub
(
&thread_ic, nc0, NR, FALSE,
&jc_packb_start, &jc_packb_end
);
if ( pc == 0)
{
pack_b_column_sum = ( int32_t* )( pack_b_buffer_s8s8s32o32
+ ( sizeof( int8_t ) * nc0_updated
* kc0_updated ) );
}
// Ensure thread ranges are valid, especially cases where no:
// of threads available for parallelization are greater than
// no: of B panel NR chunks.
if ( ( jc_packb_end > jc_packb_start ) &&
( jc_packb_start < ( jc + nc0 ) ) )
{
dim_t nc0_pack = jc_packb_end - jc_packb_start;
if ( pc == 0 )
{
for( dim_t group = 0; group < total_groups; group++ )
{
for (dim_t idx = jc_packb_start; idx < jc_packb_end; idx++ )
{
*( pack_b_column_sum + (group * nc0_updated ) + idx ) = 0;
}
}
}
// packing kernels are designed in such a way assuming that entire KCxNC
// block is packed at once and strides are set based on KC value.
// In current scenario, we call kernel with blocks of group_size x NC
// so kernel assumes that KC is group_size and strides are set based on group_size.
// To avoid this, we are calling kernel with blocks of group_size x NR, so that
// we can take care of the pointer movement across the reorder buffer in the framework
// itself.
for( dim_t jr = 0; jr < nc0_pack; jr += NR )
{
dim_t nr0 = bli_min( ( nc0_pack - jr ), NR );
int8_t* b_dst_jr = pack_b_buffer_s8s8s32o32 + ( ( jc_packb_start + jr ) * kc0_updated );
int32_t* b_sum_ptr = pack_b_column_sum + ( jc_packb_start + jr );
int8_t* b_src_jr = (int8_t*)b + ( cs_b * ( jc + jc_packb_start + jr ) );
if( nr0 < NR )
{
dim_t nr_mult_16 = (nr0 / 16) * 16;
dim_t nr0_rem = nr0 % 16;
dim_t nr0_updated = nr_mult_16;
if( nr_mult_16 > 0 )
{
// group loop
for( dim_t group = group_start; group <= group_end; group++ )
{
dim_t k_start = bli_max( group * group_size, pc );
dim_t k_end = bli_min( ( ( group + 1 ) * group_size - 1 ),
pc + kc0 - 1);
dim_t kg0 = k_end - k_start + 1;
( ( packb_s32_s8 )lcntx->packb_fun_ptr )
( b_dst_jr + ( (group * group_size) - pc) * nr0_updated,
b_sum_ptr + (group * nc0_updated),
b_src_jr + (rs_b * k_start),
rs_b, cs_b, nr_mult_16, kg0, &rs_b_use, &cs_b_use
);
}
b_dst_jr += nr_mult_16 * kc0_updated;
b_sum_ptr += nr_mult_16;
b_src_jr += nr_mult_16 * cs_b;
}
if( nr0_rem > 0 )
{
dim_t nr0_updated = 16;
// group loop
for( dim_t group = group_start; group <= group_end; group++ )
{
dim_t k_start = bli_max( group * group_size, pc );
dim_t k_end = bli_min( ( ( group + 1 ) * group_size - 1 ),
pc + kc0 - 1);
dim_t kg0 = k_end - k_start + 1;
( ( packb_s32_s8 )lcntx->packb_fun_ptr )
( b_dst_jr + ( (group * group_size) - pc) * nr0_updated,
b_sum_ptr + (group * nc0_updated),
b_src_jr + (rs_b * k_start),
rs_b, cs_b, nr0_rem, kg0, &rs_b_use, &cs_b_use
);
}
}
// no fringe after this point
continue;
}
dim_t nr0_updated = NR;
// nr0 == NR
for( dim_t group = group_start; group <= group_end; group++ )
{
dim_t k_start = bli_max( group * group_size, pc );
dim_t k_end = bli_min( ( ( group + 1 ) * group_size - 1 ),
pc + kc0 - 1);
dim_t kg0 = k_end - k_start + 1;
( ( packb_s32_s8 )lcntx->packb_fun_ptr )
( b_dst_jr + ( (group * group_size) - pc) * nr0_updated,
b_sum_ptr + (group * nc0_updated),
b_src_jr + (rs_b * k_start),
rs_b, cs_b, NR, kg0, &rs_b_use, &cs_b_use
);
}
}
rs_b_use = NR * 4;
cs_b_use = NR;
}
else
{
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
}
// All threads in work group should wait till B matrix packing
// is completed by the participating threads.
bli_thrcomm_barrier
(
bli_thread_ocomm_id( &thread_ic ),
&thread->comm[jc_work_id]
);
b_use = pack_b_buffer_s8s8s32o32;
post_ops_attr.b_col_sum_vec = pack_b_column_sum;
grp_post_ops_attr.grp_post_op_sum_ld = nc0_updated;
}
else if ( mtag_b == REORDERED )
{
// In multi-threaded scenarios, an extra offset into a given
// packed B panel is required, since the jc loop split can
// result in per thread start offset inside the panel, instead
// of panel boundaries.
b_use = b + ( jc_cur_loop * k_updated ) +
( n_sub_updated * pc ) +
( jc_cur_loop_rem * kc0_updated );
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
post_ops_attr.b_col_sum_vec = ( ( int32_t* )( b + ( k_updated * n_updated ) ) ) + jc;
grp_post_ops_attr.grp_post_op_sum_ld = n_updated;
}
else
{
//Unpacked B not supported.
return;
}
for ( dim_t ic = ic_start; ic < ic_end; ic += MC )
{
dim_t mc0 = bli_min( ( ic_end - ic ), MC );
grp_post_ops_attr.grp_post_op_i = ic;
// Only per thread C matrix is stored in temp buffer, so both
// per thread jc and ic start should be normalized to zero.
if ( c_downscale < F32 )
{
c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) );
}
else
{
c_use_ic = c_use_jc + ( rs_c_use * ic );
}
// Matrix A packed and reordered code path is not triggerred
// currently since we do not support it yet.
if ( mtag_a == PACK )
{
mem_a_size_req = sizeof( uint8_t ) * mc0 * kc0_updated;
lpgemm_alloc_mem_panel
(
mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK,
&mem_a, rntm
);
pack_a_buffer_s8s8s32o32 = ( int8_t* )bli_mem_buffer( &mem_a );
( ( packa_s32 )lcntx->packa_fun_ptr )
(
( uint8_t* )pack_a_buffer_s8s8s32o32,
( uint8_t* )( a + ( rs_a * ic ) + ( cs_a * pc ) ), rs_a, cs_a,
mc0, kc0,
&rs_a_use, &cs_a_use
);
a_use = pack_a_buffer_s8s8s32o32;
if( cs_a == 1 )
{
a_block_stride = kc0_updated;
}
else
{
a_block_stride = rs_a_use;
}
}
else
{
a_use = a + ( rs_a * ic ) + ( cs_a * pc );
// Int8 kernel reads 4 elements, totalling 4 bytes in a
// single broadcast for use in vnni instruction.
// Non vnni based kernel requires update to this code.
cs_a_use = 4;
a_block_stride = rs_a;
}
post_ops_attr.b_sum_offset = 0;
for ( dim_t jr = 0; jr < nc0; jr += NR )
{
dim_t nr0 = bli_min( ( nc0 - jr ), NR );
// Post ops meta attributes.
post_ops_attr.post_op_c_i = ic;
post_ops_attr.post_op_c_j = ( jc + jr );
post_ops_attr.rs_c_downscale = rs_c_downscale;
grp_post_ops_attr.grp_post_op_j = jc + jr;
// The kernels are defined in zen4 folder
#ifdef BLIS_KERNELS_ZEN4
// Reorder/Packed B, Reorder/Packed/Unpacked A call.
lpgemm_rowvar_s8s8s32os32_6x64m_sym_quant
(
mc0, nr0, kc0,
a_use, rs_a_use, cs_a_use, a_block_stride,
( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use,
( c_use_ic + jr ), rs_c_use, 1,
alpha, beta0, grp_post_ops_attr,
post_op_list, post_ops_attr
);
#endif
post_ops_attr.b_sum_offset += NR;
}
}
}
if ( mtag_b == REORDERED )
{
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
}
}
// Release pack buffers.
if ( mtag_b == PACK )
{
// All threads in work group should wait till B matrix usage is
// completed by the participating threads.
bli_thrcomm_barrier
(
bli_thread_ocomm_id( &thread_jc ),
&thread->comm[bli_thread_work_id( &thread_jc)]
);
if ( bli_thread_am_ochief( &thread_ic ) )
{
if ( bli_mem_is_alloc( &mem_b ) )
{
bli_pba_release( rntm, &mem_b );
}
}
}
if ( mtag_a == PACK )
{
if ( bli_mem_is_alloc( &mem_a ) )
{
bli_pba_release( rntm, &mem_a );
}
}
if ( c_downscale < F32 )
{
if ( bli_mem_is_alloc( &mem_scale_c ) )
{
bli_pba_release( rntm, &mem_scale_c );
}
}
}