mirror of
https://github.com/amd/blis.git
synced 2026-05-04 14:31:12 +00:00
Introduced support for GEMV operations with group-level symmetric quantization for the S8S8S32032 API. Framework Changes: - Added macro definitions and function prototypes for GEMV with symmetric quantization in lpgemm_5loop_interface_apis.h and lpgemm_kernels.h. - LPGEMV_M_EQ1_KERN2 for the lpgemv_m_one_s8s8s32os32_sym_quant kernel, and - LPGEMV_N_EQ1_KERN2 for the lpgemv_n_one_s8s8s32os32_sym_quant kernel. - Implemented the main GEMV framework for symmetric quantization in lpgemm_s8s8s32_sym_quant.c. Kernel Changes: - lpgemv_m_one_s8s8s32os32_sym_quant for handling the case where M = 1 and implemented in lpgemv_m_kernel_s8_grp_amd512vnni.c. - lpgemv_n_one_s8s8s32os32_sym_quant for handling the case where N = 1 and implemented in lpgemv_n_kernel_s8_grp_amd512vnni.c. - Updated the buffer reordering logic for group quantization for N=1 cases in aocl_gemm_s8s8s32os32_utils.c. Notes - Ensure that group_size is a factor of both K (and KC when K > KC). - The B matrix must be provided in reordered format (mtag_b == REORDERED). AMD-Internal: [SWLCSG-3604]
905 lines
26 KiB
C
905 lines
26 KiB
C
/*
|
|
|
|
BLIS
|
|
An object-based framework for developing high-performance BLAS-like
|
|
libraries.
|
|
|
|
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are
|
|
met:
|
|
- Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
- Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
- Neither the name(s) of the copyright holder(s) nor the names of its
|
|
contributors may be used to endorse or promote products derived
|
|
from this software without specific prior written permission.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*/
|
|
|
|
#include "blis.h"
|
|
#include "lpgemm_5loop_interface_apis.h"
|
|
#include "lpgemm_packa_s8.h"
|
|
#include "lpgemm_packb_s8.h"
|
|
#include "lpgemm_kernels.h"
|
|
#include "lpgemm_utils_s8.h"
|
|
#include "lpgemm_thrinfo_utils.h"
|
|
#include "lpgemm_config.h"
|
|
#include "lpgemm_packa.h"
|
|
|
|
// NOTE
|
|
// 1. Mandatory for matrix B to be reordered, i.e., mtag_b == REORDERED.
|
|
// 2. K should be divisible by group_size.
|
|
#ifdef BLIS_KERNELS_ZEN4
|
|
LPGEMV2(int8_t,int8_t,int32_t,s8s8s32o32_sym_quant)
|
|
{
|
|
dim_t NC = lcntx->blksz.NC;
|
|
dim_t KC = lcntx->blksz.KC;
|
|
dim_t MC = lcntx->blksz.MC;
|
|
dim_t NR = lcntx->blksz.NR;
|
|
|
|
// Group size should always be <= KC to make sure that entire group is processed
|
|
// within one micro-kernel call.
|
|
// If group size is greater than KC, then KC will be updated to group size.
|
|
// This same change is done in reorder function to maintain consistency between
|
|
// reorder and GEMM execution.
|
|
if( grp_post_op_list->group_size > KC )
|
|
{
|
|
KC = grp_post_op_list->group_size;
|
|
}
|
|
|
|
// Strides are updated based on matrix packing/reordering.
|
|
int8_t* a_use = ( int8_t* )a;
|
|
inc_t rs_a_use = rs_a;
|
|
inc_t cs_a_use = cs_a;
|
|
|
|
int8_t* b_use = ( int8_t* )b;
|
|
dim_t rs_b_use = rs_b;
|
|
inc_t cs_b_use = cs_b;
|
|
|
|
float *c_use = NULL;
|
|
|
|
lpgemm_post_op_attr post_ops_attr;
|
|
|
|
post_ops_attr.c_stor_type = c_downscale;
|
|
if ( c_downscale < F32 )
|
|
{
|
|
post_ops_attr.buf_downscale = c;
|
|
}
|
|
else
|
|
{
|
|
post_ops_attr.buf_downscale = NULL;
|
|
}
|
|
|
|
siz_t mem_a_size_req = 0;
|
|
|
|
mem_t mem_a = BLIS_MEM_INITIALIZER;
|
|
mem_t mem_b = BLIS_MEM_INITIALIZER;
|
|
|
|
int8_t* pack_a_buffer_s8s8s32os32;
|
|
|
|
lpgemm_grp_post_op_attr grp_post_ops_attr;
|
|
|
|
dim_t group_size = grp_post_op_list->group_size;
|
|
|
|
// Initialize group post ops attributes.
|
|
grp_post_ops_attr.a_scale_factor = grp_post_op_list->a_scale_factor;
|
|
grp_post_ops_attr.a_scale_factor_len = grp_post_op_list->a_scale_factor_len;
|
|
grp_post_ops_attr.b_scale_factor = grp_post_op_list->b_scale_factor;
|
|
grp_post_ops_attr.b_scale_factor_len = grp_post_op_list->b_scale_factor_len;
|
|
grp_post_ops_attr.a_zp = grp_post_op_list->a_zp;
|
|
grp_post_ops_attr.b_zp = grp_post_op_list->b_zp;
|
|
grp_post_ops_attr.a_zp_len = grp_post_op_list->a_zp_len;
|
|
grp_post_ops_attr.b_zp_len = grp_post_op_list->b_zp_len;
|
|
grp_post_ops_attr.group_size = group_size;
|
|
grp_post_ops_attr.sf_stor_type = grp_post_op_list->sf_stor_type;
|
|
grp_post_ops_attr.zp_stor_type = grp_post_op_list->zp_stor_type;
|
|
|
|
dim_t num_groups = ( k + group_size - 1 ) / group_size;
|
|
grp_post_ops_attr.grp_post_op_lda = num_groups;
|
|
grp_post_ops_attr.grp_post_op_ldb = n;
|
|
|
|
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
|
|
thrinfo_t thread_jc;
|
|
thrinfo_t thread_ic;
|
|
|
|
lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic );
|
|
|
|
if( n == 1 )
|
|
{
|
|
// Increased MR from 6 to 16 to make use of 32 ZMM registers
|
|
dim_t MR = 16;
|
|
|
|
if( mtag_b == REORDERED )
|
|
{
|
|
post_ops_attr.b_col_sum_vec = ( int32_t* )( b + k );
|
|
}
|
|
else if( mtag_b == PACK )
|
|
{
|
|
// Unreordered B not supported.
|
|
return;
|
|
}
|
|
else
|
|
{
|
|
// Unpacked B not supported.
|
|
return;
|
|
}
|
|
|
|
// Compute the IC loop thread range for the current thread.
|
|
dim_t ic_start, ic_end;
|
|
thread_ic.n_way = ( thread_ic.n_way == 1 ) ?
|
|
( thread->n_threads ) : ( thread_ic.n_way );
|
|
thread_ic.work_id = thread->tid;
|
|
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
|
|
|
|
grp_post_ops_attr.grp_post_op_k = 0;
|
|
for ( dim_t ic = ic_start; ic < ic_end; ic += MC )
|
|
{
|
|
grp_post_ops_attr.grp_post_op_i = ic;
|
|
|
|
dim_t mc0 = bli_min( ( ic_end - ic ), MC );
|
|
|
|
const int8_t *a_use = a + ic * rs_a;
|
|
c_use = c + ic * rs_c;
|
|
|
|
post_ops_attr.post_op_c_i = ic;
|
|
post_ops_attr.post_op_c_j = 0;
|
|
post_ops_attr.rs_c_downscale = rs_c;
|
|
|
|
if( mtag_a == PACK )
|
|
{
|
|
mem_a_size_req = sizeof( int8_t ) * mc0 * k;
|
|
|
|
lpgemm_alloc_mem_panel
|
|
(
|
|
mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE,
|
|
&mem_a, rntm
|
|
);
|
|
|
|
pack_a_buffer_s8s8s32os32 = ( int8_t* )bli_mem_buffer( &mem_a );
|
|
|
|
( ( packa_s32 ) lcntx->packa_fun_ptr )
|
|
(
|
|
( uint8_t* ) pack_a_buffer_s8s8s32os32,
|
|
( uint8_t* )( a + ( rs_a * ic )), rs_a, cs_a,
|
|
mc0, k,
|
|
&rs_a_use, &cs_a_use
|
|
);
|
|
a_use = pack_a_buffer_s8s8s32os32;
|
|
}
|
|
|
|
// Call lpgemv_n_one kernel
|
|
lpgemv_n_one_s8s8s32os32_sym_quant
|
|
(
|
|
mc0, k,
|
|
a_use, rs_a_use, cs_a_use, mtag_a,
|
|
b_use, rs_b_use, cs_b_use, mtag_b,
|
|
c_use, rs_c, cs_c,
|
|
alpha, beta,
|
|
MR, KC,
|
|
grp_post_ops_attr,
|
|
post_op_list,
|
|
&post_ops_attr
|
|
);
|
|
}
|
|
|
|
// Release pack buffers
|
|
if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) )
|
|
{
|
|
bli_pba_release( rntm, &mem_a );
|
|
}
|
|
if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) )
|
|
{
|
|
bli_pba_release( rntm, &mem_b );
|
|
}
|
|
}
|
|
else
|
|
{
|
|
dim_t gemm_MR = lcntx->blksz.MR;
|
|
|
|
dim_t jc_start, jc_end;
|
|
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
|
|
( thread->n_threads ) : ( thread_jc.n_way );
|
|
thread_jc.work_id = thread->tid;
|
|
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
|
|
|
|
dim_t packb_min_NR = get_packb_s8s8s32o32_min_NR();
|
|
|
|
// kc needs to be a multiple of 4 so that it can be used with vpdpbusd
|
|
// instruction. Padding is added in cases this condition is not
|
|
// satisfied, and therefore the k offset used for packed/reordered
|
|
// buffer needs to be updated.
|
|
dim_t k_updated = make_multiple_of_n( k, 4 );
|
|
dim_t n_updated = make_multiple_of_n( n, 16 );
|
|
|
|
rs_a_use = rs_a;
|
|
cs_a_use = 4;
|
|
|
|
if ( mtag_a == PACK )
|
|
{
|
|
mem_a_size_req = sizeof( uint8_t ) * k;
|
|
|
|
lpgemm_alloc_mem_panel
|
|
(
|
|
mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE,
|
|
&mem_a, rntm
|
|
);
|
|
|
|
pack_a_buffer_s8s8s32os32 =
|
|
( int8_t* ) bli_mem_buffer( &mem_a );
|
|
|
|
( ( packa_s32 )lcntx->packa_fun_ptr )
|
|
(
|
|
( uint8_t* )pack_a_buffer_s8s8s32os32,
|
|
( uint8_t* )a, rs_a, cs_a,
|
|
1, k,
|
|
&rs_a_use, &cs_a_use
|
|
);
|
|
|
|
get_packa_strides_mfringe_u8s8s32os32
|
|
(
|
|
rs_a, cs_a, &rs_a_use, &cs_a_use, gemm_MR, 1
|
|
);
|
|
|
|
a_use = pack_a_buffer_s8s8s32os32;
|
|
}
|
|
|
|
grp_post_ops_attr.grp_post_op_k = 0;
|
|
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
|
|
{
|
|
grp_post_ops_attr.grp_post_op_j = jc;
|
|
|
|
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
|
|
c_use = c + jc;
|
|
|
|
dim_t jc_cur_loop = jc;
|
|
dim_t jc_cur_loop_rem = 0;
|
|
dim_t n_sub_updated = 0;
|
|
|
|
dim_t kc0_updated = make_multiple_of_n( k, 4 );
|
|
|
|
if ( mtag_b == REORDERED )
|
|
{
|
|
get_B_panel_reordered_start_offset_width(
|
|
jc, n, NC, packb_min_NR,
|
|
&jc_cur_loop, &jc_cur_loop_rem,
|
|
&nc0, &n_sub_updated );
|
|
|
|
b_use = ( int8_t* ) ( b +
|
|
( jc_cur_loop * k_updated ) +
|
|
( jc_cur_loop_rem * kc0_updated )
|
|
);
|
|
|
|
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
|
|
|
|
post_ops_attr.b_col_sum_vec = ( ( int32_t* )( b +
|
|
( k_updated * n_updated ) ) ) +
|
|
jc;
|
|
|
|
grp_post_ops_attr.grp_post_op_sum_ld = n_updated;
|
|
}
|
|
else if( mtag_b == PACK )
|
|
{
|
|
// Unreordered B not supported.
|
|
return;
|
|
}
|
|
else
|
|
{
|
|
// Unpacked B not supported.
|
|
return;
|
|
}
|
|
|
|
post_ops_attr.post_op_c_i = 0;
|
|
post_ops_attr.post_op_c_j = jc;
|
|
post_ops_attr.rs_c_downscale = rs_c;
|
|
post_ops_attr.b_sum_offset = 0;
|
|
|
|
lpgemv_m_one_s8s8s32os32_sym_quant
|
|
(
|
|
nc0, k,
|
|
a_use, rs_a_use, cs_a_use, mtag_a,
|
|
b_use, rs_b_use, cs_b_use, mtag_b,
|
|
c_use, rs_c, cs_c,
|
|
alpha, beta,
|
|
NR, KC,
|
|
n_sub_updated,
|
|
jc_cur_loop_rem,
|
|
grp_post_ops_attr,
|
|
post_op_list,
|
|
&post_ops_attr
|
|
);
|
|
|
|
if ( mtag_b == REORDERED )
|
|
{
|
|
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
|
|
}
|
|
} // jc loop
|
|
|
|
// Release pack buffers.
|
|
if ( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) )
|
|
{
|
|
bli_pba_release( rntm, &mem_b );
|
|
}
|
|
if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) )
|
|
{
|
|
bli_pba_release(rntm, &mem_a);
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
// B should always be packed.
|
|
LPGEMM_5LOOP2(int8_t,int8_t,int32_t,s8s8s32o32_sym_quant)
|
|
{
|
|
dim_t NC = lcntx->blksz.NC;
|
|
dim_t KC = lcntx->blksz.KC;
|
|
dim_t MC = lcntx->blksz.MC;
|
|
dim_t NR = lcntx->blksz.NR;
|
|
dim_t MR = lcntx->blksz.MR;
|
|
|
|
// Group size should always be <= KC to make sure that entire group is processed
|
|
// within one micro-kernel call.
|
|
// If group size is greater than KC, then KC will be updated to group size.
|
|
// This same change is done in reorder function to maintain consistency between
|
|
// reorder and GEMM execution.
|
|
if( grp_post_op_list->group_size > KC )
|
|
{
|
|
KC = grp_post_op_list->group_size;
|
|
}
|
|
if ( mtag_b == UNPACKED )
|
|
{
|
|
// Error: can only work with packed B now.
|
|
return;
|
|
}
|
|
|
|
#ifdef BLIS_KERNELS_ZEN4
|
|
// Invoke gemv kernels for m = 1 or n = 1.
|
|
if ( ( ( m == 1 ) || ( n == 1 ) ) && ( mtag_b == REORDERED) )
|
|
{
|
|
if ( ( k % grp_post_op_list->group_size != 0 ) ||
|
|
( KC % grp_post_op_list->group_size != 0 ) )
|
|
{
|
|
bli_print_msg( "Quantized GEMV is only supported only when k and KC are "
|
|
"divisible by group_size." , __FILE__, __LINE__ );
|
|
return; // Error
|
|
}
|
|
|
|
lpgemv_rowvar_s8s8s32o32_sym_quant
|
|
(
|
|
m, n, k,
|
|
a, rs_a, cs_a, mtag_a,
|
|
b, rs_b, cs_b, mtag_b,
|
|
c, rs_c, cs_c,
|
|
alpha,
|
|
beta,
|
|
rntm,
|
|
thread,
|
|
lcntx,
|
|
grp_post_op_list,
|
|
post_op_list,
|
|
c_downscale
|
|
);
|
|
|
|
return;
|
|
}
|
|
|
|
#endif
|
|
|
|
// Strides are updated based on matrix packing/reordering.
|
|
const int8_t* a_use = NULL;
|
|
dim_t rs_a_use = rs_a;
|
|
dim_t cs_a_use = cs_a;
|
|
dim_t a_block_stride = 0;
|
|
|
|
const int8_t* b_use = NULL;
|
|
dim_t rs_b_use = rs_b;
|
|
dim_t cs_b_use = cs_b;
|
|
|
|
float* c_use_jc = NULL;
|
|
float* c_use_ic = NULL;
|
|
dim_t rs_c_use = rs_c;
|
|
dim_t rs_c_downscale = rs_c;
|
|
|
|
// Pack buffer for A.
|
|
int8_t* pack_a_buffer_s8s8s32o32;
|
|
mem_t mem_a = BLIS_MEM_INITIALIZER;
|
|
siz_t mem_a_size_req = 0;
|
|
|
|
// Pack buffer for B.
|
|
int8_t* pack_b_buffer_s8s8s32o32;
|
|
mem_t mem_b = BLIS_MEM_INITIALIZER;
|
|
siz_t mem_b_size_req = 0;
|
|
dim_t packb_min_NR = get_packb_s8s8s32o32_min_NR();
|
|
|
|
// Temporary buffer for C accumulation when downscaling is required.
|
|
float* temp_scal_c_buffer_s8s8s32o32;
|
|
mem_t mem_scale_c = BLIS_MEM_INITIALIZER;
|
|
siz_t mem_scale_c_size_req = 0;
|
|
|
|
// kc needs to be a multiple of 4 so that it can be used with vpdpbusd
|
|
// instruction. Padding is added in cases this condition is not
|
|
// satisfied, and therefore the k offset used for packed/reordered
|
|
// buffer needs to be updated.
|
|
dim_t k_updated = make_multiple_of_n( k, 4 );
|
|
dim_t n_updated = make_multiple_of_n( n, 16 );
|
|
|
|
// To decide whether to apply post ops or not.
|
|
bool is_last_k = FALSE;
|
|
|
|
// To decide whether to use original s8 C or temp buffer for beta scale.
|
|
bool is_first_k = FALSE;
|
|
|
|
lpgemm_post_op_attr post_ops_attr;
|
|
lpgemm_grp_post_op_attr grp_post_ops_attr;
|
|
|
|
post_ops_attr.c_stor_type = c_downscale;
|
|
if ( c_downscale < F32 )
|
|
{
|
|
post_ops_attr.buf_downscale = c;
|
|
}
|
|
else
|
|
{
|
|
post_ops_attr.buf_downscale = NULL;
|
|
}
|
|
|
|
dim_t group_size = grp_post_op_list->group_size;
|
|
|
|
// Initialize group post ops attributes.
|
|
grp_post_ops_attr.a_scale_factor = grp_post_op_list->a_scale_factor;
|
|
grp_post_ops_attr.a_scale_factor_len = grp_post_op_list->a_scale_factor_len;
|
|
grp_post_ops_attr.b_scale_factor = grp_post_op_list->b_scale_factor;
|
|
grp_post_ops_attr.b_scale_factor_len = grp_post_op_list->b_scale_factor_len;
|
|
grp_post_ops_attr.a_zp = grp_post_op_list->a_zp;
|
|
grp_post_ops_attr.b_zp = grp_post_op_list->b_zp;
|
|
grp_post_ops_attr.a_zp_len = grp_post_op_list->a_zp_len;
|
|
grp_post_ops_attr.b_zp_len = grp_post_op_list->b_zp_len;
|
|
grp_post_ops_attr.group_size = group_size;
|
|
grp_post_ops_attr.sf_stor_type = grp_post_op_list->sf_stor_type;
|
|
grp_post_ops_attr.zp_stor_type = grp_post_op_list->zp_stor_type;
|
|
|
|
dim_t num_groups = ( k + group_size - 1 ) / group_size;
|
|
grp_post_ops_attr.grp_post_op_lda = num_groups;
|
|
grp_post_ops_attr.grp_post_op_ldb = n;
|
|
|
|
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
|
|
thrinfo_t thread_jc;
|
|
thrinfo_t thread_ic;
|
|
|
|
lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic );
|
|
|
|
// Compute the JC, IC loop thread range for the current thread.
|
|
dim_t jc_start, jc_end;
|
|
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
|
|
|
|
dim_t ic_start, ic_end;
|
|
bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end );
|
|
|
|
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
|
|
{
|
|
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
|
|
|
|
dim_t jc_cur_loop = jc;
|
|
dim_t jc_cur_loop_rem = 0;
|
|
dim_t n_sub_updated = 0;
|
|
|
|
if ( mtag_b == REORDERED )
|
|
{
|
|
get_B_panel_reordered_start_offset_width
|
|
(
|
|
jc, n, NC, packb_min_NR,
|
|
&jc_cur_loop, &jc_cur_loop_rem,
|
|
&nc0, &n_sub_updated
|
|
);
|
|
}
|
|
|
|
if ( c_downscale == F32 )
|
|
{
|
|
c_use_jc = c + jc;
|
|
}
|
|
// Temp accumulaton buffer for C allocation.
|
|
else if ( c_downscale < F32 )
|
|
{
|
|
// Buffer memory is only required if output needs to be
|
|
// persisted across iterations of the pc/KC loop.
|
|
// It was observed that the locks used while checking out
|
|
// a buffer from memory pool had an impact on performance
|
|
// and is better to not checkout if k <= KC.
|
|
if ( k > KC )
|
|
{
|
|
mem_scale_c_size_req = sizeof( float ) * nc0 * ( ic_end - ic_start );
|
|
|
|
lpgemm_alloc_mem_panel
|
|
(
|
|
mem_scale_c_size_req, BLIS_BUFFER_FOR_GEN_USE,
|
|
&mem_scale_c, rntm
|
|
);
|
|
|
|
temp_scal_c_buffer_s8s8s32o32 = bli_mem_buffer( &mem_scale_c );
|
|
|
|
c_use_jc = ( float* )temp_scal_c_buffer_s8s8s32o32;
|
|
|
|
}
|
|
|
|
// The temp c buffer stride is modified as opposed to original C matrix.
|
|
rs_c_use = nc0;
|
|
}
|
|
|
|
int32_t* pack_b_column_sum = NULL;
|
|
|
|
for ( dim_t pc = 0; pc < k; pc += KC )
|
|
{
|
|
int32_t beta0 = ( pc == 0 ) ? beta : 1;
|
|
dim_t kc0 = bli_min( ( k - pc ), KC );
|
|
|
|
grp_post_ops_attr.grp_post_op_k = pc;
|
|
|
|
// kc0 needs to be a multiple of 4 so that it can be
|
|
// used with vpdpbusd instruction. Padding is added in
|
|
// cases this condition is not satisfied, and therefore
|
|
// the kc0 offsets used for packed/reordered buffers
|
|
// needs to be updated.
|
|
dim_t kc0_updated = make_multiple_of_n( kc0, 4 );
|
|
|
|
// No parallelization in k dim, k always starts at 0.
|
|
is_first_k = ( pc == 0 ) ? ( TRUE ) : ( FALSE );
|
|
post_ops_attr.is_first_k = is_first_k;
|
|
|
|
is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE );
|
|
post_ops_attr.is_last_k = is_last_k;
|
|
|
|
if ( mtag_b == PACK )
|
|
{
|
|
// Pack B chunks are based on jc work id.
|
|
dim_t jc_work_id = bli_thread_work_id( &thread_jc );
|
|
|
|
// Using child thrinfo (thread_ic) tid to decide chief thread
|
|
// per B matrix chunk (jc work id group)
|
|
dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR );
|
|
|
|
dim_t group_start = pc / group_size;
|
|
dim_t group_end = ( pc + kc0 - 1 ) / group_size;
|
|
|
|
|
|
dim_t total_groups = ( k + group_size - 1 ) / group_size;
|
|
dim_t n_groups_per_kc = group_end - group_start + 1;
|
|
|
|
if ( bli_thread_am_ochief( &thread_ic ) )
|
|
{
|
|
// nc0 needs to be a multiple of 16 since this gives maximum
|
|
// vectorization. Packing B always results in buffers with width
|
|
// which is a multiple of 16. Subsequently the nc0 offsets used
|
|
// for packed/reordered buffers needs to be updated.pack
|
|
|
|
mem_b_size_req = sizeof( int8_t ) * nc0_updated * kc0_updated
|
|
+ ( n_groups_per_kc * nc0_updated * sizeof( int32_t ) );
|
|
|
|
lpgemm_alloc_mem_panel
|
|
(
|
|
mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL,
|
|
&mem_b, rntm
|
|
);
|
|
|
|
thread->comm[jc_work_id].sent_object = bli_mem_buffer( &mem_b );
|
|
}
|
|
|
|
// All threads in work group should wait till chief thread has
|
|
// finished allocating the packing buffers.
|
|
bli_thrcomm_barrier
|
|
(
|
|
bli_thread_ocomm_id( &thread_ic ),
|
|
&thread->comm[jc_work_id]
|
|
);
|
|
|
|
pack_b_buffer_s8s8s32o32 =
|
|
( int8_t* ) thread->comm[jc_work_id].sent_object;
|
|
|
|
// Compute the B panel per thread loop range for parallel
|
|
// packing using ic_ways number of threads. Since atmost only
|
|
// ic_ways threads can be used, the thread_ic attributes are
|
|
// used to split the loop range.
|
|
dim_t jc_packb_start, jc_packb_end;
|
|
bli_thread_range_sub
|
|
(
|
|
&thread_ic, nc0, NR, FALSE,
|
|
&jc_packb_start, &jc_packb_end
|
|
);
|
|
|
|
if ( pc == 0)
|
|
{
|
|
pack_b_column_sum = ( int32_t* )( pack_b_buffer_s8s8s32o32
|
|
+ ( sizeof( int8_t ) * nc0_updated
|
|
* kc0_updated ) );
|
|
}
|
|
|
|
// Ensure thread ranges are valid, especially cases where no:
|
|
// of threads available for parallelization are greater than
|
|
// no: of B panel NR chunks.
|
|
if ( ( jc_packb_end > jc_packb_start ) &&
|
|
( jc_packb_start < ( jc + nc0 ) ) )
|
|
{
|
|
dim_t nc0_pack = jc_packb_end - jc_packb_start;
|
|
if ( pc == 0 )
|
|
{
|
|
for( dim_t group = 0; group < total_groups; group++ )
|
|
{
|
|
for (dim_t idx = jc_packb_start; idx < jc_packb_end; idx++ )
|
|
{
|
|
*( pack_b_column_sum + (group * nc0_updated ) + idx ) = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
// packing kernels are designed in such a way assuming that entire KCxNC
|
|
// block is packed at once and strides are set based on KC value.
|
|
// In current scenario, we call kernel with blocks of group_size x NC
|
|
// so kernel assumes that KC is group_size and strides are set based on group_size.
|
|
// To avoid this, we are calling kernel with blocks of group_size x NR, so that
|
|
// we can take care of the pointer movement across the reorder buffer in the framework
|
|
// itself.
|
|
for( dim_t jr = 0; jr < nc0_pack; jr += NR )
|
|
{
|
|
dim_t nr0 = bli_min( ( nc0_pack - jr ), NR );
|
|
|
|
int8_t* b_dst_jr = pack_b_buffer_s8s8s32o32 + ( ( jc_packb_start + jr ) * kc0_updated );
|
|
int32_t* b_sum_ptr = pack_b_column_sum + ( jc_packb_start + jr );
|
|
int8_t* b_src_jr = (int8_t*)b + ( cs_b * ( jc + jc_packb_start + jr ) );
|
|
|
|
if( nr0 < NR )
|
|
{
|
|
dim_t nr_mult_16 = (nr0 / 16) * 16;
|
|
dim_t nr0_rem = nr0 % 16;
|
|
dim_t nr0_updated = nr_mult_16;
|
|
|
|
if( nr_mult_16 > 0 )
|
|
{
|
|
// group loop
|
|
for( dim_t group = group_start; group <= group_end; group++ )
|
|
{
|
|
dim_t k_start = bli_max( group * group_size, pc );
|
|
dim_t k_end = bli_min( ( ( group + 1 ) * group_size - 1 ),
|
|
pc + kc0 - 1);
|
|
dim_t kg0 = k_end - k_start + 1;
|
|
|
|
( ( packb_s32_s8 )lcntx->packb_fun_ptr )
|
|
( b_dst_jr + ( (group * group_size) - pc) * nr0_updated,
|
|
b_sum_ptr + (group * nc0_updated),
|
|
b_src_jr + (rs_b * k_start),
|
|
rs_b, cs_b, nr_mult_16, kg0, &rs_b_use, &cs_b_use
|
|
);
|
|
}
|
|
b_dst_jr += nr_mult_16 * kc0_updated;
|
|
b_sum_ptr += nr_mult_16;
|
|
b_src_jr += nr_mult_16 * cs_b;
|
|
}
|
|
|
|
if( nr0_rem > 0 )
|
|
{
|
|
dim_t nr0_updated = 16;
|
|
// group loop
|
|
for( dim_t group = group_start; group <= group_end; group++ )
|
|
{
|
|
dim_t k_start = bli_max( group * group_size, pc );
|
|
dim_t k_end = bli_min( ( ( group + 1 ) * group_size - 1 ),
|
|
pc + kc0 - 1);
|
|
dim_t kg0 = k_end - k_start + 1;
|
|
|
|
( ( packb_s32_s8 )lcntx->packb_fun_ptr )
|
|
( b_dst_jr + ( (group * group_size) - pc) * nr0_updated,
|
|
b_sum_ptr + (group * nc0_updated),
|
|
b_src_jr + (rs_b * k_start),
|
|
rs_b, cs_b, nr0_rem, kg0, &rs_b_use, &cs_b_use
|
|
);
|
|
}
|
|
}
|
|
// no fringe after this point
|
|
continue;
|
|
}
|
|
|
|
dim_t nr0_updated = NR;
|
|
// nr0 == NR
|
|
for( dim_t group = group_start; group <= group_end; group++ )
|
|
{
|
|
dim_t k_start = bli_max( group * group_size, pc );
|
|
dim_t k_end = bli_min( ( ( group + 1 ) * group_size - 1 ),
|
|
pc + kc0 - 1);
|
|
dim_t kg0 = k_end - k_start + 1;
|
|
|
|
( ( packb_s32_s8 )lcntx->packb_fun_ptr )
|
|
( b_dst_jr + ( (group * group_size) - pc) * nr0_updated,
|
|
b_sum_ptr + (group * nc0_updated),
|
|
b_src_jr + (rs_b * k_start),
|
|
rs_b, cs_b, NR, kg0, &rs_b_use, &cs_b_use
|
|
);
|
|
|
|
}
|
|
}
|
|
|
|
rs_b_use = NR * 4;
|
|
cs_b_use = NR;
|
|
}
|
|
else
|
|
{
|
|
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
|
|
}
|
|
|
|
// All threads in work group should wait till B matrix packing
|
|
// is completed by the participating threads.
|
|
bli_thrcomm_barrier
|
|
(
|
|
bli_thread_ocomm_id( &thread_ic ),
|
|
&thread->comm[jc_work_id]
|
|
);
|
|
b_use = pack_b_buffer_s8s8s32o32;
|
|
|
|
post_ops_attr.b_col_sum_vec = pack_b_column_sum;
|
|
grp_post_ops_attr.grp_post_op_sum_ld = nc0_updated;
|
|
|
|
}
|
|
else if ( mtag_b == REORDERED )
|
|
{
|
|
// In multi-threaded scenarios, an extra offset into a given
|
|
// packed B panel is required, since the jc loop split can
|
|
// result in per thread start offset inside the panel, instead
|
|
// of panel boundaries.
|
|
b_use = b + ( jc_cur_loop * k_updated ) +
|
|
( n_sub_updated * pc ) +
|
|
( jc_cur_loop_rem * kc0_updated );
|
|
|
|
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
|
|
|
|
post_ops_attr.b_col_sum_vec = ( ( int32_t* )( b + ( k_updated * n_updated ) ) ) + jc;
|
|
|
|
grp_post_ops_attr.grp_post_op_sum_ld = n_updated;
|
|
}
|
|
else
|
|
{
|
|
//Unpacked B not supported.
|
|
return;
|
|
}
|
|
|
|
for ( dim_t ic = ic_start; ic < ic_end; ic += MC )
|
|
{
|
|
dim_t mc0 = bli_min( ( ic_end - ic ), MC );
|
|
|
|
grp_post_ops_attr.grp_post_op_i = ic;
|
|
|
|
// Only per thread C matrix is stored in temp buffer, so both
|
|
// per thread jc and ic start should be normalized to zero.
|
|
if ( c_downscale < F32 )
|
|
{
|
|
c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) );
|
|
}
|
|
else
|
|
{
|
|
c_use_ic = c_use_jc + ( rs_c_use * ic );
|
|
}
|
|
|
|
// Matrix A packed and reordered code path is not triggerred
|
|
// currently since we do not support it yet.
|
|
if ( mtag_a == PACK )
|
|
{
|
|
mem_a_size_req = sizeof( uint8_t ) * mc0 * kc0_updated;
|
|
|
|
lpgemm_alloc_mem_panel
|
|
(
|
|
mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK,
|
|
&mem_a, rntm
|
|
);
|
|
pack_a_buffer_s8s8s32o32 = ( int8_t* )bli_mem_buffer( &mem_a );
|
|
|
|
( ( packa_s32 )lcntx->packa_fun_ptr )
|
|
(
|
|
( uint8_t* )pack_a_buffer_s8s8s32o32,
|
|
( uint8_t* )( a + ( rs_a * ic ) + ( cs_a * pc ) ), rs_a, cs_a,
|
|
mc0, kc0,
|
|
&rs_a_use, &cs_a_use
|
|
);
|
|
a_use = pack_a_buffer_s8s8s32o32;
|
|
|
|
if( cs_a == 1 )
|
|
{
|
|
a_block_stride = kc0_updated;
|
|
}
|
|
|
|
else
|
|
{
|
|
a_block_stride = rs_a_use;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
a_use = a + ( rs_a * ic ) + ( cs_a * pc );
|
|
|
|
// Int8 kernel reads 4 elements, totalling 4 bytes in a
|
|
// single broadcast for use in vnni instruction.
|
|
// Non vnni based kernel requires update to this code.
|
|
cs_a_use = 4;
|
|
a_block_stride = rs_a;
|
|
}
|
|
|
|
post_ops_attr.b_sum_offset = 0;
|
|
|
|
for ( dim_t jr = 0; jr < nc0; jr += NR )
|
|
{
|
|
dim_t nr0 = bli_min( ( nc0 - jr ), NR );
|
|
|
|
// Post ops meta attributes.
|
|
post_ops_attr.post_op_c_i = ic;
|
|
post_ops_attr.post_op_c_j = ( jc + jr );
|
|
post_ops_attr.rs_c_downscale = rs_c_downscale;
|
|
|
|
grp_post_ops_attr.grp_post_op_j = jc + jr;
|
|
|
|
// The kernels are defined in zen4 folder
|
|
#ifdef BLIS_KERNELS_ZEN4
|
|
// Reorder/Packed B, Reorder/Packed/Unpacked A call.
|
|
lpgemm_rowvar_s8s8s32os32_6x64m_sym_quant
|
|
(
|
|
mc0, nr0, kc0,
|
|
a_use, rs_a_use, cs_a_use, a_block_stride,
|
|
( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use,
|
|
( c_use_ic + jr ), rs_c_use, 1,
|
|
alpha, beta0, grp_post_ops_attr,
|
|
post_op_list, post_ops_attr
|
|
);
|
|
#endif
|
|
post_ops_attr.b_sum_offset += NR;
|
|
}
|
|
}
|
|
}
|
|
if ( mtag_b == REORDERED )
|
|
{
|
|
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
|
|
}
|
|
}
|
|
|
|
// Release pack buffers.
|
|
if ( mtag_b == PACK )
|
|
{
|
|
// All threads in work group should wait till B matrix usage is
|
|
// completed by the participating threads.
|
|
bli_thrcomm_barrier
|
|
(
|
|
bli_thread_ocomm_id( &thread_jc ),
|
|
&thread->comm[bli_thread_work_id( &thread_jc)]
|
|
);
|
|
|
|
if ( bli_thread_am_ochief( &thread_ic ) )
|
|
{
|
|
if ( bli_mem_is_alloc( &mem_b ) )
|
|
{
|
|
bli_pba_release( rntm, &mem_b );
|
|
}
|
|
}
|
|
}
|
|
if ( mtag_a == PACK )
|
|
{
|
|
if ( bli_mem_is_alloc( &mem_a ) )
|
|
{
|
|
bli_pba_release( rntm, &mem_a );
|
|
}
|
|
}
|
|
if ( c_downscale < F32 )
|
|
{
|
|
if ( bli_mem_is_alloc( &mem_scale_c ) )
|
|
{
|
|
bli_pba_release( rntm, &mem_scale_c );
|
|
}
|
|
}
|
|
}
|