mirror of
https://github.com/amd/blis.git
synced 2026-04-19 15:18:52 +00:00
Deprecate S16 LPGEMM APIs.
-The following S16 APIs are removed: 1. aocl_gemm_u8s8s16os16 2. aocl_gemm_u8s8s16os8 3. aocl_gemm_u8s8s16ou8 4. aocl_gemm_s8s8s16os16 5. aocl_gemm_s8s8s16os8 along with the associated reorder APIs and corresponding framework elements. AMD-Internal: [CPUPL-6412] Change-Id: I251f8b02a4cba5110615ddeb977d86f5c949363b
This commit is contained in:
@@ -45,13 +45,10 @@
|
||||
#include "lpgemm_eltwise_ops_kernels.h"
|
||||
#include "lpgemm_utils_kernels.h"
|
||||
#include "lpgemm_pack_bf16.h"
|
||||
#include "lpgemm_packb_s16.h"
|
||||
#include "lpgemm_packa_s16.h"
|
||||
#include "lpgemm_packa.h"
|
||||
#include "lpgemm_packb.h"
|
||||
#include "lpgemm_packa_s8.h"
|
||||
#include "lpgemm_packb_s8.h"
|
||||
#include "lpgemm_packb_s8s16.h"
|
||||
#include "lpgemm_pack_f32.h"
|
||||
#include "lpgemm_jit_typedefs.h"
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
|
||||
@@ -51,10 +51,8 @@ BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_ ## LP_SFX \
|
||||
|
||||
AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32);
|
||||
AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s32os32);
|
||||
AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16);
|
||||
AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32);
|
||||
AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s32os32);
|
||||
AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s16os16);
|
||||
AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s4s32os32);
|
||||
AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16s4f32of32);
|
||||
|
||||
@@ -76,11 +74,9 @@ BLIS_EXPORT_ADDON void aocl_reorder_ ## LP_SFX \
|
||||
|
||||
AOCL_GEMM_REORDER(float,f32f32f32of32);
|
||||
AOCL_GEMM_REORDER(int8_t,u8s8s32os32);
|
||||
AOCL_GEMM_REORDER(int8_t,u8s8s16os16);
|
||||
AOCL_GEMM_REORDER(bfloat16,bf16bf16f32of32);
|
||||
AOCL_GEMM_REORDER(bfloat16,bf16bf16f32of32_reference);
|
||||
AOCL_GEMM_REORDER(int8_t,s8s8s32os32);
|
||||
AOCL_GEMM_REORDER(int8_t,s8s8s16os16);
|
||||
AOCL_GEMM_REORDER(int8_t,u8s4s32os32);
|
||||
AOCL_GEMM_REORDER(int8_t, bf16s4f32of32);
|
||||
|
||||
@@ -136,7 +132,6 @@ BLIS_EXPORT_ADDON void aocl_gemm_ ## LP_SFX \
|
||||
) \
|
||||
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32);
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16);
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8);
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,bfloat16,int32_t,u8s8s32obf16);
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,float,int32_t,u8s8s32of32);
|
||||
@@ -148,11 +143,6 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16);
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,float,int32_t,s8s8s32of32);
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8);
|
||||
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16);
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8);
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8);
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8);
|
||||
|
||||
AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16);
|
||||
AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32);
|
||||
AOCL_GEMM_MATMUL(bfloat16, int8_t, float, float, bf16s4f32of32);
|
||||
|
||||
@@ -1,200 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "aocl_gemm_interface_apis.h"
|
||||
#include "aocl_gemm_check.h"
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_thread_decor_openmp.h"
|
||||
#include "lpgemm_post_ops.h"
|
||||
#include "lpgemm_utils_s8.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"s8s8s16os16", \
|
||||
order, transa, transb, \
|
||||
m, n, k, \
|
||||
( ( float ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
( ( float ) beta ), \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it.
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX2 ISA not supported by processor, "
|
||||
"cannot perform s8s8s16 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
AOCL_GEMM_CHECK
|
||||
(
|
||||
"s8s8s16os16",
|
||||
order, transa, transb,
|
||||
m, n, k,
|
||||
a, lda, mem_format_a,
|
||||
b, ldb, mem_format_b,
|
||||
c, ldc,
|
||||
err_no
|
||||
);
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
|
||||
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
|
||||
|
||||
/* Perform BLAS parameter checking. */
|
||||
// Transpose not supported.
|
||||
if ( ( blis_transb != BLIS_NO_TRANSPOSE ) )
|
||||
{
|
||||
bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
if ( ( order != 'r' ) && ( order != 'R' ) )
|
||||
{
|
||||
bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
inc_t rs_a = lda;
|
||||
inc_t cs_a = 1;
|
||||
inc_t rs_b = ldb;
|
||||
inc_t cs_b = 1;
|
||||
const inc_t rs_c = ldc;
|
||||
const inc_t cs_c = 1;
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a;
|
||||
AOCL_MEMORY_TAG mtag_b;
|
||||
|
||||
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
|
||||
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
|
||||
|
||||
// Pack is enabled for row major storage when trans A is true.
|
||||
// Pack tranforms column major matrix to row-major storage as kernel
|
||||
// expects A matrix to be in row-major format.
|
||||
if ( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a = 1;
|
||||
cs_a = lda;
|
||||
mtag_a = PACK;
|
||||
}
|
||||
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in VNNI instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if (mtag_b == UNPACKED)
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
|
||||
// Only unpacked A supported now for row-major A matrix.
|
||||
if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) )
|
||||
{
|
||||
bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed, post_op_list,
|
||||
( void* )c, ( void* )( &order ),
|
||||
m, n
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global(&rntm_g);
|
||||
bli_pba_rntm_set_pba(&rntm_g);
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S16OS16 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
lpgemm_s8s8s16o16_openmp_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S16
|
||||
);
|
||||
#else
|
||||
lpgemm_s8s8s16o16_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S16
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
@@ -1,178 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "aocl_gemm_interface_apis.h"
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_utils_s8.h"
|
||||
#include "lpgemm_reorder_s8s16.h"
|
||||
|
||||
AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s16os16)
|
||||
{
|
||||
if ((k <= 0) || (n <= 0))
|
||||
{
|
||||
return 0; // Error.
|
||||
}
|
||||
|
||||
// Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it.
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX2 ISA not supported by processor, "
|
||||
"cannot perform s8s8s16 gemm.", __FILE__, __LINE__ );
|
||||
return 0; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
AOCL_MATRIX_TYPE input_mat_type;
|
||||
bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type);
|
||||
|
||||
if (input_mat_type == A_MATRIX)
|
||||
{
|
||||
return 0; // A reorder not supported.
|
||||
}
|
||||
|
||||
|
||||
dim_t n_reorder;
|
||||
if( n == 1 )
|
||||
{
|
||||
n_reorder = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
n_reorder = make_multiple_of_n( n, 16 );
|
||||
|
||||
}
|
||||
|
||||
// Extra space since packing does length in multiples of 4.
|
||||
dim_t k_reorder;
|
||||
if( n == 1 )
|
||||
{
|
||||
k_reorder = k;
|
||||
}
|
||||
else
|
||||
{
|
||||
k_reorder = make_multiple_of_n( k, 4 );
|
||||
}
|
||||
|
||||
// Extra memory of n_reorder * sizeof( int16_t )
|
||||
// to store sum of every column of B matrix buffer
|
||||
siz_t size_req = sizeof(int8_t) * k_reorder * n_reorder
|
||||
+ ( n_reorder * sizeof( int16_t ));
|
||||
|
||||
return size_req;
|
||||
}
|
||||
|
||||
AOCL_GEMM_REORDER(int8_t,s8s8s16os16)
|
||||
{
|
||||
if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) ||
|
||||
(k <= 0) || (n <= 0) || (ldb < n))
|
||||
{
|
||||
return; // Error.
|
||||
}
|
||||
|
||||
trans_t blis_trans;
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(trans, &blis_trans);
|
||||
|
||||
if( bli_is_trans( blis_trans ) )
|
||||
{
|
||||
bli_print_msg(" Transpose of matrix is not supported in "
|
||||
"s8s8s16 gemm.", __FILE__, __LINE__ );
|
||||
return; // Error.
|
||||
}
|
||||
// Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it.
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX2 ISA not supported by processor, "
|
||||
"cannot perform s8s8s16 gemm.", __FILE__, __LINE__ );
|
||||
return; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
AOCL_MATRIX_TYPE input_mat_type;
|
||||
bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type);
|
||||
|
||||
if (input_mat_type == A_MATRIX)
|
||||
{
|
||||
return; // A reorder not supported.
|
||||
}
|
||||
|
||||
if( n == 1 )
|
||||
{
|
||||
int16_t* pack_b_column_sum = ( int16_t* ) ( reorder_buf_addr +
|
||||
( sizeof( int8_t ) * n * k ));
|
||||
|
||||
*pack_b_column_sum = 0;
|
||||
|
||||
for( dim_t k0 = 0; k0 < k; k0++ )
|
||||
{
|
||||
reorder_buf_addr[k0] = input_buf_addr[ k0 * ldb ];
|
||||
*pack_b_column_sum += reorder_buf_addr[k0];
|
||||
}
|
||||
*pack_b_column_sum *= 128;
|
||||
return;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global(&rntm_g);
|
||||
bli_pba_rntm_set_pba(&rntm_g);
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S16OS16 );
|
||||
|
||||
// Create dummy b_reorder obj.
|
||||
lpgemm_obj_t b_reorder;
|
||||
b_reorder.storage.aligned_buffer = reorder_buf_addr;
|
||||
|
||||
// Create dummy original b obj;
|
||||
lpgemm_obj_t b;
|
||||
b.storage.aligned_buffer = (void *)input_buf_addr;
|
||||
b.rs = ldb;
|
||||
b.width = n;
|
||||
b.length = k;
|
||||
|
||||
aocl_reorderb_nr32_s8s8s16o16( &b, &b_reorder, &rntm_g, lcntx_g );
|
||||
}
|
||||
@@ -1,200 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "aocl_gemm_interface_apis.h"
|
||||
#include "aocl_gemm_check.h"
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_thread_decor_openmp.h"
|
||||
#include "lpgemm_post_ops.h"
|
||||
#include "lpgemm_utils_s8.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"s8s8s16os8", \
|
||||
order, transa, transb, \
|
||||
m, n, k, \
|
||||
( ( float ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
( ( float ) beta ), \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it.
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX2 ISA not supported by processor, "
|
||||
"cannot perform s8s8s16 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
AOCL_GEMM_CHECK
|
||||
(
|
||||
"s8s8s16os8",
|
||||
order, transa, transb,
|
||||
m, n, k,
|
||||
a, lda, mem_format_a,
|
||||
b, ldb, mem_format_b,
|
||||
c, ldc,
|
||||
err_no
|
||||
);
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
|
||||
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
|
||||
|
||||
/* Perform BLAS parameter checking. */
|
||||
// Transpose not supported.
|
||||
if ( ( blis_transb != BLIS_NO_TRANSPOSE ) )
|
||||
{
|
||||
bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
if ( ( order != 'r' ) && ( order != 'R' ) )
|
||||
{
|
||||
bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
inc_t rs_a = lda;
|
||||
inc_t cs_a = 1;
|
||||
inc_t rs_b = ldb;
|
||||
inc_t cs_b = 1;
|
||||
const inc_t rs_c = ldc;
|
||||
const inc_t cs_c = 1;
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a;
|
||||
AOCL_MEMORY_TAG mtag_b;
|
||||
|
||||
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
|
||||
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
|
||||
|
||||
// Pack is enabled for row major storage when trans A is true.
|
||||
// Pack tranforms column major matrix to row-major storage as kernel
|
||||
// expects A matrix to be in row-major format.
|
||||
if ( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a = 1;
|
||||
cs_a = lda;
|
||||
mtag_a = PACK;
|
||||
}
|
||||
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in VNNI instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if (mtag_b == UNPACKED)
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
|
||||
// Only unpacked A supported now for row-major A matrix.
|
||||
if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) )
|
||||
{
|
||||
bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed, post_op_list,
|
||||
( void* )c, ( void* )( &order ),
|
||||
m, n
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global(&rntm_g);
|
||||
bli_pba_rntm_set_pba(&rntm_g);
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S16OS16 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
lpgemm_s8s8s16o16_openmp_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
( int16_t* )c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S8
|
||||
);
|
||||
#else
|
||||
lpgemm_s8s8s16o16_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
( int16_t* )c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S8
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
@@ -1,200 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "aocl_gemm_interface_apis.h"
|
||||
#include "aocl_gemm_check.h"
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_thread_decor_openmp.h"
|
||||
#include "lpgemm_post_ops.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"u8s8s16os16", \
|
||||
order, transa, transb, \
|
||||
m, n, k, \
|
||||
( ( float ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
( ( float ) beta ), \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it.
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX2 ISA not supported by processor, "
|
||||
"cannot perform u8s8s16 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
AOCL_GEMM_CHECK
|
||||
(
|
||||
"u8s8s16os16",
|
||||
order, transa, transb,
|
||||
m, n, k,
|
||||
a, lda, mem_format_a,
|
||||
b, ldb, mem_format_b,
|
||||
c, ldc,
|
||||
err_no
|
||||
);
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
|
||||
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
|
||||
|
||||
/* Perform BLAS parameter checking. */
|
||||
// Transpose not supported.
|
||||
if ( ( blis_transb != BLIS_NO_TRANSPOSE ) )
|
||||
{
|
||||
bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
if ( ( order != 'r' ) && ( order != 'R' ) )
|
||||
{
|
||||
bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
inc_t rs_a = lda;
|
||||
inc_t cs_a = 1;
|
||||
inc_t rs_b = ldb;
|
||||
inc_t cs_b = 1;
|
||||
const inc_t rs_c = ldc;
|
||||
const inc_t cs_c = 1;
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a;
|
||||
AOCL_MEMORY_TAG mtag_b;
|
||||
|
||||
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
|
||||
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
|
||||
|
||||
// Pack is enabled for row major storage when trans A is true.
|
||||
// Pack tranforms column major matrix to row-major storage as kernel
|
||||
// expects A matrix to be in row-major format.
|
||||
if ( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a = 1;
|
||||
cs_a = lda;
|
||||
mtag_a = PACK;
|
||||
}
|
||||
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in VNNI instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if (mtag_b == UNPACKED)
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
|
||||
// Only unpacked A supported now for row-major A matrix.
|
||||
if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) )
|
||||
{
|
||||
bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed, post_op_list,
|
||||
( void* )c, ( void* )( &order ),
|
||||
m, n
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global(&rntm_g);
|
||||
bli_pba_rntm_set_pba(&rntm_g);
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
lpgemm_u8s8s16o16_openmp_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S16
|
||||
);
|
||||
#else
|
||||
lpgemm_u8s8s16o16_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S16
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
@@ -1,169 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "aocl_gemm_interface_apis.h"
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_reorder_s16.h"
|
||||
|
||||
AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16)
|
||||
{
|
||||
if ((k <= 0) || (n <= 0))
|
||||
{
|
||||
return 0; // Error.
|
||||
}
|
||||
|
||||
// Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it.
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX2 ISA not supported by processor, "
|
||||
"cannot perform u8s8s16 gemm.", __FILE__, __LINE__ );
|
||||
return 0; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
AOCL_MATRIX_TYPE input_mat_type;
|
||||
bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type);
|
||||
|
||||
if (input_mat_type == A_MATRIX)
|
||||
{
|
||||
return 0; // A reorder not supported.
|
||||
}
|
||||
|
||||
// Extra space since packing does width in multiples of 16. The vpmaddubsw
|
||||
// instruction can be used as long as at least one ymm register can be fully
|
||||
// loaded; and since k_dim needs to be at least 2, having n_dim at least 16
|
||||
// should give 2x16=32 elements, enough for 1 ymm register.The padding is
|
||||
// not rounded to NR (=16), since that would result in memory wastage.
|
||||
|
||||
dim_t n_reorder;
|
||||
if( n == 1 )
|
||||
{
|
||||
n_reorder = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
n_reorder = make_multiple_of_n( n, 16 );
|
||||
}
|
||||
|
||||
dim_t k_reorder;
|
||||
if( n == 1 )
|
||||
{
|
||||
k_reorder = k;
|
||||
}
|
||||
else
|
||||
{
|
||||
k_reorder = make_multiple_of_n( k, 2 );
|
||||
}
|
||||
|
||||
siz_t size_req = sizeof(int8_t) * k_reorder * n_reorder;
|
||||
|
||||
return size_req;
|
||||
}
|
||||
|
||||
AOCL_GEMM_REORDER(int8_t,u8s8s16os16)
|
||||
{
|
||||
if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) ||
|
||||
(k <= 0) || (n <= 0) || (ldb < n))
|
||||
{
|
||||
return; // Error.
|
||||
}
|
||||
|
||||
// Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it.
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX2 ISA not supported by processor, "
|
||||
"cannot perform u8s8s16 gemm.", __FILE__, __LINE__ );
|
||||
return; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
AOCL_MATRIX_TYPE input_mat_type;
|
||||
bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type);
|
||||
|
||||
if (input_mat_type == A_MATRIX)
|
||||
{
|
||||
return; // A reorder not supported.
|
||||
}
|
||||
|
||||
if( n == 1 )
|
||||
{
|
||||
if (ldb == 1)
|
||||
{
|
||||
memcpy( reorder_buf_addr, input_buf_addr,
|
||||
( k * sizeof( int8_t ) ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
for( dim_t k0 = 0; k0 < k; k0++ )
|
||||
{
|
||||
reorder_buf_addr[k0] = input_buf_addr[k0 * ldb];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global(&rntm_g);
|
||||
bli_pba_rntm_set_pba(&rntm_g);
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 );
|
||||
|
||||
// Create dummy b_reorder obj.
|
||||
lpgemm_obj_t b_reorder;
|
||||
b_reorder.storage.aligned_buffer = reorder_buf_addr;
|
||||
|
||||
// Create dummy original b obj;
|
||||
lpgemm_obj_t b;
|
||||
b.storage.aligned_buffer = (void *)input_buf_addr;
|
||||
b.rs = ldb;
|
||||
b.width = n;
|
||||
b.length = k;
|
||||
|
||||
aocl_reorderb_nr32_u8s8s16o16( &b, &b_reorder, &rntm_g, lcntx_g );
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "aocl_gemm_interface_apis.h"
|
||||
#include "aocl_gemm_check.h"
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_thread_decor_openmp.h"
|
||||
#include "lpgemm_post_ops.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"u8s8s16os8", \
|
||||
order, transa, transb, \
|
||||
m, n, k, \
|
||||
( ( float ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
( ( float ) beta ), \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it.
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX2 ISA not supported by processor, "
|
||||
"cannot perform u8s8s16 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
AOCL_GEMM_CHECK
|
||||
(
|
||||
"u8s8s16os8",
|
||||
order, transa, transb,
|
||||
m, n, k,
|
||||
a, lda, mem_format_a,
|
||||
b, ldb, mem_format_b,
|
||||
c, ldc,
|
||||
err_no
|
||||
);
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
|
||||
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
|
||||
|
||||
/* Perform BLAS parameter checking. */
|
||||
// Transpose not supported.
|
||||
if ( ( blis_transb != BLIS_NO_TRANSPOSE ) )
|
||||
{
|
||||
bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
if ( ( order != 'r' ) && ( order != 'R' ) )
|
||||
{
|
||||
bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
inc_t rs_a = lda;
|
||||
inc_t cs_a = 1;
|
||||
inc_t rs_b = ldb;
|
||||
inc_t cs_b = 1;
|
||||
const inc_t rs_c = ldc;
|
||||
const inc_t cs_c = 1;
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a;
|
||||
AOCL_MEMORY_TAG mtag_b;
|
||||
|
||||
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
|
||||
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
|
||||
|
||||
// Pack is enabled for row major storage when trans A is true.
|
||||
// Pack tranforms column major matrix to row-major storage as kernel
|
||||
// expects A matrix to be in row-major format.
|
||||
if ( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a = 1;
|
||||
cs_a = lda;
|
||||
mtag_a = PACK;
|
||||
}
|
||||
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in VNNI instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if (mtag_b == UNPACKED)
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
|
||||
// Only unpacked A supported now for row-major A matrix.
|
||||
if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) )
|
||||
{
|
||||
bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed, post_op_list,
|
||||
( void* )c, ( void* )( &order ),
|
||||
m, n
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global(&rntm_g);
|
||||
bli_pba_rntm_set_pba(&rntm_g);
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
lpgemm_u8s8s16o16_openmp_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
( int16_t* )c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S8
|
||||
);
|
||||
#else
|
||||
lpgemm_u8s8s16o16_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
( int16_t* )c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S8
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "aocl_gemm_interface_apis.h"
|
||||
#include "aocl_gemm_check.h"
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_thread_decor_openmp.h"
|
||||
#include "lpgemm_post_ops.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"u8s8s16ou8", \
|
||||
order, transa, transb, \
|
||||
m, n, k, \
|
||||
( ( float ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
( ( float ) beta ), \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it.
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX2 ISA not supported by processor, "
|
||||
"cannot perform u8s8s16 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
AOCL_GEMM_CHECK
|
||||
(
|
||||
"u8s8s16ou8",
|
||||
order, transa, transb,
|
||||
m, n, k,
|
||||
a, lda, mem_format_a,
|
||||
b, ldb, mem_format_b,
|
||||
c, ldc,
|
||||
err_no
|
||||
);
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
|
||||
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
|
||||
|
||||
/* Perform BLAS parameter checking. */
|
||||
// Transpose not supported.
|
||||
if ( ( blis_transb != BLIS_NO_TRANSPOSE ) )
|
||||
{
|
||||
bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
if ( ( order != 'r' ) && ( order != 'R' ) )
|
||||
{
|
||||
bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
inc_t rs_a = lda;
|
||||
inc_t cs_a = 1;
|
||||
inc_t rs_b = ldb;
|
||||
inc_t cs_b = 1;
|
||||
const inc_t rs_c = ldc;
|
||||
const inc_t cs_c = 1;
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a;
|
||||
AOCL_MEMORY_TAG mtag_b;
|
||||
|
||||
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
|
||||
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
|
||||
|
||||
// Pack is enabled for row major storage when trans A is true.
|
||||
// Pack tranforms column major matrix to row-major storage as kernel
|
||||
// expects A matrix to be in row-major format.
|
||||
if ( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a = 1;
|
||||
cs_a = lda;
|
||||
mtag_a = PACK;
|
||||
}
|
||||
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in VNNI instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if (mtag_b == UNPACKED)
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
|
||||
// Only unpacked A supported now for row-major A matrix.
|
||||
if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) )
|
||||
{
|
||||
bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed, post_op_list,
|
||||
( void* )c, ( void* )( &order ),
|
||||
m, n
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global(&rntm_g);
|
||||
bli_pba_rntm_set_pba(&rntm_g);
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
lpgemm_u8s8s16o16_openmp_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
( int16_t* )c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, U8
|
||||
);
|
||||
#else
|
||||
lpgemm_u8s8s16o16_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
( int16_t* )c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, U8
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
@@ -39,23 +39,19 @@
|
||||
// ID = One of the AOCL_OPERATION_TYPE enum.
|
||||
|
||||
#define LPGEMM_BLKSZ_MAP_ZEN4 \
|
||||
XMACRO(U8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
|
||||
XMACRO(U8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(F32F32F32OF32, 192, 8064, 512, 6, 64, 1, 6, 64, 1) \
|
||||
XMACRO(BF16BF16F32OF32, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
XMACRO(BF16S4F32OF32, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(S8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
|
||||
XMACRO(U8S4S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(F32OBF16, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
|
||||
#define LPGEMM_BLKSZ_MAP_ZEN \
|
||||
XMACRO(U8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
|
||||
XMACRO(U8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(F32F32F32OF32, 144, 8160, 512, 6, 16, 1, 6, 16, 1) \
|
||||
XMACRO(BF16BF16F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(S8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
|
||||
XMACRO(U8S4S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
|
||||
XMACRO(BF16S4F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
XMACRO(F32OBF16, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
|
||||
|
||||
@@ -38,13 +38,10 @@
|
||||
#include "lpgemm_blksz_map.h"
|
||||
#include "lpgemm_kernels.h"
|
||||
#include "lpgemm_pack_bf16.h"
|
||||
#include "lpgemm_packb_s16.h"
|
||||
#include "lpgemm_packa_s16.h"
|
||||
#include "lpgemm_packa.h"
|
||||
#include "lpgemm_packb.h"
|
||||
#include "lpgemm_packa_s8.h"
|
||||
#include "lpgemm_packb_s8.h"
|
||||
#include "lpgemm_packb_s8s16.h"
|
||||
#include "lpgemm_pack_f32.h"
|
||||
#include "lpgemm_logger.h"
|
||||
#include "lpgemm_thread_utils.h"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -46,32 +46,26 @@
|
||||
|
||||
// AVX512 + VNNI + BF16
|
||||
#define LPGEMM_KERN_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
|
||||
KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \
|
||||
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_6x64m) \
|
||||
KMACRO(BF16BF16F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \
|
||||
KMACRO(BF16S4F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \
|
||||
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
|
||||
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
|
||||
|
||||
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2 \
|
||||
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
|
||||
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
|
||||
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
|
||||
PAMACRO(S8S8S16OS16, packa_u8s8s16os16)
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
|
||||
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
|
||||
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
|
||||
PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \
|
||||
PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \
|
||||
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
|
||||
PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \
|
||||
PBMACRO(BF16S4F32OF32, packb_nr64_bf16s4f32of32)
|
||||
|
||||
@@ -85,11 +79,9 @@
|
||||
UBMACRO(BF16BF16F32OF32, unpackb_nr64_bf16bf16f32of32)
|
||||
|
||||
#define LPGEMM_PACKSCLB_FUNC_MAP_AVX512_VNNI_BF16 \
|
||||
PBSMACRO(U8S8S16OS16, NULL) \
|
||||
PBSMACRO(U8S8S32OS32, NULL) \
|
||||
PBSMACRO(BF16BF16F32OF32, NULL) \
|
||||
PBSMACRO(S8S8S32OS32, NULL) \
|
||||
PBSMACRO(S8S8S16OS16, NULL) \
|
||||
PBSMACRO(U8S4S32OS32, NULL) \
|
||||
PBSMACRO(BF16S4F32OF32, packsclb_nr64_bf16s4f32of32) \
|
||||
|
||||
@@ -105,35 +97,29 @@
|
||||
|
||||
// AVX512 + VNNI
|
||||
#define LPGEMM_KERN_FUNC_MAP_AVX512_VNNI \
|
||||
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
|
||||
KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \
|
||||
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_6x64m) \
|
||||
KMACRO(BF16BF16F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \
|
||||
KMACRO(BF16S4F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \
|
||||
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
|
||||
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
|
||||
|
||||
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
|
||||
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI \
|
||||
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
|
||||
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
|
||||
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
|
||||
PAMACRO(S8S8S16OS16, packa_u8s8s16os16)
|
||||
|
||||
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI \
|
||||
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI \
|
||||
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
|
||||
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
|
||||
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
|
||||
PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \
|
||||
PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \
|
||||
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
|
||||
PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \
|
||||
PBSMACRO(BF16S4F32OF32, packb_nr64_bf16s4f32of32)
|
||||
|
||||
@@ -148,32 +134,26 @@
|
||||
|
||||
// AVX512
|
||||
#define LPGEMM_KERN_FUNC_MAP_AVX512 \
|
||||
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
|
||||
KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \
|
||||
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_6x64m) \
|
||||
KMACRO(BF16BF16F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \
|
||||
KMACRO(BF16S4F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \
|
||||
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
|
||||
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
|
||||
|
||||
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_TO_AVX2 \
|
||||
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_MAP_AVX512 \
|
||||
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
|
||||
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
|
||||
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
|
||||
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
|
||||
PAMACRO(S8S8S16OS16, packa_u8s8s16os16) \
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX512 \
|
||||
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
|
||||
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
|
||||
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
|
||||
PBMACRO(BF16BF16F32OF32, NULL) \
|
||||
PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \
|
||||
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
|
||||
PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \
|
||||
PBMACRO(BF16S4F32OF32, NULL) \
|
||||
PBSMACRO(BF16S4F32OF32, NULL)
|
||||
@@ -189,30 +169,24 @@
|
||||
|
||||
// AVX2
|
||||
#define LPGEMM_KERN_FUNC_MAP_AVX2 \
|
||||
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
|
||||
KMACRO(U8S8S32OS32, NULL) \
|
||||
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
|
||||
KMACRO(BF16BF16F32OF32, NULL) \
|
||||
KMACRO(BF16S4F32OF32, NULL) \
|
||||
KMACRO(S8S8S32OS32, NULL) \
|
||||
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
|
||||
|
||||
#define LPGEMM_PACKA_FUNC_MAP_AVX2 \
|
||||
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
|
||||
PAMACRO(U8S8S32OS32, NULL) \
|
||||
PAMACRO(BF16BF16F32OF32, NULL) \
|
||||
KMACRO(BF16S4F32OF32, NULL) \
|
||||
PAMACRO(S8S8S32OS32, NULL) \
|
||||
PAMACRO(S8S8S16OS16, packa_u8s8s16os16) \
|
||||
|
||||
#define LPGEMM_PACKB_FUNC_MAP_AVX2 \
|
||||
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
|
||||
PBMACRO(U8S8S32OS32, NULL) \
|
||||
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
|
||||
PBMACRO(BF16BF16F32OF32, NULL) \
|
||||
KMACRO(BF16S4F32OF32, NULL) \
|
||||
PBMACRO(S8S8S32OS32, NULL) \
|
||||
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
|
||||
PBMACRO(U8S4S32OS32, NULL) \
|
||||
PBSMACRO(BF16S4F32OF32, NULL) \
|
||||
|
||||
|
||||
@@ -93,11 +93,9 @@ void lpgemm_rowvar_ ## LP_SFX \
|
||||
) \
|
||||
|
||||
LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32);
|
||||
LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16);
|
||||
LPGEMM_5LOOP(float,float,float,f32f32f32of32);
|
||||
LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32);
|
||||
LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32);
|
||||
LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16);
|
||||
|
||||
#define LPGEMM_5LOOP1(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
#include "blis.h"
|
||||
#include "lpgemm_utils_s8.h"
|
||||
#include "lpgemm_reorder_s8s16.h"
|
||||
#include "lpgemm_packb_s8s16.h"
|
||||
#include "lpgemm_config.h"
|
||||
|
||||
void aocl_reorderb_nr32_s8s8s16o16
|
||||
(
|
||||
lpgemm_obj_t* b,
|
||||
lpgemm_obj_t* b_reorder,
|
||||
rntm_t* rntm,
|
||||
lpgemm_cntx_t* lcntx
|
||||
)
|
||||
{
|
||||
dim_t NC = lcntx->blksz.NC;
|
||||
dim_t KC = lcntx->blksz.KC;
|
||||
dim_t NR = lcntx->blksz.NR;
|
||||
|
||||
// Extracting the matrix properties from the lpgemm object
|
||||
dim_t rs_b = b->rs;
|
||||
dim_t n = b->width;
|
||||
dim_t k = b->length;
|
||||
|
||||
lpgemm_mod_block_size_s16(0, n, k, NULL, &NC, &KC);
|
||||
|
||||
dim_t rs_b_reorder;
|
||||
dim_t cs_b_reorder;
|
||||
|
||||
dim_t k_updated = k;
|
||||
|
||||
// Making multiple of 2 to suit k in vpmaddubsw
|
||||
k_updated += (k_updated & 0x1);
|
||||
|
||||
dim_t n_updated = make_multiple_of_n( n, 16 );
|
||||
|
||||
dim_t n_threads = bli_rntm_num_threads( rntm );
|
||||
n_threads = ( n_threads > 0 ) ? n_threads : 1;
|
||||
|
||||
// To access the last row of B matrix - Column sum of B matrix
|
||||
int16_t* pack_b_column_sum = ( int16_t* ) ( b_reorder->storage.aligned_buffer + ( sizeof( int8_t ) * n_updated * k_updated ));
|
||||
for (dim_t idx = 0; idx < n_updated; idx++ )
|
||||
{
|
||||
*( pack_b_column_sum + idx ) = 0;
|
||||
}
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
{
|
||||
// Initialise a local thrinfo obj for work split across threads.
|
||||
thrinfo_t thread_jc;
|
||||
bli_thrinfo_set_n_way( n_threads, &thread_jc );
|
||||
bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc );
|
||||
#else
|
||||
{
|
||||
// Initialise a local thrinfo obj for work split across threads.
|
||||
thrinfo_t thread_jc;
|
||||
bli_thrinfo_set_n_way( 1, &thread_jc );
|
||||
bli_thrinfo_set_work_id( 0, &thread_jc );
|
||||
#endif
|
||||
// Compute the JC loop thread range for the current thread.
|
||||
dim_t jc_start, jc_end;
|
||||
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
|
||||
|
||||
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
|
||||
{
|
||||
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
|
||||
|
||||
dim_t jc_cur_loop = jc;
|
||||
dim_t jc_cur_loop_rem = 0;
|
||||
dim_t n_sub_updated;
|
||||
|
||||
get_B_panel_reordered_start_offset_width
|
||||
(
|
||||
jc, n, NC, 16,
|
||||
&jc_cur_loop, &jc_cur_loop_rem,
|
||||
&nc0, &n_sub_updated
|
||||
);
|
||||
|
||||
for ( dim_t pc = 0; pc < k; pc += KC )
|
||||
{
|
||||
dim_t kc0 = bli_min( ( k - pc ), KC );
|
||||
|
||||
// kc0 needs to be a multiple of 2 so that it can be used with
|
||||
// vmaddubsw instruction. Padding is added in cases this
|
||||
// condition is not satisfied, and therefore the kc0 offsets
|
||||
// used for packed/reordered buffers needs to be updated.
|
||||
dim_t kc0_updated = make_multiple_of_n( kc0, 2 );
|
||||
|
||||
// The offsets are calculated in such a way that it resembles
|
||||
// the reorder buffer traversal in single threaded reordering.
|
||||
// The panel boundaries (KCxNC) remain as it is accessed in
|
||||
// single thread, and as a consequence a thread with jc_start
|
||||
// inside the panel cannot consider NC range for reorder. It
|
||||
// has to work with NC' < NC, and the offset is calulated using
|
||||
// prev NC panels spanning k dim + cur NC panel spaning pc loop
|
||||
// cur iteration + (NC - NC') spanning current kc0 (<= KC).
|
||||
//
|
||||
//Eg: Consider the following reordered buffer diagram:
|
||||
// t1 t2
|
||||
// | |
|
||||
// | |..NC..|
|
||||
// | | |
|
||||
// |.NC. |.NC. |NC'|NC"
|
||||
// pc=0-+-----+-----+---+--+
|
||||
// KC| | | | |
|
||||
// | 1 | 3 | 5 |
|
||||
// pc=KC-+-----+-----+---st-+
|
||||
// KC| | | | |
|
||||
// | 2 | 4 | 6 | 7|
|
||||
// pc=k=2KC-+-----+-----+---+--+
|
||||
// |jc=0 |jc=NC|jc=2NC|
|
||||
//
|
||||
// The numbers 1,2..6,7 denotes the order in which reordered
|
||||
// KCxNC blocks are stored in memory, ie: block 1 followed by 2
|
||||
// followed by 3, etc. Given two threads t1 and t2, and t2 needs
|
||||
// to acces point st in the reorder buffer to write the data:
|
||||
// The offset calulation logic will be:
|
||||
// jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC,
|
||||
// n_sub_updated = NC, k = 2KC, kc0_updated = KC
|
||||
//
|
||||
// st = ( jc_cur_loop * k ) <traverse blocks 1,2,3,4>
|
||||
// + ( n_sub_updated * pc ) <traverse block 5>
|
||||
// + ( NC' * kc0_updated) <traverse block 6>
|
||||
( ( packb_s16_s8 )lcntx->packb_fun_ptr )
|
||||
(
|
||||
( ( ( int8_t* )b_reorder->storage.aligned_buffer ) +
|
||||
( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) +
|
||||
( jc_cur_loop_rem * kc0_updated ) ),
|
||||
pack_b_column_sum + jc,
|
||||
( ( ( int8_t* )b->storage.aligned_buffer ) +
|
||||
( rs_b * pc ) + jc ),
|
||||
rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder
|
||||
);
|
||||
}
|
||||
|
||||
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
|
||||
}
|
||||
}
|
||||
|
||||
// Changing the packed matrix properties in the packed matrix object
|
||||
b_reorder->rs = rs_b_reorder;
|
||||
b_reorder->cs = cs_b_reorder;
|
||||
b_reorder->mtag = REORDERED;
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
#ifndef LPGEMM_REORDER_S8S16_H
|
||||
#define LPGEMM_REORDER_S8S16_H
|
||||
|
||||
#include "lpgemm_types.h"
|
||||
|
||||
void aocl_reorderb_nr32_s8s8s16o16
|
||||
(
|
||||
lpgemm_obj_t* b,
|
||||
lpgemm_obj_t* b_reorder,
|
||||
rntm_t* rntm,
|
||||
lpgemm_cntx_t* lcntx
|
||||
);
|
||||
|
||||
#endif // LPGEMM_REORDER_S8S16_H
|
||||
@@ -1,609 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_packb_s8s16.h"
|
||||
#include "lpgemm_kernels.h"
|
||||
#include "lpgemm_utils_s8.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_thrinfo_utils.h"
|
||||
#include "lpgemm_packa_s16.h"
|
||||
|
||||
// Kernel function prototypes
|
||||
typedef void (*lpgemm_rowvar_s16_s8)
|
||||
(
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const int8_t*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const int8_t*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
int16_t*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const int16_t,
|
||||
const int16_t,
|
||||
lpgemm_post_op*,
|
||||
lpgemm_post_op_attr
|
||||
);
|
||||
|
||||
|
||||
|
||||
LPGEMV(int8_t,int8_t,int16_t,s8s8s16os16)
|
||||
{
|
||||
dim_t KC = lcntx->blksz.KC;
|
||||
dim_t MC = lcntx->blksz.MC;
|
||||
|
||||
// Strides are updated based on matrix packing/reordering.
|
||||
int8_t* a_use = ( int8_t* )a;
|
||||
inc_t rs_a_use = rs_a;
|
||||
inc_t cs_a_use = cs_a;
|
||||
|
||||
int8_t* b_use = ( int8_t* )b;
|
||||
inc_t rs_b_use = rs_b;
|
||||
inc_t cs_b_use = cs_b;
|
||||
|
||||
int16_t *c_use = NULL;
|
||||
|
||||
lpgemm_post_op_attr post_ops_attr;
|
||||
post_ops_attr.c_stor_type = c_downscale;
|
||||
if (c_downscale < S16) post_ops_attr.buf_downscale = c;
|
||||
else post_ops_attr.buf_downscale = NULL;
|
||||
|
||||
siz_t mem_a_size_req = 0;
|
||||
siz_t mem_b_size_req = 0;
|
||||
|
||||
mem_t mem_a = BLIS_MEM_INITIALIZER;
|
||||
mem_t mem_b = BLIS_MEM_INITIALIZER;
|
||||
|
||||
int8_t* pack_a_buffer;
|
||||
int8_t* pack_b_buffer;
|
||||
|
||||
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
|
||||
thrinfo_t thread_jc;
|
||||
thrinfo_t thread_ic;
|
||||
|
||||
lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic );
|
||||
|
||||
// Increased MR from 6 to 8 to make use of 16 ymm regs
|
||||
dim_t MR = 8;
|
||||
|
||||
// Pack B matrix if rs_b > 1
|
||||
if( ( mtag_b == PACK ) )
|
||||
{
|
||||
mem_b_size_req = sizeof( int8_t ) * k + sizeof( int16_t );
|
||||
|
||||
lpgemm_alloc_mem_panel
|
||||
(
|
||||
mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE,
|
||||
&mem_b, rntm
|
||||
);
|
||||
|
||||
pack_b_buffer = ( int8_t* ) bli_mem_buffer( &mem_b );
|
||||
|
||||
int16_t* pack_b_column_sum = ( int16_t* ) ( pack_b_buffer +
|
||||
( sizeof( int8_t ) * k ));
|
||||
|
||||
*pack_b_column_sum = 0;
|
||||
|
||||
for( dim_t k0 = 0; k0 < k; k0++ )
|
||||
{
|
||||
pack_b_buffer[k0] = b[ k0*rs_b ];
|
||||
*pack_b_column_sum += pack_b_buffer[k0];
|
||||
}
|
||||
*pack_b_column_sum *= 128;
|
||||
post_ops_attr.b_col_sum_vec_s16 = pack_b_column_sum;
|
||||
|
||||
b_use = pack_b_buffer;
|
||||
rs_b_use = 1;
|
||||
cs_b_use = 1;
|
||||
}
|
||||
else if ( mtag_b == REORDERED )
|
||||
{
|
||||
post_ops_attr.b_col_sum_vec_s16 = ( int16_t* ) ( b + k );
|
||||
}
|
||||
|
||||
// Compute the IC loop thread range for the current thread.
|
||||
dim_t ic_start, ic_end;
|
||||
thread_ic.n_way = ( thread_ic.n_way == 1 ) ?
|
||||
( thread->n_threads ) : ( thread_ic.n_way );
|
||||
thread_ic.work_id = thread->tid;
|
||||
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
|
||||
|
||||
for (dim_t ic = ic_start; ic < ic_end; ic += MC)
|
||||
{
|
||||
dim_t mc0 = bli_min((ic_end - ic), MC);
|
||||
|
||||
a_use = (int8_t*)a + ic * rs_a;
|
||||
|
||||
c_use = c + ic * rs_c;
|
||||
|
||||
post_ops_attr.post_op_c_i = ic;
|
||||
post_ops_attr.post_op_c_j = 0;
|
||||
post_ops_attr.rs_c_downscale = rs_c;
|
||||
|
||||
if( mtag_a == PACK )
|
||||
{
|
||||
mem_a_size_req = sizeof( int8_t ) * mc0 * k;
|
||||
|
||||
lpgemm_alloc_mem_panel
|
||||
(
|
||||
mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE,
|
||||
&mem_a, rntm
|
||||
);
|
||||
|
||||
pack_a_buffer = ( int8_t* ) bli_mem_buffer( &mem_a );
|
||||
|
||||
( ( packa_s16 ) lcntx->packa_fun_ptr )
|
||||
(
|
||||
( uint8_t* )pack_a_buffer,
|
||||
( uint8_t* )( a + ( rs_a * ic )), rs_a, cs_a,
|
||||
mc0, k,
|
||||
&rs_a_use, &cs_a_use
|
||||
);
|
||||
a_use = pack_a_buffer;
|
||||
}
|
||||
|
||||
// Call lpgemv_n_one kernel
|
||||
lpgemv_n_one_s8s8s16os16
|
||||
(
|
||||
mc0, k,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a,
|
||||
b_use, rs_b_use, cs_b_use, mtag_b,
|
||||
c_use, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
MR, KC,
|
||||
post_op_list,
|
||||
&post_ops_attr
|
||||
);
|
||||
}
|
||||
|
||||
// Release pack buffers
|
||||
if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) )
|
||||
{
|
||||
bli_pba_release(rntm, &mem_a);
|
||||
}
|
||||
if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) )
|
||||
{
|
||||
bli_pba_release(rntm, &mem_b);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// B should always be packed.
|
||||
LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16)
|
||||
{
|
||||
dim_t NC = lcntx->blksz.NC;
|
||||
dim_t KC = lcntx->blksz.KC;
|
||||
dim_t MC = lcntx->blksz.MC;
|
||||
const dim_t NR = lcntx->blksz.NR;
|
||||
const dim_t MR = lcntx->blksz.MR;
|
||||
|
||||
lpgemm_mod_block_size_s16(m, n, k, &MC, &NC, &KC);
|
||||
|
||||
if (mtag_b == UNPACKED)
|
||||
{
|
||||
// Error: can only work with packed B now.
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
if( n == 1 )
|
||||
{
|
||||
lpgemv_rowvar_s8s8s16os16( m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha,
|
||||
beta,
|
||||
rntm,
|
||||
thread,
|
||||
lcntx,
|
||||
post_op_list,
|
||||
c_downscale );
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const int8_t *b_use;
|
||||
const int8_t *a_use;
|
||||
dim_t rs_a_use = rs_a;
|
||||
dim_t cs_a_use = cs_a;
|
||||
dim_t a_block_stride = 0;
|
||||
|
||||
dim_t rs_b_use = rs_b;
|
||||
dim_t cs_b_use = cs_b;
|
||||
|
||||
int16_t *c_use_jc = NULL;
|
||||
int16_t *c_use_ic = NULL;
|
||||
dim_t rs_c_use = rs_c;
|
||||
dim_t rs_c_downscale = rs_c;
|
||||
|
||||
// Pack buffer for A.
|
||||
int8_t* pack_a_buffer_s8s8s16o16;
|
||||
mem_t mem_a = BLIS_MEM_INITIALIZER;
|
||||
siz_t mem_a_size_req = 0;
|
||||
|
||||
// Pack buffer for B.
|
||||
int8_t *pack_b_buffer_s8s8s16o16;
|
||||
mem_t mem_b = BLIS_MEM_INITIALIZER;
|
||||
dim_t packb_min_NR = 16;
|
||||
siz_t mem_b_size_req = 0;
|
||||
|
||||
// Temporary buffer for C accumulation when downscaling is required.
|
||||
int16_t* temp_scal_c_buffer_s8s8s16o16;
|
||||
mem_t mem_scale_c = BLIS_MEM_INITIALIZER;
|
||||
siz_t mem_scale_c_size_req = 0;
|
||||
|
||||
// Making multiple of 2 to suit k in vpmaddubsw
|
||||
dim_t k_updated = make_multiple_of_n( k, 2 );
|
||||
|
||||
// Making multiple of 16
|
||||
dim_t n_updated = make_multiple_of_n( n, 16 );
|
||||
|
||||
// To decide whether to apply post ops or not.
|
||||
bool is_last_k = FALSE;
|
||||
|
||||
// To decide whether to use original s8 C or temp buffer for beta scale.
|
||||
bool is_first_k = FALSE;
|
||||
|
||||
lpgemm_post_op_attr post_ops_attr;
|
||||
post_ops_attr.c_stor_type = c_downscale;
|
||||
if ( c_downscale < S16 )
|
||||
{
|
||||
post_ops_attr.buf_downscale = c;
|
||||
}
|
||||
else
|
||||
{
|
||||
post_ops_attr.buf_downscale = NULL;
|
||||
}
|
||||
|
||||
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
|
||||
thrinfo_t thread_jc;
|
||||
thrinfo_t thread_ic;
|
||||
|
||||
lpgemm_gen_thrinfo(thread, &thread_jc, &thread_ic);
|
||||
|
||||
// Compute the JC, IC loop thread range for the current thread.
|
||||
dim_t jc_start, jc_end;
|
||||
bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end);
|
||||
|
||||
dim_t ic_start, ic_end;
|
||||
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
|
||||
|
||||
for (dim_t jc = jc_start; jc < jc_end; jc += NC)
|
||||
{
|
||||
dim_t nc0 = bli_min((jc_end - jc), NC);
|
||||
|
||||
dim_t jc_cur_loop = jc;
|
||||
dim_t jc_cur_loop_rem = 0;
|
||||
dim_t n_sub_updated = 0;
|
||||
|
||||
if (mtag_b == REORDERED)
|
||||
{
|
||||
get_B_panel_reordered_start_offset_width
|
||||
(
|
||||
jc, n, NC, packb_min_NR,
|
||||
&jc_cur_loop, &jc_cur_loop_rem,
|
||||
&nc0, &n_sub_updated
|
||||
);
|
||||
}
|
||||
|
||||
if ( c_downscale == S16 )
|
||||
{
|
||||
c_use_jc = c + jc;
|
||||
}
|
||||
// Temp accumulaton buffer for C allocation.
|
||||
else if ( c_downscale < S16 )
|
||||
{
|
||||
// Buffer memory is only required if output needs to be
|
||||
// persisted across iterations of the pc/KC loop.
|
||||
// It was observed that the locks used while checking out
|
||||
// a buffer from memory pool had an impact on performance
|
||||
// and is better to not checkout if k <= KC.
|
||||
if ( k > KC )
|
||||
{
|
||||
mem_scale_c_size_req = sizeof( int16_t ) * nc0 * ( ic_end - ic_start );
|
||||
|
||||
lpgemm_alloc_mem_panel
|
||||
(
|
||||
mem_scale_c_size_req, BLIS_BUFFER_FOR_GEN_USE,
|
||||
&mem_scale_c, rntm
|
||||
);
|
||||
|
||||
temp_scal_c_buffer_s8s8s16o16 = bli_mem_buffer( &mem_scale_c );
|
||||
|
||||
c_use_jc = ( int16_t* )temp_scal_c_buffer_s8s8s16o16;
|
||||
}
|
||||
|
||||
// The temp c buffer stride is modified as opposed to original C matrix.
|
||||
rs_c_use = nc0;
|
||||
}
|
||||
|
||||
int16_t* pack_b_column_sum = NULL;
|
||||
|
||||
for (dim_t pc = 0; pc < k; pc += KC)
|
||||
{
|
||||
int16_t beta0 = (pc == 0) ? beta : 1;
|
||||
dim_t kc0 = bli_min((k - pc), KC);
|
||||
|
||||
// No parallelization in k dim, k always starts at 0.
|
||||
is_first_k = ( pc == 0 ) ? ( TRUE ) : ( FALSE );
|
||||
post_ops_attr.is_first_k = is_first_k;
|
||||
|
||||
is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE );
|
||||
post_ops_attr.is_last_k = is_last_k;
|
||||
|
||||
// kc0 needs to be a multiple of 2 so that it can be
|
||||
// used with vpmaddubsw instruction. Padding is added in
|
||||
// cases this condition is not satisfied, and therefore
|
||||
// the kc0 offsets used for packed/reordered buffers
|
||||
// needs to be updated.
|
||||
dim_t kc0_updated = make_multiple_of_n(kc0, 2);
|
||||
|
||||
if (mtag_b == PACK)
|
||||
{
|
||||
// Pack B chunks are based on jc work id.
|
||||
dim_t jc_work_id = bli_thread_work_id(&thread_jc);
|
||||
|
||||
// Using child thrinfo (thread_ic) tid to decide chief thread
|
||||
// per B matrix chunk (jc work id group)
|
||||
|
||||
// nc0 needs to be a multiple of 16 since this gives maximum
|
||||
// vectorization. Packing B always results in buffers with width
|
||||
// which is a multiple of 16. Subsequently the nc0 offsets used
|
||||
// for packed/reordered buffers needs to be updated.
|
||||
dim_t nc0_updated = make_multiple_of_n(nc0, packb_min_NR);
|
||||
|
||||
if (bli_thread_am_ochief(&thread_ic))
|
||||
{
|
||||
mem_b_size_req = sizeof(int8_t) * nc0_updated * kc0_updated + ( nc0_updated * sizeof( int16_t ) );
|
||||
|
||||
lpgemm_alloc_mem_panel(
|
||||
mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL,
|
||||
&mem_b, rntm);
|
||||
|
||||
thread->comm[jc_work_id].sent_object =
|
||||
bli_mem_buffer(&mem_b);
|
||||
}
|
||||
|
||||
// All threads in work group should wait till chief thread has
|
||||
// finished allocating the packing buffers.
|
||||
bli_thrcomm_barrier
|
||||
(
|
||||
bli_thread_ocomm_id(&thread_ic),
|
||||
&thread->comm[jc_work_id]
|
||||
);
|
||||
|
||||
pack_b_buffer_s8s8s16o16 =
|
||||
(int8_t *)thread->comm[jc_work_id].sent_object;
|
||||
|
||||
// Compute the B panel per thread loop range for parallel
|
||||
// packing using ic_ways number of threads. Since atmost only
|
||||
// ic_ways threads can be used, the thread_ic attributes are
|
||||
// used to split the loop range.
|
||||
dim_t jc_packb_start, jc_packb_end;
|
||||
bli_thread_range_sub
|
||||
(
|
||||
&thread_ic, nc0, NR, FALSE,
|
||||
&jc_packb_start, &jc_packb_end
|
||||
);
|
||||
|
||||
if ( pc == 0)
|
||||
{
|
||||
pack_b_column_sum = ( int16_t* )( pack_b_buffer_s8s8s16o16 + ( sizeof( int8_t ) * nc0_updated * kc0_updated ) );
|
||||
}
|
||||
|
||||
// Ensure thread ranges are valid, especially cases where no:
|
||||
// of threads available for parallelization are greater than
|
||||
// no: of B panel NR chunks.
|
||||
if ((jc_packb_end > jc_packb_start) &&
|
||||
(jc_packb_start < (jc + nc0)))
|
||||
{
|
||||
if ( pc == 0 )
|
||||
{
|
||||
for (int idx = jc_packb_start; idx < jc_packb_end; idx++ )
|
||||
{
|
||||
*( pack_b_column_sum + idx ) = 0;
|
||||
}
|
||||
}
|
||||
|
||||
( ( packb_s16_s8 )lcntx->packb_fun_ptr )
|
||||
(
|
||||
pack_b_buffer_s8s8s16o16 +
|
||||
(jc_packb_start * kc0_updated),
|
||||
pack_b_column_sum + ( cs_b * jc_packb_start ),
|
||||
(b + (rs_b * pc) + (cs_b * jc) +
|
||||
(cs_b * jc_packb_start)),
|
||||
rs_b,
|
||||
(jc_packb_end - jc_packb_start), kc0,
|
||||
&rs_b_use, &cs_b_use
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
|
||||
}
|
||||
|
||||
// All threads in work group should wait till B matrix packing
|
||||
// is completed by the participating threads.
|
||||
bli_thrcomm_barrier
|
||||
(
|
||||
bli_thread_ocomm_id(&thread_ic),
|
||||
&thread->comm[jc_work_id]
|
||||
);
|
||||
|
||||
b_use = pack_b_buffer_s8s8s16o16;
|
||||
post_ops_attr.b_col_sum_vec_s16 = pack_b_column_sum;
|
||||
}
|
||||
else if (mtag_b == REORDERED)
|
||||
{
|
||||
// In multi-threaded scenarios, an extra offset into a given
|
||||
// packed B panel is required, since the jc loop split can
|
||||
// result in per thread start offset inside the panel, instead
|
||||
// of panel boundaries.
|
||||
b_use = b + (jc_cur_loop * k_updated) +
|
||||
(n_sub_updated * pc) +
|
||||
(jc_cur_loop_rem * kc0_updated);
|
||||
|
||||
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
|
||||
|
||||
post_ops_attr.b_col_sum_vec_s16 = ( ( int16_t* )( b + ( k_updated * n_updated ) ) ) + jc;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Unpacked B not supported.
|
||||
return;
|
||||
}
|
||||
|
||||
for (dim_t ic = ic_start; ic < ic_end; ic += MC)
|
||||
{
|
||||
dim_t mc0 = bli_min((ic_end - ic), MC);
|
||||
|
||||
// Only per thread C matrix is stored in temp buffer, so both
|
||||
// per thread jc and ic start should be normalized to zero.
|
||||
if ( c_downscale < S16 )
|
||||
{
|
||||
c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
c_use_ic = c_use_jc + ( rs_c_use * ic );
|
||||
}
|
||||
|
||||
// Matrix A packed and reordered code path is not triggerred
|
||||
// currently for row-major inputs since we do not support it yet.
|
||||
// Pack is enabled for column-major inputs to transform into
|
||||
// row-major inputs as kernel expects row storage format.
|
||||
if ( mtag_a == PACK )
|
||||
{
|
||||
mem_a_size_req = sizeof( uint8_t ) * mc0 * kc0_updated;
|
||||
|
||||
lpgemm_alloc_mem_panel
|
||||
(
|
||||
mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK,
|
||||
&mem_a, rntm
|
||||
);
|
||||
pack_a_buffer_s8s8s16o16 = ( int8_t* )bli_mem_buffer( &mem_a );
|
||||
|
||||
( ( packa_s16 )lcntx->packa_fun_ptr )
|
||||
(
|
||||
( uint8_t* )pack_a_buffer_s8s8s16o16,
|
||||
( uint8_t* )( a + ( rs_a * ic ) + ( cs_a * pc ) ), rs_a, cs_a,
|
||||
mc0, kc0,
|
||||
&rs_a_use, &cs_a_use
|
||||
);
|
||||
a_use = pack_a_buffer_s8s8s16o16;
|
||||
|
||||
if( cs_a == 1 )
|
||||
{
|
||||
a_block_stride = kc0_updated;
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
a_block_stride = rs_a_use;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
a_use = a + ( rs_a * ic ) + ( cs_a * pc );
|
||||
cs_a_use = 1;
|
||||
a_block_stride = rs_a;
|
||||
}
|
||||
|
||||
post_ops_attr.b_sum_offset = 0;
|
||||
|
||||
for (dim_t jr = 0; jr < nc0; jr += NR)
|
||||
{
|
||||
dim_t nr0 = bli_min((nc0 - jr), NR);
|
||||
|
||||
// Post ops meta attributes.
|
||||
post_ops_attr.post_op_c_i = ic;
|
||||
post_ops_attr.post_op_c_j = ( jc + jr );
|
||||
post_ops_attr.rs_c_downscale = rs_c_downscale;
|
||||
|
||||
// Calls for reorder B
|
||||
( ( lpgemm_rowvar_s16_s8 )lcntx->kern_fun_ptr )
|
||||
(
|
||||
mc0, nr0, kc0,
|
||||
a_use, rs_a_use, cs_a_use, a_block_stride,
|
||||
(b_use + (jr * kc0_updated)), rs_b_use, cs_b_use,
|
||||
(c_use_ic + jr), rs_c_use, 1,
|
||||
alpha, beta0,
|
||||
post_op_list, post_ops_attr
|
||||
);
|
||||
post_ops_attr.b_sum_offset += NR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mtag_b == REORDERED)
|
||||
{
|
||||
adjust_B_panel_reordered_jc(&jc, jc_cur_loop);
|
||||
}
|
||||
}
|
||||
|
||||
// Release pack buffers.
|
||||
if (mtag_b == PACK)
|
||||
{
|
||||
// All threads in work group should wait till B matrix usage is
|
||||
// completed by the participating threads.
|
||||
bli_thrcomm_barrier(
|
||||
bli_thread_ocomm_id(&thread_jc),
|
||||
&thread->comm[bli_thread_work_id(&thread_jc)]);
|
||||
|
||||
if (bli_thread_am_ochief(&thread_ic))
|
||||
{
|
||||
if (bli_mem_is_alloc(&mem_b))
|
||||
{
|
||||
bli_pba_release(rntm, &mem_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
if ( c_downscale < S16 )
|
||||
{
|
||||
if ( bli_mem_is_alloc( &mem_scale_c ) )
|
||||
{
|
||||
bli_pba_release( rntm, &mem_scale_c );
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -293,105 +293,6 @@ BLIS_INLINE void lpgemm_adjust_ic_jc_ways
|
||||
}
|
||||
}
|
||||
|
||||
BLIS_INLINE void lpgemm_s16o16_get_threading
|
||||
(
|
||||
dim_t* n_threads,
|
||||
dim_t* ic_ways,
|
||||
dim_t* jc_ways,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
rntm_t* rntm_g,
|
||||
AOCL_OPERATION_TYPE op_type
|
||||
)
|
||||
{
|
||||
*n_threads = bli_rntm_num_threads( rntm_g );
|
||||
*jc_ways = bli_rntm_jc_ways( rntm_g );
|
||||
*ic_ways = bli_rntm_ic_ways( rntm_g );
|
||||
|
||||
if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) )
|
||||
{
|
||||
// If BLIS_IC_NT or JC_NT are set.
|
||||
// Default cases.
|
||||
*ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1;
|
||||
*jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1;
|
||||
|
||||
*n_threads = ( *jc_ways ) * ( *ic_ways );
|
||||
}
|
||||
else if ( ( *n_threads ) > 1 )
|
||||
{
|
||||
|
||||
dim_t NR = lpgemm_get_block_size_NR_global_cntx( op_type );
|
||||
dim_t MR = lpgemm_get_block_size_MR_global_cntx( op_type );
|
||||
|
||||
if ( n <= NR )
|
||||
{
|
||||
( *ic_ways ) = ( *n_threads );
|
||||
( *jc_ways ) = 1;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else if ( m <= MR )
|
||||
{
|
||||
( *jc_ways ) = ( *n_threads );
|
||||
( *ic_ways ) = 1;
|
||||
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else
|
||||
{
|
||||
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
|
||||
bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways );
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Setting all the values to 1 in case n_threads <= 1. This ensures
|
||||
// the threading parameters are valid.
|
||||
*n_threads = 1;
|
||||
*jc_ways = 1;
|
||||
*ic_ways = 1;
|
||||
}
|
||||
}
|
||||
|
||||
BLIS_INLINE void lpgemm_u8s8s16o16_get_threading
|
||||
(
|
||||
dim_t* n_threads,
|
||||
dim_t* ic_ways,
|
||||
dim_t* jc_ways,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
rntm_t* rntm_g
|
||||
)
|
||||
{
|
||||
lpgemm_s16o16_get_threading
|
||||
(
|
||||
n_threads,
|
||||
ic_ways, jc_ways,
|
||||
m, n, k, rntm_g,
|
||||
U8S8S16OS16
|
||||
);
|
||||
}
|
||||
|
||||
BLIS_INLINE void lpgemm_s8s8s16o16_get_threading
|
||||
(
|
||||
dim_t* n_threads,
|
||||
dim_t* ic_ways,
|
||||
dim_t* jc_ways,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
rntm_t* rntm_g
|
||||
)
|
||||
{
|
||||
lpgemm_s16o16_get_threading
|
||||
(
|
||||
n_threads,
|
||||
ic_ways, jc_ways,
|
||||
m, n, k, rntm_g,
|
||||
S8S8S16OS16
|
||||
);
|
||||
}
|
||||
|
||||
BLIS_INLINE void lpgemm_s32o32_get_threading
|
||||
(
|
||||
dim_t* n_threads,
|
||||
@@ -1162,12 +1063,10 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
} \
|
||||
} \
|
||||
|
||||
GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int16_t,s8s8s16o16)
|
||||
|
||||
|
||||
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
@@ -1807,12 +1706,10 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
); \
|
||||
} \
|
||||
|
||||
GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16)
|
||||
GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32)
|
||||
GEN_LPGEMM_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
GEN_LPGEMM_DECORATOR(float,float,float,f32f32f32of32)
|
||||
GEN_LPGEMM_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
GEN_LPGEMM_DECORATOR(int8_t,int8_t,int16_t,s8s8s16o16)
|
||||
|
||||
#define GEN_LPGEMM_DECORATOR1(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
|
||||
@@ -66,12 +66,10 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
); \
|
||||
|
||||
GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
GEN_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int16_t,s8s8s16o16)
|
||||
|
||||
|
||||
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
@@ -211,12 +209,10 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
); \
|
||||
|
||||
GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16)
|
||||
GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32)
|
||||
GEN_LPGEMM_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
GEN_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32)
|
||||
GEN_LPGEMM_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
GEN_LPGEMM_DECORATOR_FN(int8_t,int8_t,int16_t,s8s8s16o16)
|
||||
|
||||
|
||||
#define GEN_BATCH_LPGEMM_DECORATOR_FN(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
#include "blis.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_reorder_s16.h"
|
||||
#include "lpgemm_packb_s16.h"
|
||||
#include "lpgemm_config.h"
|
||||
|
||||
void aocl_reorderb_nr32_u8s8s16o16
|
||||
(
|
||||
lpgemm_obj_t* b,
|
||||
lpgemm_obj_t* b_reorder,
|
||||
rntm_t* rntm,
|
||||
lpgemm_cntx_t* lcntx
|
||||
)
|
||||
{
|
||||
dim_t NC = lcntx->blksz.NC;
|
||||
dim_t KC = lcntx->blksz.KC;
|
||||
dim_t NR = lcntx->blksz.NR;
|
||||
|
||||
// Extracting the matrix properties from the lpgemm object
|
||||
dim_t rs_b = b->rs;
|
||||
dim_t n = b->width;
|
||||
dim_t k = b->length;
|
||||
|
||||
lpgemm_mod_block_size_s16(0, n, k, NULL, &NC, &KC);
|
||||
|
||||
dim_t rs_b_reorder;
|
||||
dim_t cs_b_reorder;
|
||||
|
||||
dim_t k_updated = k;
|
||||
|
||||
// Making multiple of 2 to suit k in vpmaddubsw
|
||||
k_updated += (k_updated & 0x1);
|
||||
|
||||
dim_t n_threads = bli_rntm_num_threads( rntm );
|
||||
n_threads = ( n_threads > 0 ) ? n_threads : 1;
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
{
|
||||
// Initialise a local thrinfo obj for work split across threads.
|
||||
thrinfo_t thread_jc;
|
||||
bli_thrinfo_set_n_way( n_threads, &thread_jc );
|
||||
bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc );
|
||||
#else
|
||||
{
|
||||
// Initialise a local thrinfo obj for work split across threads.
|
||||
thrinfo_t thread_jc;
|
||||
bli_thrinfo_set_n_way( 1, &thread_jc );
|
||||
bli_thrinfo_set_work_id( 0, &thread_jc );
|
||||
#endif
|
||||
// Compute the JC loop thread range for the current thread.
|
||||
dim_t jc_start, jc_end;
|
||||
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
|
||||
|
||||
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
|
||||
{
|
||||
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
|
||||
|
||||
dim_t jc_cur_loop = jc;
|
||||
dim_t jc_cur_loop_rem = 0;
|
||||
dim_t n_sub_updated;
|
||||
|
||||
get_B_panel_reordered_start_offset_width
|
||||
(
|
||||
jc, n, NC, 16,
|
||||
&jc_cur_loop, &jc_cur_loop_rem,
|
||||
&nc0, &n_sub_updated
|
||||
);
|
||||
|
||||
for ( dim_t pc = 0; pc < k; pc += KC )
|
||||
{
|
||||
dim_t kc0 = bli_min( ( k - pc ), KC );
|
||||
|
||||
// kc0 needs to be a multiple of 2 so that it can be used with
|
||||
// vmaddubsw instruction. Padding is added in cases this
|
||||
// condition is not satisfied, and therefore the kc0 offsets
|
||||
// used for packed/reordered buffers needs to be updated.
|
||||
dim_t kc0_updated = make_multiple_of_n( kc0, 2 );
|
||||
|
||||
// The offsets are calculated in such a way that it resembles
|
||||
// the reorder buffer traversal in single threaded reordering.
|
||||
// The panel boundaries (KCxNC) remain as it is accessed in
|
||||
// single thread, and as a consequence a thread with jc_start
|
||||
// inside the panel cannot consider NC range for reorder. It
|
||||
// has to work with NC' < NC, and the offset is calulated using
|
||||
// prev NC panels spanning k dim + cur NC panel spaning pc loop
|
||||
// cur iteration + (NC - NC') spanning current kc0 (<= KC).
|
||||
//
|
||||
//Eg: Consider the following reordered buffer diagram:
|
||||
// t1 t2
|
||||
// | |
|
||||
// | |..NC..|
|
||||
// | | |
|
||||
// |.NC. |.NC. |NC'|NC"
|
||||
// pc=0-+-----+-----+---+--+
|
||||
// KC| | | | |
|
||||
// | 1 | 3 | 5 |
|
||||
// pc=KC-+-----+-----+---st-+
|
||||
// KC| | | | |
|
||||
// | 2 | 4 | 6 | 7|
|
||||
// pc=k=2KC-+-----+-----+---+--+
|
||||
// |jc=0 |jc=NC|jc=2NC|
|
||||
//
|
||||
// The numbers 1,2..6,7 denotes the order in which reordered
|
||||
// KCxNC blocks are stored in memory, ie: block 1 followed by 2
|
||||
// followed by 3, etc. Given two threads t1 and t2, and t2 needs
|
||||
// to acces point st in the reorder buffer to write the data:
|
||||
// The offset calulation logic will be:
|
||||
// jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC,
|
||||
// n_sub_updated = NC, k = 2KC, kc0_updated = KC
|
||||
//
|
||||
// st = ( jc_cur_loop * k ) <traverse blocks 1,2,3,4>
|
||||
// + ( n_sub_updated * pc ) <traverse block 5>
|
||||
// + ( NC' * kc0_updated) <traverse block 6>
|
||||
( ( packb_s16 )lcntx->packb_fun_ptr )
|
||||
(
|
||||
( ( ( int8_t* )b_reorder->storage.aligned_buffer ) +
|
||||
( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) +
|
||||
( jc_cur_loop_rem * kc0_updated ) ),
|
||||
( ( ( int8_t* )b->storage.aligned_buffer ) +
|
||||
( rs_b * pc ) + jc ),
|
||||
rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder
|
||||
);
|
||||
}
|
||||
|
||||
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
|
||||
}
|
||||
}
|
||||
|
||||
// Changing the packed matrix properties in the packed matrix object
|
||||
b_reorder->rs = rs_b_reorder;
|
||||
b_reorder->cs = cs_b_reorder;
|
||||
b_reorder->mtag = REORDERED;
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
#ifndef LPGEMM_REORDER_S16_H
|
||||
#define LPGEMM_REORDER_S16_H
|
||||
|
||||
#include "lpgemm_types.h"
|
||||
|
||||
void aocl_reorderb_nr32_u8s8s16o16
|
||||
(
|
||||
lpgemm_obj_t* b,
|
||||
lpgemm_obj_t* b_reorder,
|
||||
rntm_t* rntm,
|
||||
lpgemm_cntx_t* lcntx
|
||||
);
|
||||
|
||||
#endif // LPGEMM_REORDER_H
|
||||
@@ -1,573 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_packb_s16.h"
|
||||
#include "lpgemm_packa_s16.h"
|
||||
#include "lpgemm_kernels.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_thrinfo_utils.h"
|
||||
|
||||
// Kernel function prototypes
|
||||
typedef void (*lpgemm_rowvar_s16)
|
||||
(
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const uint8_t*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const int8_t*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
int16_t*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const int16_t,
|
||||
const int16_t,
|
||||
lpgemm_post_op*,
|
||||
lpgemm_post_op_attr
|
||||
);
|
||||
|
||||
|
||||
|
||||
LPGEMV(uint8_t,int8_t,int16_t,u8s8s16os16)
|
||||
{
|
||||
dim_t KC = lcntx->blksz.KC;
|
||||
dim_t MC = lcntx->blksz.MC;
|
||||
|
||||
// Strides are updated based on matrix packing/reordering.
|
||||
uint8_t* a_use = ( uint8_t* )a;
|
||||
inc_t rs_a_use = rs_a;
|
||||
inc_t cs_a_use = cs_a;
|
||||
|
||||
int8_t* b_use = ( int8_t* )b;
|
||||
inc_t rs_b_use = rs_b;
|
||||
inc_t cs_b_use = cs_b;
|
||||
|
||||
int16_t *c_use = NULL;
|
||||
|
||||
lpgemm_post_op_attr post_ops_attr;
|
||||
post_ops_attr.c_stor_type = c_downscale;
|
||||
if (c_downscale < S16) post_ops_attr.buf_downscale = c;
|
||||
else post_ops_attr.buf_downscale = NULL;
|
||||
|
||||
siz_t mem_a_size_req = 0;
|
||||
siz_t mem_b_size_req = 0;
|
||||
|
||||
mem_t mem_a = BLIS_MEM_INITIALIZER;
|
||||
mem_t mem_b = BLIS_MEM_INITIALIZER;
|
||||
|
||||
uint8_t* pack_a_buffer;
|
||||
int8_t* pack_b_buffer;
|
||||
|
||||
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
|
||||
thrinfo_t thread_jc;
|
||||
thrinfo_t thread_ic;
|
||||
|
||||
lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic );
|
||||
|
||||
// Increased MR from 6 to 8 to make use of 16 ymm regs
|
||||
dim_t MR = 8;
|
||||
|
||||
// Pack B matrix if rs_b > 1
|
||||
if( ( mtag_b == PACK ) && ( rs_b != 1 ) )
|
||||
{
|
||||
mem_b_size_req = sizeof( int8_t ) * k;
|
||||
|
||||
lpgemm_alloc_mem_panel
|
||||
(
|
||||
mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE,
|
||||
&mem_b, rntm
|
||||
);
|
||||
|
||||
pack_b_buffer = ( int8_t* ) bli_mem_buffer( &mem_b );
|
||||
|
||||
for( dim_t k0 = 0; k0 < k; k0++ )
|
||||
{
|
||||
pack_b_buffer[k0] = b[ k0*rs_b ];
|
||||
}
|
||||
|
||||
b_use = pack_b_buffer;
|
||||
rs_b_use = 1;
|
||||
cs_b_use = 1;
|
||||
}
|
||||
|
||||
// Compute the IC loop thread range for the current thread.
|
||||
dim_t ic_start, ic_end;
|
||||
thread_ic.n_way = ( thread_ic.n_way == 1 ) ?
|
||||
( thread->n_threads ) : ( thread_ic.n_way );
|
||||
thread_ic.work_id = thread->tid;
|
||||
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
|
||||
|
||||
for (dim_t ic = ic_start; ic < ic_end; ic += MC)
|
||||
{
|
||||
dim_t mc0 = bli_min((ic_end - ic), MC);
|
||||
|
||||
a_use = (uint8_t*)a + ic * rs_a;
|
||||
|
||||
c_use = c + ic * rs_c;
|
||||
|
||||
post_ops_attr.post_op_c_i = ic;
|
||||
post_ops_attr.post_op_c_j = 0;
|
||||
post_ops_attr.rs_c_downscale = rs_c;
|
||||
|
||||
if( mtag_a == PACK )
|
||||
{
|
||||
mem_a_size_req = sizeof( uint8_t ) * mc0 * k;
|
||||
|
||||
lpgemm_alloc_mem_panel
|
||||
(
|
||||
mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE,
|
||||
&mem_a, rntm
|
||||
);
|
||||
|
||||
pack_a_buffer = ( uint8_t* ) bli_mem_buffer( &mem_a );
|
||||
|
||||
( ( packa_s16 ) lcntx->packa_fun_ptr )
|
||||
(
|
||||
pack_a_buffer,
|
||||
( a + ( rs_a * ic )), rs_a, cs_a,
|
||||
mc0, k,
|
||||
&rs_a_use, &cs_a_use
|
||||
);
|
||||
a_use = pack_a_buffer;
|
||||
}
|
||||
|
||||
// Call lpgemv_n_one kernel
|
||||
lpgemv_n_one_u8s8s16os16
|
||||
(
|
||||
mc0, k,
|
||||
a_use, rs_a_use, cs_a_use, mtag_a,
|
||||
b_use, rs_b_use, cs_b_use, mtag_b,
|
||||
c_use, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
MR, KC,
|
||||
post_op_list,
|
||||
&post_ops_attr
|
||||
);
|
||||
}
|
||||
|
||||
// Release pack buffers
|
||||
if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) )
|
||||
{
|
||||
bli_pba_release(rntm, &mem_a);
|
||||
}
|
||||
if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) )
|
||||
{
|
||||
bli_pba_release(rntm, &mem_b);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// B should always be packed.
|
||||
LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16)
|
||||
{
|
||||
dim_t NC = lcntx->blksz.NC;
|
||||
dim_t KC = lcntx->blksz.KC;
|
||||
dim_t MC = lcntx->blksz.MC;
|
||||
const dim_t NR = lcntx->blksz.NR;
|
||||
const dim_t MR = lcntx->blksz.MR;
|
||||
|
||||
lpgemm_mod_block_size_s16(m, n, k, &MC, &NC, &KC);
|
||||
|
||||
if (mtag_b == UNPACKED)
|
||||
{
|
||||
// Error: can only work with packed B now.
|
||||
return;
|
||||
}
|
||||
|
||||
if( n == 1 )
|
||||
{
|
||||
lpgemv_rowvar_u8s8s16os16( m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha,
|
||||
beta,
|
||||
rntm,
|
||||
thread,
|
||||
lcntx,
|
||||
post_op_list,
|
||||
c_downscale );
|
||||
return;
|
||||
}
|
||||
|
||||
const int8_t *b_use;
|
||||
const uint8_t *a_use;
|
||||
dim_t rs_a_use = rs_a;
|
||||
dim_t cs_a_use = cs_a;
|
||||
dim_t a_block_stride = 0;
|
||||
|
||||
dim_t rs_b_use = rs_b;
|
||||
dim_t cs_b_use = cs_b;
|
||||
|
||||
int16_t *c_use_jc = NULL;
|
||||
int16_t *c_use_ic = NULL;
|
||||
dim_t rs_c_use = rs_c;
|
||||
dim_t rs_c_downscale = rs_c;
|
||||
|
||||
// Pack buffer for A.
|
||||
uint8_t* pack_a_buffer_u8s8s16o16;
|
||||
mem_t mem_a = BLIS_MEM_INITIALIZER;
|
||||
siz_t mem_a_size_req = 0;
|
||||
|
||||
// Pack buffer for B.
|
||||
int8_t *pack_b_buffer_u8s8s16o16;
|
||||
mem_t mem_b = BLIS_MEM_INITIALIZER;
|
||||
dim_t packb_min_NR = 16;
|
||||
siz_t mem_b_size_req = 0;
|
||||
|
||||
// Temporary buffer for C accumulation when downscaling is required.
|
||||
int16_t* temp_scal_c_buffer_u8s8s16o16;
|
||||
mem_t mem_scale_c = BLIS_MEM_INITIALIZER;
|
||||
siz_t mem_scale_c_size_req = 0;
|
||||
|
||||
// Making multiple of 2 to suit k in vpmaddubsw
|
||||
dim_t k_updated = make_multiple_of_n( k, 2 );
|
||||
|
||||
// To decide whether to apply post ops or not.
|
||||
bool is_last_k = FALSE;
|
||||
|
||||
// To decide whether to use original s8 C or temp buffer for beta scale.
|
||||
bool is_first_k = FALSE;
|
||||
|
||||
lpgemm_post_op_attr post_ops_attr;
|
||||
post_ops_attr.c_stor_type = c_downscale;
|
||||
if ( c_downscale < S16 )
|
||||
{
|
||||
post_ops_attr.buf_downscale = c;
|
||||
}
|
||||
else
|
||||
{
|
||||
post_ops_attr.buf_downscale = NULL;
|
||||
}
|
||||
|
||||
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
|
||||
thrinfo_t thread_jc;
|
||||
thrinfo_t thread_ic;
|
||||
|
||||
lpgemm_gen_thrinfo(thread, &thread_jc, &thread_ic);
|
||||
|
||||
// Compute the JC, IC loop thread range for the current thread.
|
||||
dim_t jc_start, jc_end;
|
||||
bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end);
|
||||
|
||||
dim_t ic_start, ic_end;
|
||||
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
|
||||
|
||||
for (dim_t jc = jc_start; jc < jc_end; jc += NC)
|
||||
{
|
||||
dim_t nc0 = bli_min((jc_end - jc), NC);
|
||||
|
||||
dim_t jc_cur_loop = jc;
|
||||
dim_t jc_cur_loop_rem = 0;
|
||||
dim_t n_sub_updated = 0;
|
||||
|
||||
if (mtag_b == REORDERED)
|
||||
{
|
||||
get_B_panel_reordered_start_offset_width
|
||||
(
|
||||
jc, n, NC, packb_min_NR,
|
||||
&jc_cur_loop, &jc_cur_loop_rem,
|
||||
&nc0, &n_sub_updated
|
||||
);
|
||||
}
|
||||
|
||||
if ( c_downscale == S16 )
|
||||
{
|
||||
c_use_jc = c + jc;
|
||||
}
|
||||
// Temp accumulaton buffer for C allocation.
|
||||
else if ( c_downscale < S16 )
|
||||
{
|
||||
// Buffer memory is only required if output needs to be
|
||||
// persisted across iterations of the pc/KC loop.
|
||||
// It was observed that the locks used while checking out
|
||||
// a buffer from memory pool had an impact on performance
|
||||
// and is better to not checkout if k <= KC.
|
||||
if ( k > KC )
|
||||
{
|
||||
mem_scale_c_size_req = sizeof( int16_t ) * nc0 * ( ic_end - ic_start );
|
||||
|
||||
lpgemm_alloc_mem_panel
|
||||
(
|
||||
mem_scale_c_size_req, BLIS_BUFFER_FOR_GEN_USE,
|
||||
&mem_scale_c, rntm
|
||||
);
|
||||
|
||||
temp_scal_c_buffer_u8s8s16o16 = bli_mem_buffer( &mem_scale_c );
|
||||
|
||||
c_use_jc = ( int16_t* )temp_scal_c_buffer_u8s8s16o16;
|
||||
}
|
||||
|
||||
// The temp c buffer stride is modified as opposed to original C matrix.
|
||||
rs_c_use = nc0;
|
||||
}
|
||||
|
||||
for (dim_t pc = 0; pc < k; pc += KC)
|
||||
{
|
||||
int16_t beta0 = (pc == 0) ? beta : 1;
|
||||
dim_t kc0 = bli_min((k - pc), KC);
|
||||
|
||||
// No parallelization in k dim, k always starts at 0.
|
||||
is_first_k = ( pc == 0 ) ? ( TRUE ) : ( FALSE );
|
||||
post_ops_attr.is_first_k = is_first_k;
|
||||
|
||||
is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE );
|
||||
post_ops_attr.is_last_k = is_last_k;
|
||||
|
||||
// kc0 needs to be a multiple of 2 so that it can be
|
||||
// used with vpmaddubsw instruction. Padding is added in
|
||||
// cases this condition is not satisfied, and therefore
|
||||
// the kc0 offsets used for packed/reordered buffers
|
||||
// needs to be updated.
|
||||
dim_t kc0_updated = make_multiple_of_n(kc0, 2);
|
||||
|
||||
if (mtag_b == PACK)
|
||||
{
|
||||
// Pack B chunks are based on jc work id.
|
||||
dim_t jc_work_id = bli_thread_work_id(&thread_jc);
|
||||
|
||||
// Using child thrinfo (thread_ic) tid to decide chief thread
|
||||
// per B matrix chunk (jc work id group)
|
||||
if (bli_thread_am_ochief(&thread_ic))
|
||||
{
|
||||
// nc0 needs to be a multiple of 16 since this gives maximum
|
||||
// vectorization. Packing B always results in buffers with width
|
||||
// which is a multiple of 16. Subsequently the nc0 offsets used
|
||||
// for packed/reordered buffers needs to be updated.
|
||||
dim_t nc0_updated = make_multiple_of_n(nc0, packb_min_NR);
|
||||
mem_b_size_req = sizeof(int8_t) * nc0_updated * kc0_updated;
|
||||
|
||||
lpgemm_alloc_mem_panel(
|
||||
mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL,
|
||||
&mem_b, rntm);
|
||||
|
||||
thread->comm[jc_work_id].sent_object =
|
||||
bli_mem_buffer(&mem_b);
|
||||
}
|
||||
|
||||
// All threads in work group should wait till chief thread has
|
||||
// finished allocating the packing buffers.
|
||||
bli_thrcomm_barrier
|
||||
(
|
||||
bli_thread_ocomm_id(&thread_ic),
|
||||
&thread->comm[jc_work_id]
|
||||
);
|
||||
|
||||
pack_b_buffer_u8s8s16o16 =
|
||||
(int8_t *)thread->comm[jc_work_id].sent_object;
|
||||
|
||||
// Compute the B panel per thread loop range for parallel
|
||||
// packing using ic_ways number of threads. Since atmost only
|
||||
// ic_ways threads can be used, the thread_ic attributes are
|
||||
// used to split the loop range.
|
||||
dim_t jc_packb_start, jc_packb_end;
|
||||
bli_thread_range_sub
|
||||
(
|
||||
&thread_ic, nc0, NR, FALSE,
|
||||
&jc_packb_start, &jc_packb_end
|
||||
);
|
||||
|
||||
// Ensure thread ranges are valid, especially cases where no:
|
||||
// of threads available for parallelization are greater than
|
||||
// no: of B panel NR chunks.
|
||||
if ((jc_packb_end > jc_packb_start) &&
|
||||
(jc_packb_start < (jc + nc0)))
|
||||
{
|
||||
( ( packb_s16 )lcntx->packb_fun_ptr )
|
||||
(
|
||||
pack_b_buffer_u8s8s16o16 +
|
||||
(jc_packb_start * kc0_updated),
|
||||
(b + (rs_b * pc) + (cs_b * jc) +
|
||||
(cs_b * jc_packb_start)),
|
||||
rs_b,
|
||||
(jc_packb_end - jc_packb_start), kc0,
|
||||
&rs_b_use, &cs_b_use
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
|
||||
}
|
||||
|
||||
// All threads in work group should wait till B matrix packing
|
||||
// is completed by the participating threads.
|
||||
bli_thrcomm_barrier
|
||||
(
|
||||
bli_thread_ocomm_id(&thread_ic),
|
||||
&thread->comm[jc_work_id]
|
||||
);
|
||||
|
||||
b_use = pack_b_buffer_u8s8s16o16;
|
||||
}
|
||||
else if (mtag_b == REORDERED)
|
||||
{
|
||||
// In multi-threaded scenarios, an extra offset into a given
|
||||
// packed B panel is required, since the jc loop split can
|
||||
// result in per thread start offset inside the panel, instead
|
||||
// of panel boundaries.
|
||||
b_use = b + (jc_cur_loop * k_updated) +
|
||||
(n_sub_updated * pc) +
|
||||
(jc_cur_loop_rem * kc0_updated);
|
||||
|
||||
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
|
||||
}
|
||||
else
|
||||
{
|
||||
// Unpacked B not supported.
|
||||
return;
|
||||
}
|
||||
|
||||
for (dim_t ic = ic_start; ic < ic_end; ic += MC)
|
||||
{
|
||||
dim_t mc0 = bli_min((ic_end - ic), MC);
|
||||
|
||||
// Only per thread C matrix is stored in temp buffer, so both
|
||||
// per thread jc and ic start should be normalized to zero.
|
||||
if ( c_downscale < S16 )
|
||||
{
|
||||
c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) );
|
||||
}
|
||||
else
|
||||
{
|
||||
c_use_ic = c_use_jc + ( rs_c_use * ic );
|
||||
}
|
||||
|
||||
// Matrix A packed and reordered code path is not triggerred
|
||||
// currently for row-major inputs since we do not support it yet.
|
||||
// Pack is enabled for column-major inputs to transform into
|
||||
// row-major inputs as kernel expects row storage format.
|
||||
if ( mtag_a == PACK )
|
||||
{
|
||||
mem_a_size_req = sizeof( uint8_t ) * mc0 * kc0_updated;
|
||||
|
||||
lpgemm_alloc_mem_panel
|
||||
(
|
||||
mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK,
|
||||
&mem_a, rntm
|
||||
);
|
||||
pack_a_buffer_u8s8s16o16 = ( uint8_t* )bli_mem_buffer( &mem_a );
|
||||
|
||||
( ( packa_s16 )lcntx->packa_fun_ptr )
|
||||
(
|
||||
pack_a_buffer_u8s8s16o16,
|
||||
( a + ( rs_a * ic ) + ( cs_a * pc ) ), rs_a, cs_a,
|
||||
mc0, kc0,
|
||||
&rs_a_use, &cs_a_use
|
||||
);
|
||||
a_use = pack_a_buffer_u8s8s16o16;
|
||||
|
||||
if( cs_a == 1 )
|
||||
{
|
||||
a_block_stride = kc0_updated;
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
a_block_stride = rs_a_use;
|
||||
}
|
||||
|
||||
}
|
||||
else if ( mtag_a == REORDERED )
|
||||
{
|
||||
lpgemm_get_packa_strides( lcntx, &rs_a_use, &cs_a_use );
|
||||
a_use = a + ( pc * m ) + ( kc0_updated * ic );
|
||||
a_block_stride = kc0_updated;
|
||||
}
|
||||
else
|
||||
{
|
||||
a_use = a + ( rs_a * ic ) + ( cs_a * pc );
|
||||
cs_a_use = 1;
|
||||
a_block_stride = rs_a;
|
||||
}
|
||||
|
||||
for (dim_t jr = 0; jr < nc0; jr += NR)
|
||||
{
|
||||
dim_t nr0 = bli_min((nc0 - jr), NR);
|
||||
|
||||
// Post ops meta attributes.
|
||||
post_ops_attr.post_op_c_i = ic;
|
||||
post_ops_attr.post_op_c_j = ( jc + jr );
|
||||
post_ops_attr.rs_c_downscale = rs_c_downscale;
|
||||
|
||||
// Calls for reorder B
|
||||
( ( lpgemm_rowvar_s16 )lcntx->kern_fun_ptr )
|
||||
(
|
||||
mc0, nr0, kc0,
|
||||
a_use, rs_a_use, cs_a_use, a_block_stride,
|
||||
(b_use + (jr * kc0_updated)), rs_b_use, cs_b_use,
|
||||
(c_use_ic + jr), rs_c_use, 1,
|
||||
alpha, beta0,
|
||||
post_op_list, post_ops_attr
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mtag_b == REORDERED)
|
||||
{
|
||||
adjust_B_panel_reordered_jc(&jc, jc_cur_loop);
|
||||
}
|
||||
}
|
||||
|
||||
// Release pack buffers.
|
||||
if (mtag_b == PACK)
|
||||
{
|
||||
// All threads in work group should wait till B matrix usage is
|
||||
// completed by the participating threads.
|
||||
bli_thrcomm_barrier(
|
||||
bli_thread_ocomm_id(&thread_jc),
|
||||
&thread->comm[bli_thread_work_id(&thread_jc)]);
|
||||
|
||||
if (bli_thread_am_ochief(&thread_ic))
|
||||
{
|
||||
if (bli_mem_is_alloc(&mem_b))
|
||||
{
|
||||
bli_pba_release(rntm, &mem_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
if ( c_downscale < S16 )
|
||||
{
|
||||
if ( bli_mem_is_alloc( &mem_scale_c ) )
|
||||
{
|
||||
bli_pba_release( rntm, &mem_scale_c );
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -87,12 +87,10 @@ void lpgemm_rowvar_ ## LP_SFX \
|
||||
) \
|
||||
|
||||
LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64);
|
||||
LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32);
|
||||
LPGEMM_MAIN_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x64);
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m);
|
||||
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m);
|
||||
LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64);
|
||||
LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32);
|
||||
|
||||
|
||||
#define LPGEMM_MAIN_KERN1(A_type,B_type,C_type,LP_SFX) \
|
||||
@@ -144,10 +142,6 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64);
|
||||
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64);
|
||||
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64);
|
||||
|
||||
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32);
|
||||
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32);
|
||||
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32);
|
||||
|
||||
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x64);
|
||||
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x64);
|
||||
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x64);
|
||||
@@ -201,10 +195,6 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x64);
|
||||
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x64);
|
||||
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x64);
|
||||
|
||||
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32);
|
||||
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32);
|
||||
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32);
|
||||
|
||||
|
||||
#define LPGEMM_M_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
@@ -258,8 +248,6 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32);
|
||||
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_9x32);
|
||||
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48);
|
||||
|
||||
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16);
|
||||
|
||||
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x16);
|
||||
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x32);
|
||||
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x48);
|
||||
@@ -275,8 +263,6 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x16);
|
||||
LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x32);
|
||||
LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x48);
|
||||
|
||||
LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16);
|
||||
|
||||
|
||||
#define LPGEMM_N_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
@@ -329,14 +315,10 @@ void lpgemm_rowvar_ ## LP_SFX \
|
||||
LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16);
|
||||
LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16);
|
||||
|
||||
LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16);
|
||||
|
||||
LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6xlt16);
|
||||
|
||||
LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16);
|
||||
|
||||
LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16);
|
||||
|
||||
#define LPGEMM_N_LT_NR0_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
@@ -396,10 +378,6 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48);
|
||||
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48);
|
||||
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48);
|
||||
|
||||
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16);
|
||||
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16);
|
||||
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16);
|
||||
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x16);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x16);
|
||||
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x16);
|
||||
@@ -432,10 +410,6 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x48);
|
||||
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x48);
|
||||
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x48);
|
||||
|
||||
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16);
|
||||
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16);
|
||||
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x16);
|
||||
|
||||
#define LPGEMM_MN_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
@@ -496,10 +470,6 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16);
|
||||
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16);
|
||||
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3xlt16);
|
||||
@@ -512,10 +482,6 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1xlt16);
|
||||
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16);
|
||||
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16);
|
||||
|
||||
#define LPGEMM_MN_LT_NR0_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
@@ -600,8 +566,6 @@ void lpgemv_n_one_ ## LP_SFX \
|
||||
LPGEMV_N_EQ1_KERN(float, float, float,f32f32f32of32);
|
||||
LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float,bf16bf16f32of32);
|
||||
LPGEMV_N_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32);
|
||||
LPGEMV_N_EQ1_KERN(uint8_t,int8_t,int16_t,u8s8s16os16);
|
||||
LPGEMV_N_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32);
|
||||
LPGEMV_N_EQ1_KERN(int8_t,int8_t,int16_t,s8s8s16os16);
|
||||
|
||||
#endif //BLIS_LPGEMM_KERN_H
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#ifndef BLIS_GEMM_S8_INT16_PACKB
|
||||
#define BLIS_GEMM_S8_INT16_PACKB
|
||||
|
||||
typedef void (*packb_s16_s8)
|
||||
(
|
||||
int8_t*,
|
||||
int16_t*,
|
||||
const int8_t*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
dim_t*,
|
||||
dim_t*
|
||||
);
|
||||
|
||||
void packb_nr32_s8s8s16o16
|
||||
(
|
||||
int8_t *pack_b_buffer_s8s8s16o16,
|
||||
int16_t *pack_b_column_sum,
|
||||
const int8_t *b,
|
||||
const dim_t ldb,
|
||||
const dim_t cols,
|
||||
const dim_t rows,
|
||||
dim_t *rs_b,
|
||||
dim_t *cs_b
|
||||
);
|
||||
|
||||
#endif // BLIS_GEMM_S8_INT16_PACKB
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#ifndef BLIS_GEMM_INT8_U8S8S16_PACKA
|
||||
#define BLIS_GEMM_INT8_U8S8S16_PACKA
|
||||
|
||||
typedef void (*packa_s16)
|
||||
(
|
||||
uint8_t*,
|
||||
const uint8_t*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
dim_t*,
|
||||
dim_t*
|
||||
);
|
||||
|
||||
void packa_u8s8s16os16
|
||||
(
|
||||
uint8_t* pack_a_buffer_u8s8s16o16,
|
||||
const uint8_t* a,
|
||||
const dim_t rs,
|
||||
const dim_t cs,
|
||||
const dim_t MC,
|
||||
const dim_t KC,
|
||||
dim_t* rs_a,
|
||||
dim_t* cs_a
|
||||
);
|
||||
|
||||
#endif //BLIS_GEMM_INT8_U8S8S16_PACKA
|
||||
@@ -1,60 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#ifndef BLIS_GEMM_INT16_PACKB
|
||||
#define BLIS_GEMM_INT16_PACKB
|
||||
|
||||
typedef void (*packb_s16)
|
||||
(
|
||||
int8_t*,
|
||||
const int8_t*,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
const dim_t,
|
||||
dim_t*,
|
||||
dim_t*
|
||||
);
|
||||
|
||||
void packb_nr32_u8s8s16o16
|
||||
(
|
||||
int8_t *pack_b_buffer_u8s8s16o16,
|
||||
const int8_t *b,
|
||||
const dim_t ldb,
|
||||
const dim_t cols,
|
||||
const dim_t rows,
|
||||
dim_t *rs_b,
|
||||
dim_t *cs_b
|
||||
);
|
||||
|
||||
#endif // BLIS_GEMM_INT16_PACKB
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,412 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include "blis.h"
|
||||
|
||||
#ifdef BLIS_ADDON_LPGEMM
|
||||
|
||||
void packb_nrlt16_s8s8s16o16
|
||||
(
|
||||
int8_t *pack_b_buffer_s8s8s16o16,
|
||||
int16_t *pack_b_column_sum,
|
||||
const int8_t *b,
|
||||
const dim_t ldb,
|
||||
const dim_t rows,
|
||||
dim_t n0_partial_rem
|
||||
)
|
||||
{
|
||||
dim_t k_full_pieces_blks = rows / 2;
|
||||
dim_t k_full_pieces = k_full_pieces_blks * 2;
|
||||
dim_t k_partial_pieces = rows % 2;
|
||||
dim_t NR = 16;
|
||||
dim_t kr_new = 0;
|
||||
|
||||
int8_t buf0[16], buf1[16];
|
||||
|
||||
__m128i b_vec[2], inter_vec[2];
|
||||
|
||||
__m256i sum1;
|
||||
__m256i temp1;
|
||||
__m256 temp2, temp3;
|
||||
|
||||
//load the temp buffer to compute column sum of B matrix
|
||||
sum1 = _mm256_loadu_si256( (__m256i const *)(pack_b_column_sum) );
|
||||
|
||||
for (dim_t kr = 0; kr < k_full_pieces; kr += 2)
|
||||
{
|
||||
memcpy(buf0, (b + (ldb * (kr + 0))), (n0_partial_rem * sizeof(int8_t)));
|
||||
memcpy(buf1, (b + (ldb * (kr + 1))), (n0_partial_rem * sizeof(int8_t)));
|
||||
|
||||
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
|
||||
b_vec[0] = _mm_loadu_si128((__m128i *)buf0);
|
||||
// Read b[1,0], b[1,1], b[1,2]......., b[1,15]
|
||||
b_vec[1] = _mm_loadu_si128((__m128i *)buf1);
|
||||
|
||||
//compute sum1 to compute B matrix column sum
|
||||
temp1 =
|
||||
_mm256_add_epi16( _mm256_cvtepi8_epi16( b_vec[0] ), _mm256_cvtepi8_epi16( b_vec[1] ));
|
||||
|
||||
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
|
||||
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
|
||||
|
||||
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
|
||||
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
|
||||
|
||||
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
|
||||
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
|
||||
|
||||
sum1 = _mm256_add_epi16 (sum1, temp1);
|
||||
|
||||
// Reorder B matrix inputs to suit vpmaddubsw instructions
|
||||
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
|
||||
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
|
||||
|
||||
// Store b[0,0], b[1,0], b[0,1]......., b[0,7], b[1,7]
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + (kr_new * NR)), inter_vec[0]);
|
||||
// Store b[0,8], b[1,8], b[0,9]......., b[0,15], b[1,15]
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
|
||||
|
||||
// Increment to ignore the padded bits
|
||||
kr_new += 2;
|
||||
}
|
||||
|
||||
// Handle k partial cases
|
||||
if (k_partial_pieces > 0)
|
||||
{
|
||||
memcpy(buf0, (b + (ldb * (k_full_pieces + 0))), (n0_partial_rem * sizeof(int8_t)));
|
||||
|
||||
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
|
||||
b_vec[0] = _mm_loadu_si128((__m128i *)buf0);
|
||||
b_vec[1] = _mm_setzero_si128(); // Initialize with zero for padding
|
||||
|
||||
//compute sum1 to compute B matrix column sum
|
||||
temp1 = ( _mm256_cvtepi8_epi16( b_vec[0] ));
|
||||
|
||||
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
|
||||
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
|
||||
|
||||
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
|
||||
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
|
||||
|
||||
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
|
||||
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
|
||||
|
||||
sum1 = _mm256_add_epi16 (sum1, temp1);
|
||||
|
||||
// Reorder B matrix inputs to suit vpmaddubsw instructions
|
||||
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
|
||||
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
|
||||
|
||||
// Store b[0,0], 0, b[0,1]......., b[0,7], 0
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]);
|
||||
|
||||
// Store b[0,8], 0, b[0,9]......., b[0,15], 0
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
|
||||
}
|
||||
//store the sum column
|
||||
_mm256_storeu_si256( (__m256i *)(pack_b_column_sum), sum1 );
|
||||
}
|
||||
|
||||
void packb_nr16_s8s8s16o16(
|
||||
int8_t *pack_b_buffer_s8s8s16o16,
|
||||
int16_t *pack_b_column_sum,
|
||||
const int8_t *b,
|
||||
const dim_t ldb,
|
||||
const dim_t rows)
|
||||
{
|
||||
dim_t k_full_pieces_blks = rows / 2;
|
||||
dim_t k_full_pieces = k_full_pieces_blks * 2;
|
||||
dim_t k_partial_pieces = rows % 2;
|
||||
dim_t NR = 16;
|
||||
dim_t kr_new = 0;
|
||||
|
||||
__m128i b_vec[2], inter_vec[2];
|
||||
|
||||
__m256i sum1;
|
||||
__m256i temp1;
|
||||
__m256 temp2, temp3;
|
||||
|
||||
//load the temp buffer to compute column sum of B matrix
|
||||
sum1 = _mm256_loadu_si256( (__m256i const *)(pack_b_column_sum) );
|
||||
|
||||
for (dim_t kr = 0; kr < k_full_pieces; kr += 2)
|
||||
{
|
||||
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
|
||||
b_vec[0] = _mm_loadu_si128((__m128i const *)(b + (ldb * (kr + 0))));
|
||||
|
||||
// Read b[1,0], b[1,1], b[1,2]......., b[1,15]
|
||||
b_vec[1] = _mm_loadu_si128((__m128i const *)(b + (ldb * (kr + 1))));
|
||||
|
||||
//compute sum1 to compute B matrix column sum
|
||||
temp1 =
|
||||
_mm256_add_epi16( _mm256_cvtepi8_epi16( b_vec[0] ), _mm256_cvtepi8_epi16( b_vec[1] ));
|
||||
|
||||
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
|
||||
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
|
||||
|
||||
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
|
||||
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
|
||||
|
||||
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
|
||||
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
|
||||
|
||||
sum1 = _mm256_add_epi16 (sum1, temp1);
|
||||
|
||||
// Reorder B matrix inputs to suit vpmaddubsw instructions
|
||||
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
|
||||
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
|
||||
|
||||
// Store b[0,0], b[1,0], b[0,1]......., b[0,7], b[1,7]
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]);
|
||||
|
||||
// Store b[0,8], b[1,8], b[0,9]......., b[0,15], b[1,15]
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
|
||||
|
||||
// Increment to ignore the padded bits
|
||||
kr_new += 2;
|
||||
}
|
||||
|
||||
if (k_partial_pieces > 0)
|
||||
{
|
||||
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
|
||||
b_vec[0] = _mm_loadu_si128((__m128i const *)(b + (ldb * (k_full_pieces + 0))));
|
||||
b_vec[1] = _mm_setzero_si128(); // Initialize with zero for padding
|
||||
|
||||
//compute sum1
|
||||
temp1 = ( _mm256_cvtepi8_epi16( b_vec[0] ));
|
||||
|
||||
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
|
||||
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
|
||||
|
||||
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
|
||||
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
|
||||
|
||||
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
|
||||
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
|
||||
|
||||
sum1 = _mm256_add_epi16 (sum1, temp1);
|
||||
|
||||
// Reorder B matrix inputs to suit vpmaddubsw instructions
|
||||
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
|
||||
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
|
||||
|
||||
// Store b[0,0], 0, b[0,1]......., b[0,7], 0
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]);
|
||||
// Store b[0,8], 0, b[0,9]......., b[0,15], 0
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
|
||||
}
|
||||
//store the sum column
|
||||
_mm256_storeu_si256( (__m256i *)(pack_b_column_sum), sum1 );
|
||||
}
|
||||
|
||||
void packb_nr32_s8s8s16o16(
|
||||
int8_t *pack_b_buffer_s8s8s16o16,
|
||||
int16_t *pack_b_column_sum,
|
||||
const int8_t *b,
|
||||
const dim_t ldb,
|
||||
const dim_t cols,
|
||||
const dim_t rows,
|
||||
dim_t *rs_b,
|
||||
dim_t *cs_b)
|
||||
{
|
||||
dim_t NR = 32;
|
||||
|
||||
dim_t n_full_pieces = cols / NR;
|
||||
dim_t n_full_pieces_loop_limit = n_full_pieces * NR;
|
||||
dim_t n_partial_pieces = cols % NR;
|
||||
dim_t k_full_pieces_blks = rows / 2;
|
||||
dim_t k_full_pieces = k_full_pieces_blks * 2;
|
||||
dim_t k_partial_pieces = rows % 2;
|
||||
|
||||
dim_t KC_updated = rows;
|
||||
|
||||
// Making multiple of 2 to suit k in vpmaddubsw
|
||||
KC_updated += (KC_updated & 0x1);
|
||||
|
||||
//to compute column sum of B matrix
|
||||
__m256i sum1, sum2;
|
||||
__m256i temp1;
|
||||
__m256 temp2, temp3;
|
||||
|
||||
__m256i b_vec[2], inter_vec[2];
|
||||
|
||||
for (dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR)
|
||||
{
|
||||
//load the temp buffer to compute column sum of B matrix
|
||||
sum1 = _mm256_loadu_si256( (__m256i const *)(pack_b_column_sum + jc) );
|
||||
sum2 = _mm256_loadu_si256( (__m256i const *)(pack_b_column_sum + 16 + jc) );
|
||||
|
||||
for (dim_t kr = 0; kr < k_full_pieces; kr += 2)
|
||||
{
|
||||
// Read b[0,0], b[0,1], b[0,2]......., b[0,31]
|
||||
b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 0)) + jc));
|
||||
|
||||
// Read b[1,0], b[1,1], b[1,2]......., b[1,31]
|
||||
b_vec[1] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 1)) + jc));
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0)
|
||||
//compute sum1 and sum2 to compute B matrix column sum
|
||||
temp1 =
|
||||
_mm256_add_epi16( _mm256_cvtepi8_epi16( _mm256_extractf128_si256( b_vec[0], 0 )),
|
||||
_mm256_cvtepi8_epi16( _mm256_extractf128_si256( b_vec[1], 0 )));
|
||||
|
||||
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
|
||||
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
|
||||
|
||||
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
|
||||
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
|
||||
|
||||
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
|
||||
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
|
||||
|
||||
sum1 = _mm256_add_epi16 (sum1, temp1);
|
||||
|
||||
//compute sum2
|
||||
temp1 =
|
||||
_mm256_add_epi16( _mm256_cvtepi8_epi16( _mm256_extractf128_si256( b_vec[0], 1 )),
|
||||
_mm256_cvtepi8_epi16( _mm256_extractf128_si256( b_vec[1], 1 )));
|
||||
|
||||
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
|
||||
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
|
||||
|
||||
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
|
||||
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
|
||||
|
||||
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
|
||||
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
|
||||
|
||||
sum2 = _mm256_add_epi16 (sum2, temp1);
|
||||
|
||||
// Reorder B matrix inputs to suit vpmaddubsw instructions
|
||||
inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]);
|
||||
inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]);
|
||||
|
||||
b_vec[0] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x20);
|
||||
b_vec[1] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x31);
|
||||
|
||||
// Store B[0,0], B[1,0], B[0,1], B[1,1], ......, B[0,15], B[1,15]
|
||||
_mm256_storeu_si256((__m256i *)(pack_b_buffer_s8s8s16o16 + ((jc * KC_updated) + (kr * NR))), b_vec[0]);
|
||||
// Store B[0,16], B[1,16], B[0,17], B[1,17], ......, B[0,31], B[1,31]
|
||||
_mm256_storeu_si256((__m256i *)(pack_b_buffer_s8s8s16o16 + ((jc * KC_updated) + ((kr + 1) * NR))), b_vec[1]);
|
||||
}
|
||||
|
||||
if (k_partial_pieces > 0)
|
||||
{
|
||||
// Read b[0,0], b[0,1], b[0,2]......., b[0,31]
|
||||
b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (k_full_pieces + 0)) + jc));
|
||||
b_vec[1] = _mm256_setzero_si256(); // Initialize with zero for padding
|
||||
|
||||
//compute sum1
|
||||
temp1 = _mm256_cvtepi8_epi16( _mm256_extractf128_si256( b_vec[0], 0 ));
|
||||
|
||||
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
|
||||
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
|
||||
|
||||
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
|
||||
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
|
||||
|
||||
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
|
||||
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
|
||||
|
||||
sum1 = _mm256_add_epi16 (sum1, temp1);
|
||||
|
||||
//compute sum2
|
||||
temp1 = _mm256_cvtepi8_epi16( _mm256_extractf128_si256( b_vec[0], 1 ));
|
||||
|
||||
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
|
||||
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
|
||||
|
||||
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
|
||||
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
|
||||
|
||||
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
|
||||
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
|
||||
|
||||
sum2 = _mm256_add_epi16 (sum2, temp1);
|
||||
|
||||
// Reorder B matrix inputs to suit vpmaddubsw instructions
|
||||
inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]);
|
||||
inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]);
|
||||
|
||||
b_vec[0] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x20);
|
||||
b_vec[1] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x31);
|
||||
|
||||
// Store B[0,0], B[1,0], B[0,1], B[1,1], ......, B[0,15], B[1,15]
|
||||
_mm256_storeu_si256((__m256i *)(pack_b_buffer_s8s8s16o16 + ((jc * KC_updated) + (k_full_pieces * NR))), b_vec[0]);
|
||||
// Store B[0,16], B[1,16], B[0,17], B[1,17], ......, B[0,31], B[1,31]
|
||||
_mm256_storeu_si256((__m256i *)(pack_b_buffer_s8s8s16o16 + ((jc * KC_updated) + ((k_full_pieces + 1) * NR))), b_vec[1]);
|
||||
}
|
||||
//store the sum column
|
||||
_mm256_storeu_si256( (__m256i *)(pack_b_column_sum + jc), sum1 );
|
||||
_mm256_storeu_si256( (__m256i *)(pack_b_column_sum + 16 + jc), sum2 );
|
||||
}
|
||||
|
||||
// B matrix packing when n < NR
|
||||
if (n_partial_pieces > 0)
|
||||
{
|
||||
// Split into multiple smaller fringe kernels, so as to maximize
|
||||
// vectorization after packing. Any n0 < NR(32) can be expressed
|
||||
// as n0 = 16 + n`.
|
||||
dim_t n0_16 = n_partial_pieces / 16;
|
||||
dim_t n0_partial_rem = n_partial_pieces % 16;
|
||||
|
||||
dim_t n0_partial_pack = 0;
|
||||
|
||||
if (n0_16 == 1)
|
||||
{
|
||||
packb_nr16_s8s8s16o16(
|
||||
(pack_b_buffer_s8s8s16o16 +
|
||||
(n_full_pieces_loop_limit * KC_updated)),
|
||||
( pack_b_column_sum + ( n_full_pieces_loop_limit ) ),
|
||||
(b + n_full_pieces_loop_limit), ldb, rows);
|
||||
|
||||
n0_partial_pack = 16;
|
||||
}
|
||||
|
||||
if (n0_partial_rem > 0)
|
||||
{
|
||||
packb_nrlt16_s8s8s16o16(
|
||||
(pack_b_buffer_s8s8s16o16 + (n_full_pieces_loop_limit * KC_updated) +
|
||||
(n0_partial_pack * KC_updated)),
|
||||
( pack_b_column_sum + n_full_pieces_loop_limit + n0_partial_pack ),
|
||||
(b + n_full_pieces_loop_limit + n0_partial_pack),
|
||||
ldb, rows, n0_partial_rem);
|
||||
}
|
||||
}
|
||||
|
||||
*rs_b = NR * 2;
|
||||
*cs_b = NR;
|
||||
}
|
||||
#endif
|
||||
@@ -1,877 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include "blis.h"
|
||||
|
||||
#ifdef BLIS_ADDON_LPGEMM
|
||||
|
||||
#include "../u8s8s16/lpgemm_s16_kern_macros.h"
|
||||
|
||||
#define LPGEMV_N_KERNEL_2_LOADS( ymm0, ymm1, paddr, stride ) \
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)paddr ); \
|
||||
ymm1 = _mm256_loadu_si256( (__m256i const *)(paddr + stride) ); \
|
||||
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 ); \
|
||||
ymm1 = _mm256_add_epi8( ymm1, vec_uint8 );
|
||||
|
||||
#define LPGEMV_N_KERNEL_2_FMA( a_reg1, a_reg2, b_reg, \
|
||||
inter_reg1, inter_reg2, c_reg1, c_reg2 ) \
|
||||
inter_reg1 = _mm256_maddubs_epi16(a_reg1, b_reg); \
|
||||
c_reg1 = _mm256_add_epi16(inter_reg1, c_reg1); \
|
||||
inter_reg2 = _mm256_maddubs_epi16(a_reg2, b_reg); \
|
||||
c_reg2 = _mm256_add_epi16(inter_reg2, c_reg2);
|
||||
|
||||
|
||||
#define LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, paddr, stride ) \
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)(paddr) ); \
|
||||
ymm1 = _mm256_loadu_si256( (__m256i const *)(paddr + stride) ); \
|
||||
ymm2 = _mm256_loadu_si256( (__m256i const *)(paddr + 2 * stride) ); \
|
||||
ymm3 = _mm256_loadu_si256( (__m256i const *)(paddr + 3 * stride) ); \
|
||||
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 ); \
|
||||
ymm1 = _mm256_add_epi8( ymm1, vec_uint8 ); \
|
||||
ymm2 = _mm256_add_epi8( ymm2, vec_uint8 ); \
|
||||
ymm3 = _mm256_add_epi8( ymm3, vec_uint8 );
|
||||
|
||||
#define LPGEMV_N_KERNEL_4_FMA( a_reg1, a_reg2, a_reg3, a_reg4, b_reg, \
|
||||
inter_reg1, inter_reg2, \
|
||||
inter_reg3, inter_reg4, \
|
||||
out_reg1, out_reg2, out_reg3, out_reg4 ) \
|
||||
inter_reg1 = _mm256_maddubs_epi16(a_reg1, b_reg); \
|
||||
out_reg1 = _mm256_add_epi16(inter_reg1, out_reg1); \
|
||||
inter_reg2 = _mm256_maddubs_epi16(a_reg2, b_reg); \
|
||||
out_reg2 = _mm256_add_epi16(inter_reg2, out_reg2); \
|
||||
inter_reg3 = _mm256_maddubs_epi16(a_reg3, b_reg); \
|
||||
out_reg3 = _mm256_add_epi16(inter_reg3, out_reg3); \
|
||||
inter_reg4 = _mm256_maddubs_epi16(a_reg4, b_reg); \
|
||||
out_reg4 = _mm256_add_epi16(inter_reg4, out_reg4);
|
||||
|
||||
#define LPGEMV_YMM2XMM( ymm0, ymm1, ymm2, ymm3, xmm0 ) \
|
||||
ymm0 = _mm256_hadd_epi16( ymm0, ymm1 ); \
|
||||
ymm1 = _mm256_hadd_epi16( ymm2, ymm3 ); \
|
||||
ymm0 = _mm256_hadd_epi16( ymm0, ymm1 ); \
|
||||
xmm0 = _mm_add_epi16( _mm256_extracti128_si256( ymm0, 0 ), \
|
||||
_mm256_extracti128_si256( ymm0, 1 ) );
|
||||
|
||||
|
||||
|
||||
LPGEMV_N_EQ1_KERN(int8_t, int8_t, int16_t, s8s8s16os16)
|
||||
{
|
||||
static void* post_ops_labels[] =
|
||||
{
|
||||
&&POST_OPS_DISABLE,
|
||||
&&POST_OPS_BIAS,
|
||||
&&POST_OPS_RELU,
|
||||
&&POST_OPS_RELU_SCALE,
|
||||
&&POST_OPS_GELU_TANH,
|
||||
&&POST_OPS_GELU_ERF,
|
||||
&&POST_OPS_CLIP,
|
||||
&&POST_OPS_DOWNSCALE,
|
||||
&&POST_OPS_MATRIX_ADD,
|
||||
&&POST_OPS_SWISH,
|
||||
NULL,// Virtual node for matrix_mul, else segfault
|
||||
&&POST_OPS_TANH,
|
||||
&&POST_OPS_SIGMOID
|
||||
};
|
||||
|
||||
int8_t *a_use = NULL;
|
||||
int8_t *b_use = NULL;
|
||||
int16_t *c_use = NULL;
|
||||
|
||||
lpgemm_post_op_attr post_ops_attr = *(post_op_attr);
|
||||
|
||||
// temp buffer to store output C vector
|
||||
int16_t ctemp[16];
|
||||
|
||||
// temp buffers to store a, b data in k_rem case.
|
||||
int8_t buf0[32] = {0};
|
||||
int8_t buf1[32] = {0};
|
||||
int8_t buf2[32] = {0};
|
||||
int8_t buf3[32] = {0};
|
||||
int8_t buf4[32] = {0};
|
||||
int8_t buf5[32] = {0};
|
||||
int8_t buf6[32] = {0};
|
||||
int8_t buf7[32] = {0};
|
||||
int8_t buf8[32] = {0};
|
||||
|
||||
|
||||
uint8_t cvt_uint8 = 128;
|
||||
__m256i vec_uint8;
|
||||
|
||||
int16_t* bsumptr = post_ops_attr.b_col_sum_vec_s16;
|
||||
|
||||
for ( dim_t ir = 0; ir < m0; ir += MR )
|
||||
{
|
||||
dim_t mr0 = bli_min( ( m0 - ir ), MR );
|
||||
dim_t k_iter = k / 32;
|
||||
dim_t k_rem = k % 32;
|
||||
|
||||
__m256i ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7;
|
||||
__m256i ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14;
|
||||
__m256i ymm15;
|
||||
|
||||
__m128i xmm0, xmm1;
|
||||
|
||||
/* zero the accumulator registers */
|
||||
ZERO_ACC_YMM_4_REG( ymm8, ymm9, ymm10, ymm11 )
|
||||
ZERO_ACC_YMM_4_REG( ymm12, ymm13, ymm14, ymm15 )
|
||||
|
||||
//update pointers
|
||||
a_use = (int8_t*)a + ir * rs_a;
|
||||
b_use = (int8_t*)b;
|
||||
c_use = (int16_t*)c + ir * rs_c;
|
||||
|
||||
if( mr0 == MR )
|
||||
{
|
||||
vec_uint8 = _mm256_set1_epi8 (cvt_uint8);
|
||||
|
||||
for (dim_t k = 0; k < k_iter; k++)
|
||||
{
|
||||
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)(b_use) );
|
||||
b_use += 32;
|
||||
|
||||
//Load 4x32 elements from row0-row3 of A
|
||||
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, a_use, rs_a )
|
||||
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
|
||||
ymm6, ymm4, ymm5, ymm7, ymm4,
|
||||
ymm8, ymm9, ymm10, ymm11
|
||||
)
|
||||
|
||||
// Load 4x32 elements from row8-row11 of A
|
||||
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3,
|
||||
( a_use + 4 * rs_a ), rs_a
|
||||
)
|
||||
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
|
||||
ymm6, ymm4, ymm5, ymm7, ymm4,
|
||||
ymm12, ymm13, ymm14, ymm15
|
||||
)
|
||||
|
||||
a_use += 32;
|
||||
}
|
||||
|
||||
|
||||
|
||||
if( k_rem )
|
||||
{
|
||||
uint8_t buf_vec_uint8_t[32] = {0};
|
||||
int8_t* restrict a0 = (a_use);
|
||||
int8_t* restrict a1 = (a_use + rs_a );
|
||||
int8_t* restrict a2 = (a_use + 2 * rs_a );
|
||||
int8_t* restrict a3 = (a_use + 3 * rs_a );
|
||||
int8_t* restrict a4 = (a_use + 4 * rs_a );
|
||||
int8_t* restrict a5 = (a_use + 5 * rs_a );
|
||||
int8_t* restrict a6 = (a_use + 6 * rs_a );
|
||||
int8_t* restrict a7 = (a_use + 7 * rs_a );
|
||||
|
||||
for( dim_t i = 0; i < k_rem; i++)
|
||||
{
|
||||
buf8[i] = b_use[i];
|
||||
buf0[i] = a0[i];
|
||||
buf1[i] = a1[i];
|
||||
buf2[i] = a2[i];
|
||||
buf3[i] = a3[i];
|
||||
buf4[i] = a4[i];
|
||||
buf5[i] = a5[i];
|
||||
buf6[i] = a6[i];
|
||||
buf7[i] = a7[i];
|
||||
buf_vec_uint8_t[i] = cvt_uint8;
|
||||
}
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
|
||||
|
||||
vec_uint8 = _mm256_loadu_si256( ( __m256i const *) buf_vec_uint8_t );
|
||||
|
||||
//Load 4x32 elements from row0-row3 of A
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
|
||||
ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 );
|
||||
ymm2 = _mm256_loadu_si256( (__m256i const *)buf2 );
|
||||
ymm3 = _mm256_loadu_si256( (__m256i const *)buf3 );
|
||||
|
||||
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 );
|
||||
ymm1 = _mm256_add_epi8( ymm1, vec_uint8 );
|
||||
ymm2 = _mm256_add_epi8( ymm2, vec_uint8 );
|
||||
ymm3 = _mm256_add_epi8( ymm3, vec_uint8 );
|
||||
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
|
||||
ymm6, ymm4, ymm5, ymm7, ymm4,
|
||||
ymm8, ymm9, ymm10, ymm11
|
||||
)
|
||||
|
||||
// Load 4x32 elements from row8-row11 of A
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)buf4 );
|
||||
ymm1 = _mm256_loadu_si256( (__m256i const *)buf5 );
|
||||
ymm2 = _mm256_loadu_si256( (__m256i const *)buf6 );
|
||||
ymm3 = _mm256_loadu_si256( (__m256i const *)buf7 );
|
||||
|
||||
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 );
|
||||
ymm1 = _mm256_add_epi8( ymm1, vec_uint8 );
|
||||
ymm2 = _mm256_add_epi8( ymm2, vec_uint8 );
|
||||
ymm3 = _mm256_add_epi8( ymm3, vec_uint8 );
|
||||
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
|
||||
ymm6, ymm4, ymm5, ymm7, ymm4,
|
||||
ymm12, ymm13, ymm14, ymm15
|
||||
)
|
||||
|
||||
}
|
||||
//Add the registers horizantally to get one
|
||||
LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, xmm0 )
|
||||
LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, xmm1 )
|
||||
|
||||
xmm0 = _mm_hadd_epi16( xmm0, xmm1 );
|
||||
|
||||
// post ops are applied on ymm register though
|
||||
// second half of the register is filled with zeroes.
|
||||
ymm8 = _mm256_setzero_si256();
|
||||
ymm8 = _mm256_inserti128_si256( ymm8, xmm0, 0);
|
||||
|
||||
ymm0 = _mm256_set1_epi16( *bsumptr );
|
||||
ymm8 = _mm256_sub_epi16( ymm8, ymm0 );
|
||||
}
|
||||
else
|
||||
{
|
||||
int8_t *a_use_fringe = a_use;
|
||||
dim_t mr0_use = mr0;
|
||||
dim_t regidx = 0;
|
||||
|
||||
if( mr0_use >= 4 )
|
||||
{
|
||||
vec_uint8 = _mm256_set1_epi8 (cvt_uint8);
|
||||
|
||||
for (dim_t k = 0; k < k_iter; k++)
|
||||
{
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)b_use );
|
||||
b_use += 32;
|
||||
|
||||
//Load 4x32 elements from row0-row3 of A
|
||||
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3,
|
||||
a_use, rs_a )
|
||||
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
|
||||
ymm6, ymm4, ymm5, ymm7, ymm4,
|
||||
ymm8, ymm9, ymm10, ymm11
|
||||
)
|
||||
|
||||
a_use += 32;
|
||||
}
|
||||
|
||||
if( k_rem )
|
||||
{
|
||||
uint8_t buf_vec_uint8_t[32] = {0};
|
||||
int8_t* restrict a0 = (a_use);
|
||||
int8_t* restrict a1 = (a_use + rs_a );
|
||||
int8_t* restrict a2 = (a_use + 2 * rs_a );
|
||||
int8_t* restrict a3 = (a_use + 3 * rs_a );
|
||||
|
||||
for( dim_t i = 0; i < k_rem; i++)
|
||||
{
|
||||
buf8[i] = b_use[i];
|
||||
buf0[i] = a0[i];
|
||||
buf1[i] = a1[i];
|
||||
buf2[i] = a2[i];
|
||||
buf3[i] = a3[i];
|
||||
buf_vec_uint8_t[i] = cvt_uint8;
|
||||
}
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
|
||||
|
||||
vec_uint8 = _mm256_loadu_si256( (__m256i const *)buf_vec_uint8_t );
|
||||
//Load 4xk_rem elements from row0-row3 of A
|
||||
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
|
||||
ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 );
|
||||
ymm2 = _mm256_loadu_si256( (__m256i const *)buf2 );
|
||||
ymm3 = _mm256_loadu_si256( (__m256i const *)buf3 );
|
||||
|
||||
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 );
|
||||
ymm1 = _mm256_add_epi8( ymm1, vec_uint8 );
|
||||
ymm2 = _mm256_add_epi8( ymm2, vec_uint8 );
|
||||
ymm3 = _mm256_add_epi8( ymm3, vec_uint8 );
|
||||
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
|
||||
ymm6, ymm4, ymm5, ymm7, ymm4,
|
||||
ymm8, ymm9, ymm10, ymm11
|
||||
)
|
||||
}
|
||||
|
||||
//update pointers
|
||||
mr0_use -= 4;
|
||||
a_use = a_use_fringe + 4 * rs_a;
|
||||
a_use_fringe = a_use;
|
||||
b_use = (int8_t*)b;
|
||||
|
||||
//Add the registers horizantally to get one
|
||||
LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, xmm0 )
|
||||
|
||||
xmm0 = _mm_hadd_epi16( xmm0, xmm0 );
|
||||
|
||||
int64_t data = _mm_extract_epi64( xmm0, 0);
|
||||
//insert xmm outputs into final output reg based on regidx
|
||||
ymm8 = _mm256_setzero_si256();
|
||||
ymm8 = _mm256_insert_epi64( ymm8, data, 0 );
|
||||
regidx++;
|
||||
}
|
||||
|
||||
// Dot product for <= 3
|
||||
if ( mr0_use )
|
||||
{
|
||||
// Dot product for m = 2
|
||||
if ( mr0_use >= 2 )
|
||||
{
|
||||
vec_uint8 = _mm256_set1_epi8 (cvt_uint8);
|
||||
|
||||
for ( dim_t k = 0; k < k_iter; k++ )
|
||||
{
|
||||
// Load 0-31 in b[k+0 - k+31]
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)b_use );
|
||||
|
||||
LPGEMV_N_KERNEL_2_LOADS( ymm0, ymm1, a_use, rs_a);
|
||||
|
||||
LPGEMV_N_KERNEL_2_FMA( ymm0, ymm1, ymm6, ymm4,
|
||||
ymm5, ymm12, ymm13);
|
||||
b_use += 32; // move b pointer to next 32 elements
|
||||
a_use += 32;
|
||||
}
|
||||
if ( k_rem )
|
||||
{
|
||||
uint8_t buf_vec_uint8_t[32] = {0};
|
||||
int8_t* restrict a0 = (a_use);
|
||||
int8_t* restrict a1 = (a_use + rs_a );
|
||||
|
||||
for( dim_t i = 0; i < k_rem; i++)
|
||||
{
|
||||
buf8[i] = b_use[i];
|
||||
buf0[i] = a0[i];
|
||||
buf1[i] = a1[i];
|
||||
buf_vec_uint8_t[i] = cvt_uint8;
|
||||
}
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
|
||||
|
||||
vec_uint8 = _mm256_loadu_si256( (__m256i const *)buf_vec_uint8_t );
|
||||
//Load 2xk_rem elements from row0-row3 of A
|
||||
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
|
||||
ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 );
|
||||
|
||||
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 );
|
||||
ymm1 = _mm256_add_epi8( ymm1, vec_uint8 );
|
||||
|
||||
LPGEMV_N_KERNEL_2_FMA( ymm0, ymm1, ymm6,
|
||||
ymm4, ymm5, ymm12, ymm13 );
|
||||
}
|
||||
|
||||
mr0_use -= 2;
|
||||
a_use = a_use_fringe + 2 * rs_a;
|
||||
a_use_fringe = a_use;
|
||||
b_use = (int8_t*)b;
|
||||
}
|
||||
|
||||
// Dot product for m = 1
|
||||
if ( mr0_use == 1 )
|
||||
{
|
||||
vec_uint8 = _mm256_set1_epi8 (cvt_uint8);
|
||||
|
||||
for ( dim_t k = 0; k < k_iter; k++ )
|
||||
{
|
||||
// Load 0-31 in b[k+0 - k+31]
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)b_use );
|
||||
|
||||
// Load 1x32 elements from row0-row1 of A
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)a_use );
|
||||
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 );
|
||||
|
||||
ymm4 = _mm256_maddubs_epi16(ymm0, ymm6);
|
||||
ymm14 = _mm256_add_epi16(ymm4, ymm14);
|
||||
|
||||
b_use += 32; // move b pointer to next 32 elements
|
||||
a_use += 32;
|
||||
}
|
||||
if ( k_rem )
|
||||
{
|
||||
uint8_t buf_vec_uint8_t[32] = {0};
|
||||
int8_t* restrict a0 = (a_use);
|
||||
|
||||
for( dim_t i = 0; i < k_rem; i++)
|
||||
{
|
||||
buf8[i] = b_use[i];
|
||||
buf0[i] = a0[i];
|
||||
buf_vec_uint8_t[i] = cvt_uint8;
|
||||
}
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
|
||||
|
||||
vec_uint8 = _mm256_loadu_si256( (__m256i const *)buf_vec_uint8_t );
|
||||
|
||||
//Load 1xk_rem elements from row0-row3 of A
|
||||
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
|
||||
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 );
|
||||
|
||||
ymm4 = _mm256_maddubs_epi16(ymm0, ymm6);
|
||||
ymm14 = _mm256_add_epi16(ymm4, ymm14);
|
||||
}
|
||||
|
||||
// When only fringe 1,
|
||||
// update the registers to store in order
|
||||
if ( !( mr0 & 0x2 ) ) ymm12 = ymm14;
|
||||
}
|
||||
|
||||
LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, xmm0)
|
||||
xmm0 = _mm_hadd_epi16( xmm0, xmm0 );
|
||||
|
||||
int64_t data = _mm_extract_epi64( xmm0, 0);
|
||||
//insert xmm outputs into final output reg based on regidx
|
||||
|
||||
if( regidx == 0 )
|
||||
{
|
||||
ymm8 = _mm256_insert_epi64( ymm8, data, 0 );
|
||||
}
|
||||
else
|
||||
{
|
||||
ymm8 = _mm256_insert_epi64( ymm8, data, 1 );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int16_t buf_vec_int16_t[16] = {0};
|
||||
for( dim_t i = 0; i < mr0; i++)
|
||||
buf_vec_int16_t[i] = *bsumptr;
|
||||
ymm0 = _mm256_loadu_si256( ( __m256i const *) buf_vec_int16_t);
|
||||
ymm8 = _mm256_sub_epi16( ymm8, ymm0 );
|
||||
}
|
||||
|
||||
// Load alpha and beta
|
||||
__m256i selector1 = _mm256_set1_epi16(alpha);
|
||||
__m256i selector2 = _mm256_set1_epi16(beta);
|
||||
|
||||
// Scale by alpha
|
||||
ymm8 = _mm256_mullo_epi16(selector1, ymm8);
|
||||
|
||||
if( beta != 0 )
|
||||
{
|
||||
if ( post_ops_attr.buf_downscale != NULL )
|
||||
{
|
||||
if( post_ops_attr.rs_c_downscale == 1 )
|
||||
{
|
||||
if( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
dim_t m0_rem_dscale_bytes = mr0 * sizeof( int8_t );
|
||||
|
||||
S8_S16_BETA_NLT16_MEMCP_UTIL( ctemp, 0,
|
||||
m0_rem_dscale_bytes );
|
||||
|
||||
S8_S16_BETA_OP_NLT16( ymm8, ctemp,
|
||||
selector1, selector2 )
|
||||
}
|
||||
else if( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
dim_t m0_rem_dscale_bytes = mr0 * sizeof( uint8_t );
|
||||
|
||||
U8_S16_BETA_NLT16_MEMCP_UTIL( ctemp, 0,
|
||||
m0_rem_dscale_bytes );
|
||||
|
||||
U8_S16_BETA_OP_NLT16( ymm8, ctemp,
|
||||
selector1, selector2 )
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
int8_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( (int8_t*)post_ops_attr.buf_downscale
|
||||
+ ( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) );
|
||||
}
|
||||
selector1 = _mm256_cvtepi8_epi32
|
||||
( _mm_loadu_si128( (__m128i const*)ctemp ) );
|
||||
S16_BETA_FMA( ymm8, selector1, selector2 );
|
||||
}
|
||||
else if( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
uint8_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( (uint8_t*)post_ops_attr.buf_downscale
|
||||
+ ( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) );
|
||||
}
|
||||
selector1 = _mm256_cvtepu8_epi32
|
||||
( _mm_loadu_si128( (__m128i const*)ctemp ) );
|
||||
S16_BETA_FMA( ymm8, selector1, selector2 );
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if( rs_c == 1 )
|
||||
{
|
||||
dim_t m0_rem_bytes = mr0 * sizeof( int16_t );
|
||||
memcpy( ctemp, c_use, m0_rem_bytes );
|
||||
S16_S16_BETA_OP_NLT16( ymm8, ctemp,
|
||||
selector1, selector2 )
|
||||
}
|
||||
else
|
||||
{
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = c_use[ i * rs_c ];
|
||||
}
|
||||
selector1 = _mm256_loadu_si256( (__m256i const *)ctemp );
|
||||
S16_BETA_FMA( ymm8, selector1, selector2 );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Post Ops
|
||||
lpgemm_post_op * post_ops_list_temp = post_op;
|
||||
|
||||
post_ops_attr.is_last_k = TRUE;
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP
|
||||
|
||||
|
||||
POST_OPS_BIAS:
|
||||
{
|
||||
|
||||
|
||||
selector1 =
|
||||
_mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args1) );
|
||||
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_RELU:
|
||||
{
|
||||
selector1 = _mm256_setzero_si256();
|
||||
|
||||
ymm8 = _mm256_max_epi16( selector1, ymm8 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_RELU_SCALE:
|
||||
{
|
||||
__m256i b0;
|
||||
selector1 = _mm256_setzero_si256();
|
||||
selector2 = _mm256_set1_epi16(
|
||||
*( ( int16_t* )post_ops_list_temp->op_args2 ) );
|
||||
|
||||
RELU_SCALE_OP_S16_AVX2( ymm8 )
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_GELU_TANH:
|
||||
{
|
||||
__m256 dn, z, x, r2, r, y1, y2, x_tanh;
|
||||
__m256i q;
|
||||
|
||||
GELU_TANH_S16_AVX2( ymm8, y1, y2, r, r2, x, z, dn, x_tanh, q )
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_GELU_ERF:
|
||||
{
|
||||
__m256 x, r, y1, y2, x_erf;
|
||||
|
||||
GELU_ERF_S16_AVX2(ymm8, y1, y2, r, x, x_erf)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_CLIP:
|
||||
{
|
||||
__m256i min = _mm256_set1_epi16(
|
||||
*( int16_t* )post_ops_list_temp->op_args2 );
|
||||
__m256i max = _mm256_set1_epi16(
|
||||
*( int16_t* )post_ops_list_temp->op_args3 );
|
||||
|
||||
CLIP_S16_AVX2(ymm8, min, max)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_DOWNSCALE:
|
||||
{
|
||||
__m128i temp[2];
|
||||
__m256i temp_32[2];
|
||||
__m256 temp_float[2];
|
||||
__m256 scale_1 = _mm256_setzero_ps();
|
||||
__m256 scale_2 = _mm256_setzero_ps();
|
||||
__m128i _zero_point_0 = _mm_setzero_si128();
|
||||
__m256i zero_point_0 = _mm256_setzero_si256();
|
||||
__m256 res_1, res_2;
|
||||
|
||||
scale_1 =
|
||||
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
|
||||
|
||||
scale_2 =
|
||||
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
|
||||
|
||||
_zero_point_0 = _mm_set1_epi8(
|
||||
*( ( int8_t* )post_ops_list_temp->op_args1 ) );
|
||||
|
||||
if ( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
|
||||
}
|
||||
|
||||
// Scale first 16 columns of the 2 rows.
|
||||
CVT_MULRND_CVT16(ymm8, scale_1, scale_2, zero_point_0)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
|
||||
POST_OPS_MATRIX_ADD:
|
||||
{
|
||||
dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3;
|
||||
|
||||
if ( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1;
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
memcpy
|
||||
(
|
||||
( int8_t* )ctemp,
|
||||
matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ),
|
||||
( mr0 ) * sizeof(int8_t)
|
||||
);
|
||||
selector1 = _mm256_cvtepi8_epi16(
|
||||
_mm_loadu_si128( ( __m128i const* )ctemp ) );
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
}
|
||||
else
|
||||
{
|
||||
int8_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( matptr +
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm256_cvtepi8_epi16
|
||||
( _mm_loadu_si128( (__m128i const*)ctemp ) );
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
}
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1;
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
memcpy
|
||||
(
|
||||
( uint8_t* )ctemp,
|
||||
matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ),
|
||||
( mr0 ) * sizeof(uint8_t)
|
||||
);
|
||||
selector1 = _mm256_cvtepu8_epi16(
|
||||
_mm_loadu_si128( ( __m128i const* )ctemp ) );
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
}
|
||||
else
|
||||
{
|
||||
uint8_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( matptr +
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm256_cvtepu8_epi16
|
||||
( _mm_loadu_si128( (__m128i const*)ctemp ) );
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1;
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
memcpy
|
||||
(
|
||||
( int16_t* )ctemp,
|
||||
matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ),
|
||||
( mr0 ) * sizeof(int16_t)
|
||||
);
|
||||
|
||||
selector1 = _mm256_loadu_si256( ( __m256i const* )ctemp );
|
||||
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
}
|
||||
else
|
||||
{
|
||||
int16_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( matptr +
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm256_loadu_si256( (__m256i const *)ctemp );
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
}
|
||||
}
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_SWISH:
|
||||
{
|
||||
selector1 =
|
||||
_mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) );
|
||||
__m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \
|
||||
_mm256_extractf128_si256( selector1, 0 ) ) );
|
||||
|
||||
__m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn;
|
||||
__m256i ex_out;
|
||||
|
||||
SWISH_S16_AVX2( ymm8, al, al_in, tmp_reg1,
|
||||
tmp_reg2, r, r2, z, dn, ex_out );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_TANH:
|
||||
{
|
||||
__m256 dn, z, x, r2, r, y1, y2;
|
||||
__m256i q;
|
||||
|
||||
TANH_S16_AVX2( ymm8, y1, y2, r, r2, x, z, dn, q )
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_SIGMOID:
|
||||
{
|
||||
__m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn;
|
||||
__m256i ex_out;
|
||||
|
||||
SIGMOID_S16_AVX2( ymm8, al_in, tmp_reg1,
|
||||
tmp_reg2, r, r2, z, dn, ex_out );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_DISABLE:
|
||||
{
|
||||
if ( post_ops_attr.buf_downscale != NULL )
|
||||
{
|
||||
__m128i temp[2];
|
||||
__m256i zero_reg = _mm256_setzero_si256();
|
||||
if( post_ops_attr.rs_c_downscale == 1 )
|
||||
{
|
||||
if( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
// Store the results in downscaled type
|
||||
// (int8 instead of int16).
|
||||
CVT_STORE_S16_S8_1ROW_NLT16(ymm8, zero_reg, ctemp);
|
||||
|
||||
dim_t m0_rem_dscale_bytes = mr0 * sizeof( int8_t );
|
||||
|
||||
CVT_STORE_S16_S8_NLT16_MEMCP_UTIL( ctemp, 0,
|
||||
m0_rem_dscale_bytes);
|
||||
}
|
||||
else if( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
// Store the results in downscaled type (uint8 instead of int16).
|
||||
CVT_STORE_S16_U8_1ROW_NLT16(ymm8, zero_reg, ctemp);
|
||||
|
||||
dim_t m0_rem_dscale_bytes = mr0 * sizeof( uint8_t );
|
||||
|
||||
CVT_STORE_S16_U8_NLT16_MEMCP_UTIL( ctemp, 0,
|
||||
m0_rem_dscale_bytes);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
int8_t ctemp[16];
|
||||
|
||||
CVT_STORE_S16_S8_1ROW_NLT16(ymm8, zero_reg, ctemp);
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
*( ( int8_t* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
|
||||
}
|
||||
}
|
||||
else if( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
uint8_t ctemp[16];
|
||||
|
||||
CVT_STORE_S16_U8_1ROW_NLT16(ymm8, zero_reg, ctemp);
|
||||
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
*( ( uint8_t* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if( rs_c == 1 )
|
||||
{
|
||||
_mm256_storeu_si256( ( __m256i* )ctemp, ymm8 );
|
||||
|
||||
dim_t m0_rem_bytes = mr0 * sizeof( int16_t );
|
||||
|
||||
memcpy( c_use, ctemp, m0_rem_bytes );
|
||||
}
|
||||
else
|
||||
{
|
||||
_mm256_storeu_si256( ( __m256i* )ctemp, ymm8 );
|
||||
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
c_use[i * rs_c] = ctemp[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
post_ops_attr.post_op_c_i += MR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,258 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include "blis.h"
|
||||
|
||||
#ifdef BLIS_ADDON_LPGEMM
|
||||
|
||||
void packb_nrlt16_u8s8s16o16
|
||||
(
|
||||
int8_t *pack_b_buffer_u8s8s16o16,
|
||||
const int8_t *b,
|
||||
const dim_t ldb,
|
||||
const dim_t rows,
|
||||
dim_t n0_partial_rem
|
||||
)
|
||||
{
|
||||
dim_t k_full_pieces_blks = rows / 2;
|
||||
dim_t k_full_pieces = k_full_pieces_blks * 2;
|
||||
dim_t k_partial_pieces = rows % 2;
|
||||
dim_t NR = 16;
|
||||
dim_t kr_new = 0;
|
||||
|
||||
int8_t buf0[16], buf1[16];
|
||||
|
||||
__m128i b_vec[2], inter_vec[2];
|
||||
|
||||
for (dim_t kr = 0; kr < k_full_pieces; kr += 2)
|
||||
{
|
||||
memcpy(buf0, (b + (ldb * (kr + 0))), (n0_partial_rem * sizeof(int8_t)));
|
||||
memcpy(buf1, (b + (ldb * (kr + 1))), (n0_partial_rem * sizeof(int8_t)));
|
||||
|
||||
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
|
||||
b_vec[0] = _mm_loadu_si128((__m128i *)buf0);
|
||||
// Read b[1,0], b[1,1], b[1,2]......., b[1,15]
|
||||
b_vec[1] = _mm_loadu_si128((__m128i *)buf1);
|
||||
|
||||
// Reorder B matrix inputs to suit vpmaddubsw instructions
|
||||
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
|
||||
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
|
||||
|
||||
// Store b[0,0], b[1,0], b[0,1]......., b[0,7], b[1,7]
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + (kr_new * NR)), inter_vec[0]);
|
||||
// Store b[0,8], b[1,8], b[0,9]......., b[0,15], b[1,15]
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
|
||||
|
||||
// Increment to ignore the padded bits
|
||||
kr_new += 2;
|
||||
}
|
||||
|
||||
// Handle k partial cases
|
||||
if (k_partial_pieces > 0)
|
||||
{
|
||||
memcpy(buf0, (b + (ldb * (k_full_pieces + 0))), (n0_partial_rem * sizeof(int8_t)));
|
||||
|
||||
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
|
||||
b_vec[0] = _mm_loadu_si128((__m128i *)buf0);
|
||||
b_vec[1] = _mm_setzero_si128(); // Initialize with zero for padding
|
||||
|
||||
// Reorder B matrix inputs to suit vpmaddubsw instructions
|
||||
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
|
||||
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
|
||||
|
||||
// Store b[0,0], 0, b[0,1]......., b[0,7], 0
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]);
|
||||
|
||||
// Store b[0,8], 0, b[0,9]......., b[0,15], 0
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
|
||||
}
|
||||
}
|
||||
|
||||
void packb_nr16_u8s8s16o16(
|
||||
int8_t *pack_b_buffer_u8s8s16o16,
|
||||
const int8_t *b,
|
||||
const dim_t ldb,
|
||||
const dim_t rows)
|
||||
{
|
||||
dim_t k_full_pieces_blks = rows / 2;
|
||||
dim_t k_full_pieces = k_full_pieces_blks * 2;
|
||||
dim_t k_partial_pieces = rows % 2;
|
||||
dim_t NR = 16;
|
||||
dim_t kr_new = 0;
|
||||
|
||||
__m128i b_vec[2], inter_vec[2];
|
||||
|
||||
for (dim_t kr = 0; kr < k_full_pieces; kr += 2)
|
||||
{
|
||||
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
|
||||
b_vec[0] = _mm_loadu_si128((__m128i const *)(b + (ldb * (kr + 0))));
|
||||
|
||||
// Read b[1,0], b[1,1], b[1,2]......., b[1,15]
|
||||
b_vec[1] = _mm_loadu_si128((__m128i const *)(b + (ldb * (kr + 1))));
|
||||
|
||||
// Reorder B matrix inputs to suit vpmaddubsw instructions
|
||||
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
|
||||
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
|
||||
|
||||
// Store b[0,0], b[1,0], b[0,1]......., b[0,7], b[1,7]
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]);
|
||||
|
||||
// Store b[0,8], b[1,8], b[0,9]......., b[0,15], b[1,15]
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
|
||||
|
||||
// Increment to ignore the padded bits
|
||||
kr_new += 2;
|
||||
}
|
||||
|
||||
if (k_partial_pieces > 0)
|
||||
{
|
||||
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
|
||||
b_vec[0] = _mm_loadu_si128((__m128i const *)(b + (ldb * (k_full_pieces + 0))));
|
||||
b_vec[1] = _mm_setzero_si128(); // Initialize with zero for padding
|
||||
|
||||
// Reorder B matrix inputs to suit vpmaddubsw instructions
|
||||
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
|
||||
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
|
||||
|
||||
// Store b[0,0], 0, b[0,1]......., b[0,7], 0
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]);
|
||||
// Store b[0,8], 0, b[0,9]......., b[0,15], 0
|
||||
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
|
||||
}
|
||||
}
|
||||
|
||||
void packb_nr32_u8s8s16o16(
|
||||
int8_t *pack_b_buffer_u8s8s16o16,
|
||||
const int8_t *b,
|
||||
const dim_t ldb,
|
||||
const dim_t cols,
|
||||
const dim_t rows,
|
||||
dim_t *rs_b,
|
||||
dim_t *cs_b)
|
||||
{
|
||||
dim_t NR = 32;
|
||||
|
||||
dim_t n_full_pieces = cols / NR;
|
||||
dim_t n_full_pieces_loop_limit = n_full_pieces * NR;
|
||||
dim_t n_partial_pieces = cols % NR;
|
||||
dim_t k_full_pieces_blks = rows / 2;
|
||||
dim_t k_full_pieces = k_full_pieces_blks * 2;
|
||||
dim_t k_partial_pieces = rows % 2;
|
||||
|
||||
dim_t KC_updated = rows;
|
||||
|
||||
// Making multiple of 2 to suit k in vpmaddubsw
|
||||
KC_updated += (KC_updated & 0x1);
|
||||
|
||||
__m256i b_vec[2], inter_vec[2];
|
||||
|
||||
for (dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR)
|
||||
{
|
||||
for (dim_t kr = 0; kr < k_full_pieces; kr += 2)
|
||||
{
|
||||
// Read b[0,0], b[0,1], b[0,2]......., b[0,31]
|
||||
b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 0)) + jc));
|
||||
|
||||
// Read b[1,0], b[1,1], b[1,2]......., b[1,31]
|
||||
b_vec[1] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 1)) + jc));
|
||||
|
||||
// Reorder B matrix inputs to suit vpmaddubsw instructions
|
||||
inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]);
|
||||
inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]);
|
||||
|
||||
b_vec[0] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x20);
|
||||
b_vec[1] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x31);
|
||||
|
||||
// Store B[0,0], B[1,0], B[0,1], B[1,1], ......, B[0,15], B[1,15]
|
||||
_mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + (kr * NR))), b_vec[0]);
|
||||
// Store B[0,16], B[1,16], B[0,17], B[1,17], ......, B[0,31], B[1,31]
|
||||
_mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + ((kr + 1) * NR))), b_vec[1]);
|
||||
}
|
||||
|
||||
if (k_partial_pieces > 0)
|
||||
{
|
||||
// Read b[0,0], b[0,1], b[0,2]......., b[0,31]
|
||||
b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (k_full_pieces + 0)) + jc));
|
||||
b_vec[1] = _mm256_setzero_si256(); // Initialize with zero for padding
|
||||
|
||||
// Reorder B matrix inputs to suit vpmaddubsw instructions
|
||||
inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]);
|
||||
inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]);
|
||||
|
||||
b_vec[0] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x20);
|
||||
b_vec[1] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x31);
|
||||
|
||||
// Store B[0,0], B[1,0], B[0,1], B[1,1], ......, B[0,15], B[1,15]
|
||||
_mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + (k_full_pieces * NR))), b_vec[0]);
|
||||
// Store B[0,16], B[1,16], B[0,17], B[1,17], ......, B[0,31], B[1,31]
|
||||
_mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + ((k_full_pieces + 1) * NR))), b_vec[1]);
|
||||
}
|
||||
}
|
||||
|
||||
// B matrix packing when n < NR
|
||||
if (n_partial_pieces > 0)
|
||||
{
|
||||
// Split into multiple smaller fringe kernels, so as to maximize
|
||||
// vectorization after packing. Any n0 < NR(32) can be expressed
|
||||
// as n0 = 16 + n`.
|
||||
dim_t n0_16 = n_partial_pieces / 16;
|
||||
dim_t n0_partial_rem = n_partial_pieces % 16;
|
||||
|
||||
dim_t n0_partial_pack = 0;
|
||||
|
||||
if (n0_16 == 1)
|
||||
{
|
||||
packb_nr16_u8s8s16o16(
|
||||
(pack_b_buffer_u8s8s16o16 +
|
||||
(n_full_pieces_loop_limit * KC_updated)),
|
||||
(b + n_full_pieces_loop_limit), ldb, rows);
|
||||
|
||||
n0_partial_pack = 16;
|
||||
}
|
||||
|
||||
if (n0_partial_rem > 0)
|
||||
{
|
||||
packb_nrlt16_u8s8s16o16(
|
||||
(pack_b_buffer_u8s8s16o16 + (n_full_pieces_loop_limit * KC_updated) +
|
||||
(n0_partial_pack * KC_updated)),
|
||||
(b + n_full_pieces_loop_limit + n0_partial_pack),
|
||||
ldb, rows, n0_partial_rem);
|
||||
}
|
||||
}
|
||||
|
||||
*rs_b = NR * 2;
|
||||
*cs_b = NR;
|
||||
}
|
||||
#endif
|
||||
@@ -1,510 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#ifndef LPGEMM_S16_KERN_MACROS_H
|
||||
#define LPGEMM_S16_KERN_MACROS_H
|
||||
|
||||
#include "../gelu_avx2.h"
|
||||
#include "../silu_avx2.h"
|
||||
#include "../sigmoid_avx2.h"
|
||||
#include "../math_utils_avx2.h"
|
||||
|
||||
#define S8_MIN (-128)
|
||||
#define S8_MAX (+127)
|
||||
|
||||
/* ReLU scale (Parametric ReLU): f(x) = x, when x > 0 and f(x) = a*x when x <= 0 */
|
||||
#define RELU_SCALE_OP_S16_AVX2(reg) \
|
||||
selector1 = _mm256_setzero_si256();\
|
||||
selector1 = _mm256_cmpgt_epi16 ( selector1, reg ); \
|
||||
\
|
||||
/* Only < 0 elements in b0. */ \
|
||||
b0 = _mm256_and_si256 ( selector1, reg ); \
|
||||
\
|
||||
/* Only >= 0 elements in c_int16_0p0. */ \
|
||||
reg = _mm256_andnot_si256( selector1, reg ); \
|
||||
\
|
||||
/* Only scaling for < 0 elements. */ \
|
||||
b0 = _mm256_mullo_epi16( b0, selector2 ); \
|
||||
\
|
||||
/* Combine the scaled < 0 and >= 0 elements. */ \
|
||||
reg = _mm256_or_si256( b0, reg ); \
|
||||
|
||||
// s16 fma macro
|
||||
#define S16_BETA_FMA(reg,scratch1,scratch2) \
|
||||
scratch1 = _mm256_mullo_epi16( scratch2, scratch1 ); \
|
||||
reg = _mm256_add_epi16( scratch1, reg ); \
|
||||
|
||||
// Beta scale macro, scratch2=beta
|
||||
#define S16_S16_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \
|
||||
scratch1 = \
|
||||
_mm256_loadu_si256 \
|
||||
( \
|
||||
( __m256i const* )( c + ( rs_c * ( m_ir + m_ind ) ) + ( n_ind * 16 ) ) \
|
||||
); \
|
||||
S16_BETA_FMA(reg,scratch1,scratch2) \
|
||||
|
||||
// Beta n < 16 scale macro, scratch2=beta
|
||||
#define S16_S16_BETA_OP_NLT16(reg,buf_,scratch1,scratch2) \
|
||||
scratch1 = _mm256_loadu_si256( ( __m256i const* )buf_ ); \
|
||||
S16_BETA_FMA(reg,scratch1,scratch2) \
|
||||
|
||||
// Downscale beta scale macro (s8 -> s16), scratch2=beta
|
||||
#define S8_S16_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \
|
||||
scratch1 = \
|
||||
_mm256_cvtepi8_epi16 \
|
||||
( \
|
||||
_mm_loadu_si128 \
|
||||
( \
|
||||
( __m128i const* )( ( int8_t* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) )\
|
||||
) \
|
||||
); \
|
||||
S16_BETA_FMA(reg,scratch1,scratch2) \
|
||||
|
||||
// Downscale beta scale macro (u8 -> s16), scratch2=beta
|
||||
#define U8_S16_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \
|
||||
scratch1 = \
|
||||
_mm256_cvtepu8_epi16 \
|
||||
( \
|
||||
_mm_loadu_si128 \
|
||||
( \
|
||||
( __m128i const* )( ( uint8_t* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) )\
|
||||
) \
|
||||
); \
|
||||
S16_BETA_FMA(reg,scratch1,scratch2) \
|
||||
|
||||
// Downscale beta n < 16 scale macro (s8 -> s16), scratch2=beta
|
||||
#define S8_S16_BETA_OP_NLT16(reg,buf_,scratch1,scratch2) \
|
||||
scratch1 = _mm256_cvtepi8_epi16( _mm_loadu_si128( ( __m128i const* )buf_ ) ); \
|
||||
S16_BETA_FMA(reg,scratch1,scratch2) \
|
||||
|
||||
// Downscale beta n < 16 scale macro (u8 -> s16), scratch2=beta
|
||||
#define U8_S16_BETA_OP_NLT16(reg,buf_,scratch1,scratch2) \
|
||||
scratch1 = _mm256_cvtepu8_epi16( _mm_loadu_si128( ( __m128i const* )buf_ ) ); \
|
||||
S16_BETA_FMA(reg,scratch1,scratch2) \
|
||||
|
||||
#define US8_S16_BETA_NLT16_MEMCP_HELPER(buf_,m_ind,bytes, C_type) \
|
||||
memcpy \
|
||||
( \
|
||||
buf_, \
|
||||
( ( C_type* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
|
||||
post_ops_attr.post_op_c_j ), bytes \
|
||||
); \
|
||||
|
||||
#define S8_S16_BETA_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \
|
||||
US8_S16_BETA_NLT16_MEMCP_HELPER(buf_,m_ind,bytes,int8_t) \
|
||||
|
||||
#define U8_S16_BETA_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \
|
||||
US8_S16_BETA_NLT16_MEMCP_HELPER(buf_,m_ind,bytes,uint8_t) \
|
||||
|
||||
// Downscale macro
|
||||
#define CVT_MULRND_CVT16(reg, scale0, scale1, zero_point_0) \
|
||||
\
|
||||
/* Extract the first 128 bits of the register*/ \
|
||||
temp[0] = _mm256_extractf128_si256( reg, 0 ); \
|
||||
/* Extract the second 128 bits of the register*/ \
|
||||
temp[1] = _mm256_extractf128_si256( reg, 1 ); \
|
||||
\
|
||||
temp_32[0] = _mm256_cvtepi16_epi32( temp[0] ); \
|
||||
temp_32[1] = _mm256_cvtepi16_epi32( temp[1] ); \
|
||||
temp_float[0] = _mm256_cvtepi32_ps( temp_32[0] ); \
|
||||
temp_float[1] = _mm256_cvtepi32_ps( temp_32[1] ); \
|
||||
\
|
||||
/* Multiply the C matrix by the scale value*/ \
|
||||
res_1 = _mm256_mul_ps( temp_float[0], scale0 ); \
|
||||
res_2 = _mm256_mul_ps( temp_float[1], scale1 ); \
|
||||
\
|
||||
/* Round the resultant value to the nearest float value. */ \
|
||||
res_1 = \
|
||||
_mm256_round_ps \
|
||||
( \
|
||||
res_1, ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \
|
||||
); \
|
||||
res_2 = \
|
||||
_mm256_round_ps \
|
||||
( \
|
||||
res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) \
|
||||
); \
|
||||
\
|
||||
/* Convert the clipped float32 scaled rounded value to int32 */ \
|
||||
temp_32[0] = _mm256_cvtps_epi32( res_1 ); \
|
||||
temp_32[1] = _mm256_cvtps_epi32( res_2 ); \
|
||||
\
|
||||
/* Convert the s32 to s16 */ \
|
||||
reg = _mm256_packs_epi32( temp_32[0], temp_32[1] ); \
|
||||
\
|
||||
/*Permute to make sure the order is correct*/ \
|
||||
reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \
|
||||
\
|
||||
/* Zero point addition.*/ \
|
||||
reg = _mm256_add_epi16( reg, zero_point_0 ); \
|
||||
|
||||
// Downscale store macro helper
|
||||
#define CVT_STORE_S16_SU8_HELPER(reg, m_ind, n_ind, C_type) \
|
||||
reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \
|
||||
\
|
||||
_mm256_storeu_si256 \
|
||||
( \
|
||||
( __m256i* )( ( C_type* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 32 ) ), \
|
||||
reg \
|
||||
); \
|
||||
|
||||
// Downscale store macro (s16 -> s8)
|
||||
#define CVT_STORE_S16_S8(reg0, reg1, m_ind, n_ind) \
|
||||
/* Convert the s16 to s8 */ \
|
||||
reg0 = _mm256_packs_epi16( reg0, reg1 ); \
|
||||
CVT_STORE_S16_SU8_HELPER(reg0, m_ind, n_ind, int8_t) \
|
||||
|
||||
// Downscale store macro (s16 -> u8)
|
||||
#define CVT_STORE_S16_U8(reg0, reg1, m_ind, n_ind) \
|
||||
/* Convert the s16 to s8 */ \
|
||||
reg0 = _mm256_packus_epi16( reg0, reg1 ); \
|
||||
CVT_STORE_S16_SU8_HELPER(reg0, m_ind, n_ind, uint8_t) \
|
||||
|
||||
// Downscale store helper macro for fringe cases
|
||||
#define CVT_STORE_S16_US8_2ROW_HELPER(reg, m_ind0, m_ind1, n_ind, C_type) \
|
||||
reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \
|
||||
\
|
||||
/* Extract the first 128 bits of the register*/ \
|
||||
temp[0] = _mm256_extractf128_si256( reg, 0 ); \
|
||||
/* Extract the second 128 bits of the register*/ \
|
||||
temp[1] = _mm256_extractf128_si256( reg, 1 ); \
|
||||
\
|
||||
_mm_storeu_si128 \
|
||||
( \
|
||||
( __m128i* )( ( C_type* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind0 ) ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) ), \
|
||||
temp[0] \
|
||||
); \
|
||||
_mm_storeu_si128 \
|
||||
( \
|
||||
( __m128i* )( ( C_type* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind1 ) ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) ), \
|
||||
temp[1] \
|
||||
); \
|
||||
|
||||
// Downscale store macro for fringe cases (s16 -> s8)
|
||||
#define CVT_STORE_S16_S8_2ROW(reg0, reg1, m_ind0, m_ind1, n_ind) \
|
||||
/* Convert the s16 to s8 */ \
|
||||
reg0 = _mm256_packs_epi16( reg0, reg1 ); \
|
||||
CVT_STORE_S16_US8_2ROW_HELPER(reg0, m_ind0, m_ind1, n_ind, int8_t) \
|
||||
|
||||
// Downscale store macro for fringe cases (s16 -> u8)
|
||||
#define CVT_STORE_S16_U8_2ROW(reg0, reg1, m_ind0, m_ind1, n_ind) \
|
||||
/* Convert the s16 to u8 */ \
|
||||
reg0 = _mm256_packus_epi16( reg0, reg1 ); \
|
||||
CVT_STORE_S16_US8_2ROW_HELPER(reg0, m_ind0, m_ind1, n_ind, uint8_t) \
|
||||
|
||||
// Downscale store helper macro for fringe cases
|
||||
#define CVT_STORE_S16_US8_1ROW(reg, m_ind0, n_ind, C_type) \
|
||||
reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \
|
||||
\
|
||||
/* Extract the first 128 bits of the register*/ \
|
||||
temp[0] = _mm256_extractf128_si256( reg, 0 ); \
|
||||
\
|
||||
_mm_storeu_si128 \
|
||||
( \
|
||||
( __m128i* )( ( C_type* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind0 ) ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) ), \
|
||||
temp[0] \
|
||||
); \
|
||||
|
||||
// Downscale store (s16 -> s8) macro for fringe cases
|
||||
#define CVT_STORE_S16_S8_1ROW(reg0, reg1, m_ind0, n_ind) \
|
||||
/* Convert the s16 to s8 */ \
|
||||
reg0 = _mm256_packs_epi16( reg0, reg1 ); \
|
||||
CVT_STORE_S16_US8_1ROW(reg0, m_ind0, n_ind, int8_t) \
|
||||
|
||||
// Downscale store (s16 -> u8) macro for fringe cases
|
||||
#define CVT_STORE_S16_U8_1ROW(reg0, reg1, m_ind0, n_ind) \
|
||||
/* Convert the s16 to u8 */ \
|
||||
reg0 = _mm256_packus_epi16( reg0, reg1 ); \
|
||||
CVT_STORE_S16_US8_1ROW(reg0, m_ind0, n_ind, uint8_t) \
|
||||
|
||||
// Downscale store helper macro for n < 16 fringe cases
|
||||
#define CVT_STORE_S16_US8_2ROW_NLT16(reg, buf0, buf1) \
|
||||
reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \
|
||||
\
|
||||
/* Extract the first 128 bits of the register*/ \
|
||||
temp[0] = _mm256_extractf128_si256( reg, 0 ); \
|
||||
/* Extract the second 128 bits of the register*/ \
|
||||
temp[1] = _mm256_extractf128_si256( reg, 1 ); \
|
||||
\
|
||||
_mm_storeu_si128( ( __m128i* )buf0, temp[0] ); \
|
||||
_mm_storeu_si128( ( __m128i* )buf1, temp[1] ); \
|
||||
|
||||
// Downscale store (int16 -> s8) macro for n < 16 fringe cases
|
||||
#define CVT_STORE_S16_S8_2ROW_NLT16(reg0, reg1, buf0, buf1) \
|
||||
/* Convert the s16 to s8 */ \
|
||||
reg0 = _mm256_packs_epi16( reg0, reg1 ); \
|
||||
CVT_STORE_S16_US8_2ROW_NLT16(reg0, buf0, buf1) \
|
||||
|
||||
// Downscale store (int16 -> u8) macro for n < 16 fringe cases
|
||||
#define CVT_STORE_S16_U8_2ROW_NLT16(reg0, reg1, buf0, buf1) \
|
||||
/* Convert the s16 to s8 */ \
|
||||
reg0 = _mm256_packus_epi16( reg0, reg1 ); \
|
||||
CVT_STORE_S16_US8_2ROW_NLT16(reg0, buf0, buf1) \
|
||||
|
||||
// Downscale store helper macro for n < 16 fringe cases
|
||||
#define CVT_STORE_S16_US8_1ROW_NLT16(reg, buf0) \
|
||||
reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \
|
||||
\
|
||||
/* Extract the first 128 bits of the register*/ \
|
||||
temp[0] = _mm256_extractf128_si256( reg, 0 ); \
|
||||
\
|
||||
_mm_storeu_si128( ( __m128i* )buf0, temp[0] ); \
|
||||
|
||||
// Downscale store (s16 -> s8) macro for n < 16 fringe cases
|
||||
#define CVT_STORE_S16_S8_1ROW_NLT16(reg0, reg1, buf0) \
|
||||
/* Convert the s16 to s8 */ \
|
||||
reg0 = _mm256_packs_epi16( reg0, reg1 ); \
|
||||
CVT_STORE_S16_US8_1ROW_NLT16(reg0, buf0) \
|
||||
|
||||
// Downscale store (s16 -> u8) macro for n < 16 fringe cases
|
||||
#define CVT_STORE_S16_U8_1ROW_NLT16(reg0, reg1, buf0) \
|
||||
/* Convert the s16 to u8 */ \
|
||||
reg0 = _mm256_packus_epi16( reg0, reg1 ); \
|
||||
CVT_STORE_S16_US8_1ROW_NLT16(reg0, buf0) \
|
||||
|
||||
#define CVT_STORE_S16_US8_NLT16_MEMCP_HELPER(buf_,m_ind,bytes, C_type) \
|
||||
memcpy \
|
||||
( \
|
||||
( ( C_type* )post_ops_attr.buf_downscale + \
|
||||
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
|
||||
post_ops_attr.post_op_c_j ), buf_, bytes \
|
||||
); \
|
||||
|
||||
#define CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \
|
||||
CVT_STORE_S16_US8_NLT16_MEMCP_HELPER(buf_,m_ind,bytes, int8_t) \
|
||||
|
||||
#define CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \
|
||||
CVT_STORE_S16_US8_NLT16_MEMCP_HELPER(buf_,m_ind,bytes, uint8_t) \
|
||||
|
||||
//--------------------------------------------------------------------------
|
||||
/* GeLU (x) = 0.5* x * (1 + tanh ( 0.797884 * ( x + ( 0.044715 * x^3 ) ) ) ) */
|
||||
#define GELU_TANH_S16_AVX2(reg, y1, y2, r, r2, x, z, dn, x_tanh, q) \
|
||||
\
|
||||
y1 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(reg, 0)) ); \
|
||||
y2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(reg, 1)) ); \
|
||||
\
|
||||
GELU_TANH_F32_AVX2_DEF(y1, r, r2, x, z, dn, x_tanh, q); \
|
||||
\
|
||||
GELU_TANH_F32_AVX2_DEF(y2, r, r2, x, z, dn, x_tanh, q); \
|
||||
\
|
||||
reg = _mm256_packs_epi32(_mm256_cvtps_epi32(y1), _mm256_cvtps_epi32(y2));\
|
||||
reg = _mm256_permute4x64_epi64(reg, 0XD8);\
|
||||
|
||||
|
||||
/* ERF GeLU (x) = 0.5* x * (1 + erf (x * 0.707107 )) */
|
||||
#define GELU_ERF_S16_AVX2(reg, y1, y2, r, x, x_erf) \
|
||||
\
|
||||
y1 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(reg, 0)) ); \
|
||||
y2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(reg, 1)) ); \
|
||||
\
|
||||
GELU_ERF_F32_AVX2_DEF(y1, r, x, x_erf); \
|
||||
\
|
||||
GELU_ERF_F32_AVX2_DEF(y2, r, x, x_erf); \
|
||||
\
|
||||
reg = _mm256_packs_epi32(_mm256_cvtps_epi32(y1), _mm256_cvtps_epi32(y2));\
|
||||
reg = _mm256_permute4x64_epi64(reg, 0XD8);\
|
||||
|
||||
#define CLIP_S16_AVX2(reg, min, max) \
|
||||
\
|
||||
reg = _mm256_min_epi16( _mm256_max_epi16( reg, min ), max ); \
|
||||
|
||||
// Matrix Add post-ops helper macros
|
||||
#define S16_MATRIX_ADD_1COL(scr0,m_ind) \
|
||||
c_int16_ ## m_ind ## p0 = _mm256_add_epi16( scr0, c_int16_ ## m_ind ## p0 ); \
|
||||
|
||||
#define S16_MATRIX_ADD_2COL(scr0,scr1,m_ind) \
|
||||
c_int16_ ## m_ind ## p0 = _mm256_add_epi16( scr0, c_int16_ ## m_ind ## p0 ); \
|
||||
c_int16_ ## m_ind ## p1 = _mm256_add_epi16( scr1, c_int16_ ## m_ind ## p1 ); \
|
||||
|
||||
#define S8_S16_MATRIX_ADD_LOAD(scr,m_ind,n_ind) \
|
||||
scr = _mm256_cvtepi8_epi16 \
|
||||
( \
|
||||
_mm_loadu_si128 \
|
||||
( \
|
||||
( __m128i const* ) \
|
||||
( matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
#define S8_S16_MATRIX_ADD_1COL_PAR(buf,scr0,m_ind,n_rem,OTYPE) \
|
||||
memcpy \
|
||||
( \
|
||||
( OTYPE* )buf, \
|
||||
matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ), \
|
||||
( n_rem ) * sizeof(OTYPE) \
|
||||
); \
|
||||
scr0 = _mm256_cvtepi8_epi16 \
|
||||
( \
|
||||
_mm_loadu_si128( ( __m128i const* )buf ) \
|
||||
); \
|
||||
S16_MATRIX_ADD_1COL(scr0,m_ind); \
|
||||
|
||||
#define S8_S16_MATRIX_ADD_1COL(scr0,m_ind) \
|
||||
S8_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \
|
||||
S16_MATRIX_ADD_1COL(scr0,m_ind); \
|
||||
|
||||
#define S8_S16_MATRIX_ADD_2COL(scr0,scr1,m_ind) \
|
||||
S8_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \
|
||||
S8_S16_MATRIX_ADD_LOAD(scr1,m_ind,1); \
|
||||
S16_MATRIX_ADD_2COL(scr0,scr1,m_ind); \
|
||||
|
||||
#define U8_S16_MATRIX_ADD_LOAD(scr,m_ind,n_ind) \
|
||||
scr = _mm256_cvtepu8_epi16 \
|
||||
( \
|
||||
_mm_loadu_si128 \
|
||||
( \
|
||||
( __m128i const* ) \
|
||||
( matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) ) \
|
||||
) \
|
||||
); \
|
||||
|
||||
#define U8_S16_MATRIX_ADD_1COL_PAR(buf,scr0,m_ind,n_rem,OTYPE) \
|
||||
memcpy \
|
||||
( \
|
||||
( OTYPE* )buf, \
|
||||
matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ), \
|
||||
( n_rem ) * sizeof(OTYPE) \
|
||||
); \
|
||||
scr0 = _mm256_cvtepu8_epi16 \
|
||||
( \
|
||||
_mm_loadu_si128( ( __m128i const* )buf ) \
|
||||
); \
|
||||
S16_MATRIX_ADD_1COL(scr0,m_ind); \
|
||||
|
||||
#define U8_S16_MATRIX_ADD_1COL(scr0,m_ind) \
|
||||
U8_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \
|
||||
S16_MATRIX_ADD_1COL(scr0,m_ind); \
|
||||
|
||||
#define U8_S16_MATRIX_ADD_2COL(scr0,scr1,m_ind) \
|
||||
U8_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \
|
||||
U8_S16_MATRIX_ADD_LOAD(scr1,m_ind,1); \
|
||||
S16_MATRIX_ADD_2COL(scr0,scr1,m_ind); \
|
||||
|
||||
#define S16_S16_MATRIX_ADD_LOAD(scr,m_ind,n_ind) \
|
||||
scr = _mm256_loadu_si256 \
|
||||
( \
|
||||
(__m256i const *) \
|
||||
( matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
|
||||
post_ops_attr.post_op_c_j + ( n_ind * 16 ) ) \
|
||||
); \
|
||||
|
||||
#define S16_S16_MATRIX_ADD_1COL_PAR(buf,scr0,m_ind,n_rem,OTYPE) \
|
||||
memcpy \
|
||||
( \
|
||||
( OTYPE* )buf, \
|
||||
matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ), \
|
||||
( n_rem ) * sizeof(OTYPE) \
|
||||
); \
|
||||
scr0 = _mm256_loadu_si256( ( __m256i const* )buf ); \
|
||||
S16_MATRIX_ADD_1COL(scr0,m_ind); \
|
||||
|
||||
#define S16_S16_MATRIX_ADD_1COL(scr0,m_ind) \
|
||||
S16_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \
|
||||
S16_MATRIX_ADD_1COL(scr0,m_ind); \
|
||||
|
||||
#define S16_S16_MATRIX_ADD_2COL(scr0,scr1,m_ind) \
|
||||
S16_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \
|
||||
S16_S16_MATRIX_ADD_LOAD(scr1,m_ind,1); \
|
||||
S16_MATRIX_ADD_2COL(scr0,scr1,m_ind); \
|
||||
|
||||
// SiLU utility macros. al1, al2 register expected to contain floats.
|
||||
#define SWISH_S16_AVX2(in_reg, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out) \
|
||||
\
|
||||
tmp_reg1 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \
|
||||
_mm256_extractf128_si256( in_reg, 0 ) ) ); \
|
||||
tmp_reg2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \
|
||||
_mm256_extractf128_si256( in_reg, 1 ) ) ); \
|
||||
\
|
||||
SWISH_F32_AVX2_DEF(tmp_reg1, al, al_in, r, r2, z, dn, ex_out); \
|
||||
\
|
||||
SWISH_F32_AVX2_DEF(tmp_reg2, al, al_in, r, r2, z, dn, ex_out); \
|
||||
\
|
||||
in_reg = _mm256_packs_epi32(_mm256_cvtps_epi32(tmp_reg1), _mm256_cvtps_epi32(tmp_reg2));\
|
||||
in_reg = _mm256_permute4x64_epi64(in_reg, 0XD8);\
|
||||
|
||||
//TANH utility macros.
|
||||
#define TANH_S16_AVX2(reg, y1, y2, r, r2, x, z, dn, q) \
|
||||
\
|
||||
y1 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(reg, 0)) ); \
|
||||
y2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(reg, 1)) ); \
|
||||
\
|
||||
TANHF_AVX2(y1, r, r2, x, z, dn, q); \
|
||||
\
|
||||
TANHF_AVX2(y2, r, r2, x, z, dn, q); \
|
||||
\
|
||||
reg = _mm256_packs_epi32(_mm256_cvtps_epi32(y1), _mm256_cvtps_epi32(y2));\
|
||||
reg = _mm256_permute4x64_epi64(reg, 0XD8);\
|
||||
|
||||
// SIGMOID utility macros. al1, al2 register expected to contain floats.
|
||||
#define SIGMOID_S16_AVX2(in_reg, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out) \
|
||||
\
|
||||
tmp_reg1 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \
|
||||
_mm256_extractf128_si256( in_reg, 0 ) ) ); \
|
||||
tmp_reg2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \
|
||||
_mm256_extractf128_si256( in_reg, 1 ) ) ); \
|
||||
\
|
||||
SIGMOID_F32_AVX2_DEF(tmp_reg1, al_in, r, r2, z, dn, ex_out); \
|
||||
\
|
||||
SIGMOID_F32_AVX2_DEF(tmp_reg2, al_in, r, r2, z, dn, ex_out); \
|
||||
\
|
||||
in_reg = _mm256_packs_epi32(_mm256_cvtps_epi32(tmp_reg1), _mm256_cvtps_epi32(tmp_reg2));\
|
||||
in_reg = _mm256_permute4x64_epi64(in_reg, 0XD8);\
|
||||
|
||||
//Zero-out the given YMM accumulator registers
|
||||
#define ZERO_ACC_YMM_4_REG(ymm0,ymm1,ymm2,ymm3) \
|
||||
ymm0 = _mm256_setzero_si256 (); \
|
||||
ymm1 = _mm256_setzero_si256 (); \
|
||||
ymm2 = _mm256_setzero_si256 (); \
|
||||
ymm3 = _mm256_setzero_si256 ();
|
||||
|
||||
|
||||
#endif //LPGEMM_S16_KERN_MACROS_H
|
||||
@@ -1,815 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include "blis.h"
|
||||
|
||||
#ifdef BLIS_ADDON_LPGEMM
|
||||
|
||||
#include "lpgemm_s16_kern_macros.h"
|
||||
|
||||
#define LPGEMV_N_KERNEL_2_LOADS( ymm0, ymm1, paddr, stride ) \
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)paddr ); \
|
||||
ymm1 = _mm256_loadu_si256( (__m256i const *)(paddr + stride) );
|
||||
|
||||
#define LPGEMV_N_KERNEL_2_FMA( a_reg1, a_reg2, b_reg, \
|
||||
inter_reg1, inter_reg2, c_reg1, c_reg2 ) \
|
||||
inter_reg1 = _mm256_maddubs_epi16(a_reg1, b_reg); \
|
||||
c_reg1 = _mm256_add_epi16(inter_reg1, c_reg1); \
|
||||
inter_reg2 = _mm256_maddubs_epi16(a_reg2, b_reg); \
|
||||
c_reg2 = _mm256_add_epi16(inter_reg2, c_reg2);
|
||||
|
||||
|
||||
#define LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, paddr, stride ) \
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)(paddr) ); \
|
||||
ymm1 = _mm256_loadu_si256( (__m256i const *)(paddr + stride) ); \
|
||||
ymm2 = _mm256_loadu_si256( (__m256i const *)(paddr + 2 * stride) ); \
|
||||
ymm3 = _mm256_loadu_si256( (__m256i const *)(paddr + 3 * stride) );
|
||||
|
||||
#define LPGEMV_N_KERNEL_4_FMA( a_reg1, a_reg2, a_reg3, a_reg4, b_reg, \
|
||||
inter_reg1, inter_reg2, \
|
||||
inter_reg3, inter_reg4, \
|
||||
out_reg1, out_reg2, out_reg3, out_reg4 ) \
|
||||
inter_reg1 = _mm256_maddubs_epi16(a_reg1, b_reg); \
|
||||
out_reg1 = _mm256_add_epi16(inter_reg1, out_reg1); \
|
||||
inter_reg2 = _mm256_maddubs_epi16(a_reg2, b_reg); \
|
||||
out_reg2 = _mm256_add_epi16(inter_reg2, out_reg2); \
|
||||
inter_reg3 = _mm256_maddubs_epi16(a_reg3, b_reg); \
|
||||
out_reg3 = _mm256_add_epi16(inter_reg3, out_reg3); \
|
||||
inter_reg4 = _mm256_maddubs_epi16(a_reg4, b_reg); \
|
||||
out_reg4 = _mm256_add_epi16(inter_reg4, out_reg4);
|
||||
|
||||
#define LPGEMV_YMM2XMM( ymm0, ymm1, ymm2, ymm3, xmm0 ) \
|
||||
ymm0 = _mm256_hadd_epi16( ymm0, ymm1 ); \
|
||||
ymm1 = _mm256_hadd_epi16( ymm2, ymm3 ); \
|
||||
ymm0 = _mm256_hadd_epi16( ymm0, ymm1 ); \
|
||||
xmm0 = _mm_add_epi16( _mm256_extracti128_si256( ymm0, 0 ), \
|
||||
_mm256_extracti128_si256( ymm0, 1 ) );
|
||||
|
||||
|
||||
|
||||
LPGEMV_N_EQ1_KERN(uint8_t, int8_t, int16_t, u8s8s16os16)
|
||||
{
|
||||
static void* post_ops_labels[] =
|
||||
{
|
||||
&&POST_OPS_DISABLE,
|
||||
&&POST_OPS_BIAS,
|
||||
&&POST_OPS_RELU,
|
||||
&&POST_OPS_RELU_SCALE,
|
||||
&&POST_OPS_GELU_TANH,
|
||||
&&POST_OPS_GELU_ERF,
|
||||
&&POST_OPS_CLIP,
|
||||
&&POST_OPS_DOWNSCALE,
|
||||
&&POST_OPS_MATRIX_ADD,
|
||||
&&POST_OPS_SWISH,
|
||||
NULL,// Virtual node for matrix_mul, else segfault
|
||||
&&POST_OPS_TANH,
|
||||
&&POST_OPS_SIGMOID
|
||||
};
|
||||
|
||||
uint8_t *a_use = NULL;
|
||||
int8_t *b_use = NULL;
|
||||
int16_t *c_use = NULL;
|
||||
|
||||
lpgemm_post_op_attr post_ops_attr = *(post_op_attr);
|
||||
|
||||
// temp buffer to store output C vector
|
||||
int16_t ctemp[16];
|
||||
|
||||
// temp buffers to store a, b data in k_rem case.
|
||||
uint8_t buf0[32] = {0};
|
||||
uint8_t buf1[32] = {0};
|
||||
uint8_t buf2[32] = {0};
|
||||
uint8_t buf3[32] = {0};
|
||||
uint8_t buf4[32] = {0};
|
||||
uint8_t buf5[32] = {0};
|
||||
uint8_t buf6[32] = {0};
|
||||
uint8_t buf7[32] = {0};
|
||||
int8_t buf8[32] = {0};
|
||||
|
||||
for ( dim_t ir = 0; ir < m0; ir += MR )
|
||||
{
|
||||
dim_t mr0 = bli_min( ( m0 - ir ), MR );
|
||||
dim_t k_iter = k / 32;
|
||||
dim_t k_rem = k % 32;
|
||||
|
||||
__m256i ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7;
|
||||
__m256i ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14;
|
||||
__m256i ymm15;
|
||||
|
||||
__m128i xmm0, xmm1;
|
||||
|
||||
/* zero the accumulator registers */
|
||||
ZERO_ACC_YMM_4_REG( ymm8, ymm9, ymm10, ymm11 )
|
||||
ZERO_ACC_YMM_4_REG( ymm12, ymm13, ymm14, ymm15 )
|
||||
|
||||
//update pointers
|
||||
a_use = (uint8_t*)a + ir * rs_a;
|
||||
b_use = (int8_t*)b;
|
||||
c_use = (int16_t*)c + ir * rs_c;
|
||||
|
||||
if( mr0 == MR )
|
||||
{
|
||||
for (dim_t k = 0; k < k_iter; k++)
|
||||
{
|
||||
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)(b_use) );
|
||||
b_use += 32;
|
||||
|
||||
//Load 4x32 elements from row0-row3 of A
|
||||
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, a_use, rs_a )
|
||||
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
|
||||
ymm6, ymm4, ymm5, ymm7, ymm4,
|
||||
ymm8, ymm9, ymm10, ymm11
|
||||
)
|
||||
|
||||
// Load 4x32 elements from row8-row11 of A
|
||||
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3,
|
||||
( a_use + 4 * rs_a ), rs_a
|
||||
)
|
||||
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
|
||||
ymm6, ymm4, ymm5, ymm7, ymm4,
|
||||
ymm12, ymm13, ymm14, ymm15
|
||||
)
|
||||
|
||||
a_use += 32;
|
||||
}
|
||||
|
||||
|
||||
|
||||
if( k_rem )
|
||||
{
|
||||
|
||||
uint8_t* restrict a0 = (a_use);
|
||||
uint8_t* restrict a1 = (a_use + rs_a );
|
||||
uint8_t* restrict a2 = (a_use + 2 * rs_a );
|
||||
uint8_t* restrict a3 = (a_use + 3 * rs_a );
|
||||
uint8_t* restrict a4 = (a_use + 4 * rs_a );
|
||||
uint8_t* restrict a5 = (a_use + 5 * rs_a );
|
||||
uint8_t* restrict a6 = (a_use + 6 * rs_a );
|
||||
uint8_t* restrict a7 = (a_use + 7 * rs_a );
|
||||
|
||||
for( dim_t i = 0; i < k_rem; i++)
|
||||
{
|
||||
buf8[i] = b_use[i];
|
||||
buf0[i] = a0[i];
|
||||
buf1[i] = a1[i];
|
||||
buf2[i] = a2[i];
|
||||
buf3[i] = a3[i];
|
||||
buf4[i] = a4[i];
|
||||
buf5[i] = a5[i];
|
||||
buf6[i] = a6[i];
|
||||
buf7[i] = a7[i];
|
||||
}
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
|
||||
|
||||
//Load 4x32 elements from row0-row3 of A
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
|
||||
ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 );
|
||||
ymm2 = _mm256_loadu_si256( (__m256i const *)buf2 );
|
||||
ymm3 = _mm256_loadu_si256( (__m256i const *)buf3 );
|
||||
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
|
||||
ymm6, ymm4, ymm5, ymm7, ymm4,
|
||||
ymm8, ymm9, ymm10, ymm11
|
||||
)
|
||||
|
||||
// Load 4x32 elements from row8-row11 of A
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)buf4 );
|
||||
ymm1 = _mm256_loadu_si256( (__m256i const *)buf5 );
|
||||
ymm2 = _mm256_loadu_si256( (__m256i const *)buf6 );
|
||||
ymm3 = _mm256_loadu_si256( (__m256i const *)buf7 );
|
||||
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
|
||||
ymm6, ymm4, ymm5, ymm7, ymm4,
|
||||
ymm12, ymm13, ymm14, ymm15
|
||||
)
|
||||
|
||||
}
|
||||
//Add the registers horizantally to get one
|
||||
LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, xmm0 )
|
||||
LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, xmm1 )
|
||||
|
||||
xmm0 = _mm_hadd_epi16( xmm0, xmm1 );
|
||||
|
||||
// post ops are applied on ymm register though
|
||||
// second half of the register is filled with zeroes.
|
||||
ymm8 = _mm256_setzero_si256();
|
||||
ymm8 = _mm256_inserti128_si256( ymm8, xmm0, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
uint8_t *a_use_fringe = a_use;
|
||||
dim_t mr0_use = mr0;
|
||||
dim_t regidx = 0;
|
||||
|
||||
if( mr0_use >= 4 )
|
||||
{
|
||||
for (dim_t k = 0; k < k_iter; k++)
|
||||
{
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)b_use );
|
||||
b_use += 32;
|
||||
|
||||
//Load 4x32 elements from row0-row3 of A
|
||||
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3,
|
||||
a_use, rs_a )
|
||||
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
|
||||
ymm6, ymm4, ymm5, ymm7, ymm4,
|
||||
ymm8, ymm9, ymm10, ymm11
|
||||
)
|
||||
|
||||
a_use += 32;
|
||||
}
|
||||
|
||||
if( k_rem )
|
||||
{
|
||||
uint8_t* restrict a0 = (a_use);
|
||||
uint8_t* restrict a1 = (a_use + rs_a );
|
||||
uint8_t* restrict a2 = (a_use + 2 * rs_a );
|
||||
uint8_t* restrict a3 = (a_use + 3 * rs_a );
|
||||
|
||||
for( dim_t i = 0; i < k_rem; i++)
|
||||
{
|
||||
buf8[i] = b_use[i];
|
||||
buf0[i] = a0[i];
|
||||
buf1[i] = a1[i];
|
||||
buf2[i] = a2[i];
|
||||
buf3[i] = a3[i];
|
||||
}
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
|
||||
|
||||
//Load 4xk_rem elements from row0-row3 of A
|
||||
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
|
||||
ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 );
|
||||
ymm2 = _mm256_loadu_si256( (__m256i const *)buf2 );
|
||||
ymm3 = _mm256_loadu_si256( (__m256i const *)buf3 );
|
||||
|
||||
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
|
||||
ymm6, ymm4, ymm5, ymm7, ymm4,
|
||||
ymm8, ymm9, ymm10, ymm11
|
||||
)
|
||||
}
|
||||
|
||||
//update pointers
|
||||
mr0_use -= 4;
|
||||
a_use = a_use_fringe + 4 * rs_a;
|
||||
a_use_fringe = a_use;
|
||||
b_use = (int8_t*)b;
|
||||
|
||||
//Add the registers horizantally to get one
|
||||
LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, xmm0 )
|
||||
|
||||
xmm0 = _mm_hadd_epi16( xmm0, xmm0 );
|
||||
|
||||
int64_t data = _mm_extract_epi64( xmm0, 0);
|
||||
//insert xmm outputs into final output reg based on regidx
|
||||
ymm8 = _mm256_setzero_si256();
|
||||
ymm8 = _mm256_insert_epi64( ymm8, data, 0 );
|
||||
regidx++;
|
||||
}
|
||||
|
||||
// Dot product for <= 3
|
||||
if ( mr0_use )
|
||||
{
|
||||
// Dot product for m = 2
|
||||
if ( mr0_use >= 2 )
|
||||
{
|
||||
for ( dim_t k = 0; k < k_iter; k++ )
|
||||
{
|
||||
// Load 0-31 in b[k+0 - k+31]
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)b_use );
|
||||
|
||||
LPGEMV_N_KERNEL_2_LOADS( ymm0, ymm1, a_use, rs_a);
|
||||
|
||||
LPGEMV_N_KERNEL_2_FMA( ymm0, ymm1, ymm6, ymm4,
|
||||
ymm5, ymm12, ymm13);
|
||||
b_use += 32; // move b pointer to next 32 elements
|
||||
a_use += 32;
|
||||
}
|
||||
if ( k_rem )
|
||||
{
|
||||
uint8_t* restrict a0 = (a_use);
|
||||
uint8_t* restrict a1 = (a_use + rs_a );
|
||||
|
||||
for( dim_t i = 0; i < k_rem; i++)
|
||||
{
|
||||
buf8[i] = b_use[i];
|
||||
buf0[i] = a0[i];
|
||||
buf1[i] = a1[i];
|
||||
}
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
|
||||
|
||||
//Load 2xk_rem elements from row0-row3 of A
|
||||
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
|
||||
ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 );
|
||||
|
||||
LPGEMV_N_KERNEL_2_FMA( ymm0, ymm1, ymm6,
|
||||
ymm4, ymm5, ymm12, ymm13 );
|
||||
}
|
||||
|
||||
mr0_use -= 2;
|
||||
a_use = a_use_fringe + 2 * rs_a;
|
||||
a_use_fringe = a_use;
|
||||
b_use = (int8_t*)b;
|
||||
}
|
||||
|
||||
// Dot product for m = 1
|
||||
if ( mr0_use == 1 )
|
||||
{
|
||||
for ( dim_t k = 0; k < k_iter; k++ )
|
||||
{
|
||||
// Load 0-31 in b[k+0 - k+31]
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)b_use );
|
||||
|
||||
// Load 1x32 elements from row0-row1 of A
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)a_use );
|
||||
|
||||
ymm4 = _mm256_maddubs_epi16(ymm0, ymm6);
|
||||
ymm14 = _mm256_add_epi16(ymm4, ymm14);
|
||||
|
||||
b_use += 32; // move b pointer to next 32 elements
|
||||
a_use += 32;
|
||||
}
|
||||
if ( k_rem )
|
||||
{
|
||||
uint8_t* restrict a0 = (a_use);
|
||||
|
||||
for( dim_t i = 0; i < k_rem; i++)
|
||||
{
|
||||
buf8[i] = b_use[i];
|
||||
buf0[i] = a0[i];
|
||||
}
|
||||
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
|
||||
|
||||
//Load 1xk_rem elements from row0-row3 of A
|
||||
|
||||
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
|
||||
|
||||
ymm4 = _mm256_maddubs_epi16(ymm0, ymm6);
|
||||
ymm14 = _mm256_add_epi16(ymm4, ymm14);
|
||||
}
|
||||
|
||||
// When only fringe 1,
|
||||
// update the registers to store in order
|
||||
if ( !( mr0 & 0x2 ) ) ymm12 = ymm14;
|
||||
}
|
||||
|
||||
LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, xmm0)
|
||||
xmm0 = _mm_hadd_epi16( xmm0, xmm0 );
|
||||
|
||||
int64_t data = _mm_extract_epi64( xmm0, 0);
|
||||
//insert xmm outputs into final output reg based on regidx
|
||||
|
||||
if( regidx == 0 )
|
||||
{
|
||||
ymm8 = _mm256_insert_epi64( ymm8, data, 0 );
|
||||
}
|
||||
else
|
||||
{
|
||||
ymm8 = _mm256_insert_epi64( ymm8, data, 1 );
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Load alpha and beta
|
||||
__m256i selector1 = _mm256_set1_epi16(alpha);
|
||||
__m256i selector2 = _mm256_set1_epi16(beta);
|
||||
|
||||
// Scale by alpha
|
||||
ymm8 = _mm256_mullo_epi16(selector1, ymm8);
|
||||
|
||||
if( beta != 0 )
|
||||
{
|
||||
if ( post_ops_attr.buf_downscale != NULL )
|
||||
{
|
||||
if( post_ops_attr.rs_c_downscale == 1 )
|
||||
{
|
||||
if( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
dim_t m0_rem_dscale_bytes = mr0 * sizeof( int8_t );
|
||||
|
||||
S8_S16_BETA_NLT16_MEMCP_UTIL( ctemp, 0,
|
||||
m0_rem_dscale_bytes );
|
||||
|
||||
S8_S16_BETA_OP_NLT16( ymm8, ctemp,
|
||||
selector1, selector2 )
|
||||
}
|
||||
else if( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
dim_t m0_rem_dscale_bytes = mr0 * sizeof( uint8_t );
|
||||
|
||||
U8_S16_BETA_NLT16_MEMCP_UTIL( ctemp, 0,
|
||||
m0_rem_dscale_bytes );
|
||||
|
||||
U8_S16_BETA_OP_NLT16( ymm8, ctemp,
|
||||
selector1, selector2 )
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
int8_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( (int8_t*)post_ops_attr.buf_downscale
|
||||
+ ( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) );
|
||||
}
|
||||
selector1 = _mm256_cvtepi8_epi32
|
||||
( _mm_loadu_si128( (__m128i const*)ctemp ) );
|
||||
S16_BETA_FMA( ymm8, selector1, selector2 );
|
||||
}
|
||||
else if( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
uint8_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( (uint8_t*)post_ops_attr.buf_downscale
|
||||
+ ( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) );
|
||||
}
|
||||
selector1 = _mm256_cvtepu8_epi32
|
||||
( _mm_loadu_si128( (__m128i const*)ctemp ) );
|
||||
S16_BETA_FMA( ymm8, selector1, selector2 );
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if( rs_c == 1 )
|
||||
{
|
||||
dim_t m0_rem_bytes = mr0 * sizeof( int16_t );
|
||||
memcpy( ctemp, c_use, m0_rem_bytes );
|
||||
S16_S16_BETA_OP_NLT16( ymm8, ctemp,
|
||||
selector1, selector2 )
|
||||
}
|
||||
else
|
||||
{
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = c_use[ i * rs_c ];
|
||||
}
|
||||
selector1 = _mm256_loadu_si256( (__m256i const *)ctemp );
|
||||
S16_BETA_FMA( ymm8, selector1, selector2 );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Post Ops
|
||||
lpgemm_post_op * post_ops_list_temp = post_op;
|
||||
|
||||
post_ops_attr.is_last_k = TRUE;
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP
|
||||
|
||||
|
||||
POST_OPS_BIAS:
|
||||
{
|
||||
|
||||
|
||||
selector1 =
|
||||
_mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args1) );
|
||||
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_RELU:
|
||||
{
|
||||
selector1 = _mm256_setzero_si256();
|
||||
|
||||
ymm8 = _mm256_max_epi16( selector1, ymm8 );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_RELU_SCALE:
|
||||
{
|
||||
__m256i b0;
|
||||
selector1 = _mm256_setzero_si256();
|
||||
selector2 = _mm256_set1_epi16(
|
||||
*( ( int16_t* )post_ops_list_temp->op_args2 ) );
|
||||
|
||||
RELU_SCALE_OP_S16_AVX2( ymm8 )
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_GELU_TANH:
|
||||
{
|
||||
__m256 dn, z, x, r2, r, y1, y2, x_tanh;
|
||||
__m256i q;
|
||||
|
||||
GELU_TANH_S16_AVX2( ymm8, y1, y2, r, r2, x, z, dn, x_tanh, q )
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_GELU_ERF:
|
||||
{
|
||||
__m256 x, r, y1, y2, x_erf;
|
||||
|
||||
GELU_ERF_S16_AVX2(ymm8, y1, y2, r, x, x_erf)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_CLIP:
|
||||
{
|
||||
__m256i min = _mm256_set1_epi16(
|
||||
*( int16_t* )post_ops_list_temp->op_args2 );
|
||||
__m256i max = _mm256_set1_epi16(
|
||||
*( int16_t* )post_ops_list_temp->op_args3 );
|
||||
|
||||
CLIP_S16_AVX2(ymm8, min, max)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_DOWNSCALE:
|
||||
{
|
||||
__m128i temp[2];
|
||||
__m256i temp_32[2];
|
||||
__m256 temp_float[2];
|
||||
__m256 scale_1 = _mm256_setzero_ps();
|
||||
__m256 scale_2 = _mm256_setzero_ps();
|
||||
__m128i _zero_point_0 = _mm_setzero_si128();
|
||||
__m256i zero_point_0 = _mm256_setzero_si256();
|
||||
__m256 res_1, res_2;
|
||||
|
||||
scale_1 =
|
||||
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
|
||||
|
||||
scale_2 =
|
||||
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
|
||||
|
||||
_zero_point_0 = _mm_set1_epi8(
|
||||
*( ( int8_t* )post_ops_list_temp->op_args1 ) );
|
||||
|
||||
if ( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
|
||||
}
|
||||
|
||||
// Scale first 16 columns of the 2 rows.
|
||||
CVT_MULRND_CVT16(ymm8, scale_1, scale_2, zero_point_0)
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
|
||||
POST_OPS_MATRIX_ADD:
|
||||
{
|
||||
dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3;
|
||||
|
||||
if ( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1;
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
memcpy
|
||||
(
|
||||
( int8_t* )ctemp,
|
||||
matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ),
|
||||
( mr0 ) * sizeof(int8_t)
|
||||
);
|
||||
selector1 = _mm256_cvtepi8_epi16(
|
||||
_mm_loadu_si128( ( __m128i const* )ctemp ) );
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
}
|
||||
else
|
||||
{
|
||||
int8_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( matptr +
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm256_cvtepi8_epi16
|
||||
( _mm_loadu_si128( (__m128i const*)ctemp ) );
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
}
|
||||
}
|
||||
else if ( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1;
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
memcpy
|
||||
(
|
||||
( uint8_t* )ctemp,
|
||||
matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ),
|
||||
( mr0 ) * sizeof(uint8_t)
|
||||
);
|
||||
selector1 = _mm256_cvtepu8_epi16(
|
||||
_mm_loadu_si128( ( __m128i const* )ctemp ) );
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
}
|
||||
else
|
||||
{
|
||||
uint8_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( matptr +
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm256_cvtepu8_epi16
|
||||
( _mm_loadu_si128( (__m128i const*)ctemp ) );
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1;
|
||||
|
||||
if( ldm == 1 )
|
||||
{
|
||||
memcpy
|
||||
(
|
||||
( int16_t* )ctemp,
|
||||
matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) +
|
||||
post_ops_attr.post_op_c_j + ( 0 * 16 ),
|
||||
( mr0 ) * sizeof(int16_t)
|
||||
);
|
||||
|
||||
selector1 = _mm256_loadu_si256( ( __m256i const* )ctemp );
|
||||
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
}
|
||||
else
|
||||
{
|
||||
int32_t ctemp[16];
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
ctemp[i] = *( matptr +
|
||||
( ( post_ops_attr.post_op_c_i + i )
|
||||
* ldm ) );
|
||||
}
|
||||
selector1 = _mm256_loadu_si256( (__m256i const *)ctemp );
|
||||
ymm8 = _mm256_add_epi16( selector1, ymm8 );
|
||||
}
|
||||
}
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_SWISH:
|
||||
{
|
||||
selector1 =
|
||||
_mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) );
|
||||
__m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \
|
||||
_mm256_extractf128_si256( selector1, 0 ) ) );
|
||||
|
||||
__m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn;
|
||||
__m256i ex_out;
|
||||
|
||||
SWISH_S16_AVX2( ymm8, al, al_in, tmp_reg1,
|
||||
tmp_reg2, r, r2, z, dn, ex_out );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_TANH:
|
||||
{
|
||||
__m256 dn, z, x, r2, r, y1, y2;
|
||||
__m256i q;
|
||||
|
||||
TANH_S16_AVX2( ymm8, y1, y2, r, r2, x, z, dn, q )
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_SIGMOID:
|
||||
{
|
||||
__m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn;
|
||||
__m256i ex_out;
|
||||
|
||||
SIGMOID_S16_AVX2( ymm8, al_in, tmp_reg1,
|
||||
tmp_reg2, r, r2, z, dn, ex_out );
|
||||
|
||||
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
|
||||
}
|
||||
POST_OPS_DISABLE:
|
||||
{
|
||||
if ( post_ops_attr.buf_downscale != NULL )
|
||||
{
|
||||
__m128i temp[2];
|
||||
__m256i zero_reg = _mm256_setzero_si256();
|
||||
if( post_ops_attr.rs_c_downscale == 1 )
|
||||
{
|
||||
if( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
// Store the results in downscaled type
|
||||
// (int8 instead of int16).
|
||||
CVT_STORE_S16_S8_1ROW_NLT16(ymm8, zero_reg, ctemp);
|
||||
|
||||
dim_t m0_rem_dscale_bytes = mr0 * sizeof( int8_t );
|
||||
|
||||
CVT_STORE_S16_S8_NLT16_MEMCP_UTIL( ctemp, 0,
|
||||
m0_rem_dscale_bytes);
|
||||
}
|
||||
else if( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
// Store the results in downscaled type (uint8 instead of int16).
|
||||
CVT_STORE_S16_U8_1ROW_NLT16(ymm8, zero_reg, ctemp);
|
||||
|
||||
dim_t m0_rem_dscale_bytes = mr0 * sizeof( uint8_t );
|
||||
|
||||
CVT_STORE_S16_U8_NLT16_MEMCP_UTIL( ctemp, 0,
|
||||
m0_rem_dscale_bytes);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if( post_ops_attr.c_stor_type == S8 )
|
||||
{
|
||||
int8_t ctemp[16];
|
||||
|
||||
CVT_STORE_S16_S8_1ROW_NLT16(ymm8, zero_reg, ctemp);
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
*( ( int8_t* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
|
||||
}
|
||||
}
|
||||
else if( post_ops_attr.c_stor_type == U8 )
|
||||
{
|
||||
uint8_t ctemp[16];
|
||||
|
||||
CVT_STORE_S16_U8_1ROW_NLT16(ymm8, zero_reg, ctemp);
|
||||
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
*( ( uint8_t* )post_ops_attr.buf_downscale +
|
||||
( post_ops_attr.rs_c_downscale *
|
||||
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if( rs_c == 1 )
|
||||
{
|
||||
_mm256_storeu_si256( ( __m256i* )ctemp, ymm8 );
|
||||
|
||||
dim_t m0_rem_bytes = mr0 * sizeof( int16_t );
|
||||
|
||||
memcpy( c_use, ctemp, m0_rem_bytes );
|
||||
}
|
||||
else
|
||||
{
|
||||
_mm256_storeu_si256( ( __m256i* )ctemp, ymm8 );
|
||||
|
||||
for( dim_t i = 0; i < mr0; i++ )
|
||||
{
|
||||
c_use[i * rs_c] = ctemp[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
post_ops_attr.post_op_c_i += MR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user