Deprecate S16 LPGEMM APIs.

-The following S16 APIs are removed:
1. aocl_gemm_u8s8s16os16
2. aocl_gemm_u8s8s16os8
3. aocl_gemm_u8s8s16ou8
4. aocl_gemm_s8s8s16os16
5. aocl_gemm_s8s8s16os8
along with the associated reorder APIs and corresponding
framework elements.

AMD-Internal: [CPUPL-6412]

Change-Id: I251f8b02a4cba5110615ddeb977d86f5c949363b
This commit is contained in:
Mithun Mohan
2025-02-07 11:27:28 +00:00
parent 1f0fb05277
commit bffa92ec93
39 changed files with 1 additions and 22361 deletions

View File

@@ -45,13 +45,10 @@
#include "lpgemm_eltwise_ops_kernels.h"
#include "lpgemm_utils_kernels.h"
#include "lpgemm_pack_bf16.h"
#include "lpgemm_packb_s16.h"
#include "lpgemm_packa_s16.h"
#include "lpgemm_packa.h"
#include "lpgemm_packb.h"
#include "lpgemm_packa_s8.h"
#include "lpgemm_packb_s8.h"
#include "lpgemm_packb_s8s16.h"
#include "lpgemm_pack_f32.h"
#include "lpgemm_jit_typedefs.h"
#ifdef LPGEMM_BF16_JIT

View File

@@ -51,10 +51,8 @@ BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_ ## LP_SFX \
AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32);
AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s32os32);
AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16);
AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32);
AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s32os32);
AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s16os16);
AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s4s32os32);
AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16s4f32of32);
@@ -76,11 +74,9 @@ BLIS_EXPORT_ADDON void aocl_reorder_ ## LP_SFX \
AOCL_GEMM_REORDER(float,f32f32f32of32);
AOCL_GEMM_REORDER(int8_t,u8s8s32os32);
AOCL_GEMM_REORDER(int8_t,u8s8s16os16);
AOCL_GEMM_REORDER(bfloat16,bf16bf16f32of32);
AOCL_GEMM_REORDER(bfloat16,bf16bf16f32of32_reference);
AOCL_GEMM_REORDER(int8_t,s8s8s32os32);
AOCL_GEMM_REORDER(int8_t,s8s8s16os16);
AOCL_GEMM_REORDER(int8_t,u8s4s32os32);
AOCL_GEMM_REORDER(int8_t, bf16s4f32of32);
@@ -136,7 +132,6 @@ BLIS_EXPORT_ADDON void aocl_gemm_ ## LP_SFX \
) \
AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32);
AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16);
AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8);
AOCL_GEMM_MATMUL(uint8_t,int8_t,bfloat16,int32_t,u8s8s32obf16);
AOCL_GEMM_MATMUL(uint8_t,int8_t,float,int32_t,u8s8s32of32);
@@ -148,11 +143,6 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16);
AOCL_GEMM_MATMUL(int8_t,int8_t,float,int32_t,s8s8s32of32);
AOCL_GEMM_MATMUL(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8);
AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16);
AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8);
AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8);
AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8);
AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16);
AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32);
AOCL_GEMM_MATMUL(bfloat16, int8_t, float, float, bf16s4f32of32);

View File

@@ -1,200 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "aocl_gemm_interface_apis.h"
#include "aocl_gemm_check.h"
#include "lpgemm_types.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_config.h"
#include "lpgemm_thread_decor_openmp.h"
#include "lpgemm_post_ops.h"
#include "lpgemm_utils_s8.h"
#include "lpgemm_logger.h"
AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16)
{
LPGEMM_START_LOGGER();
LPGEMM_WRITE_LOGGER \
(
"s8s8s16os16", \
order, transa, transb, \
m, n, k, \
( ( float ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
( ( float ) beta ), \
ldc, post_op_unparsed \
);
trans_t blis_transa;
trans_t blis_transb;
// Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it.
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
{
bli_print_msg(" AVX2 ISA not supported by processor, "
"cannot perform s8s8s16 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
// check for validity of params.
int err_no = 0;
AOCL_GEMM_CHECK
(
"s8s8s16os16",
order, transa, transb,
m, n, k,
a, lda, mem_format_a,
b, ldb, mem_format_b,
c, ldc,
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
/* Perform BLAS parameter checking. */
// Transpose not supported.
if ( ( blis_transb != BLIS_NO_TRANSPOSE ) )
{
bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
if ( ( order != 'r' ) && ( order != 'R' ) )
{
bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ );
goto err_hndl;
}
inc_t rs_a = lda;
inc_t cs_a = 1;
inc_t rs_b = ldb;
inc_t cs_b = 1;
const inc_t rs_c = ldc;
const inc_t cs_c = 1;
AOCL_MEMORY_TAG mtag_a;
AOCL_MEMORY_TAG mtag_b;
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
// Pack is enabled for row major storage when trans A is true.
// Pack tranforms column major matrix to row-major storage as kernel
// expects A matrix to be in row-major format.
if ( bli_is_trans( blis_transa ) )
{
rs_a = 1;
cs_a = lda;
mtag_a = PACK;
}
// B matrix needs to be packed in a certain format in order to be loaded
// and used in VNNI instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if (mtag_b == UNPACKED)
{
mtag_b = PACK;
}
// Only unpacked A supported now for row-major A matrix.
if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) )
{
bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ );
goto err_hndl;
}
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed, post_op_list,
( void* )c, ( void* )( &order ),
m, n
);
if( err != BLIS_SUCCESS )
{
goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global(&rntm_g);
bli_pba_rntm_set_pba(&rntm_g);
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S16OS16 );
#ifdef BLIS_ENABLE_OPENMP
lpgemm_s8s8s16o16_openmp_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S16
);
#else
lpgemm_s8s8s16o16_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S16
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -1,178 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "aocl_gemm_interface_apis.h"
#include "lpgemm_types.h"
#include "lpgemm_config.h"
#include "lpgemm_utils_s8.h"
#include "lpgemm_reorder_s8s16.h"
AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s16os16)
{
if ((k <= 0) || (n <= 0))
{
return 0; // Error.
}
// Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it.
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
{
bli_print_msg(" AVX2 ISA not supported by processor, "
"cannot perform s8s8s16 gemm.", __FILE__, __LINE__ );
return 0; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
AOCL_MATRIX_TYPE input_mat_type;
bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type);
if (input_mat_type == A_MATRIX)
{
return 0; // A reorder not supported.
}
dim_t n_reorder;
if( n == 1 )
{
n_reorder = 1;
}
else
{
n_reorder = make_multiple_of_n( n, 16 );
}
// Extra space since packing does length in multiples of 4.
dim_t k_reorder;
if( n == 1 )
{
k_reorder = k;
}
else
{
k_reorder = make_multiple_of_n( k, 4 );
}
// Extra memory of n_reorder * sizeof( int16_t )
// to store sum of every column of B matrix buffer
siz_t size_req = sizeof(int8_t) * k_reorder * n_reorder
+ ( n_reorder * sizeof( int16_t ));
return size_req;
}
AOCL_GEMM_REORDER(int8_t,s8s8s16os16)
{
if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) ||
(k <= 0) || (n <= 0) || (ldb < n))
{
return; // Error.
}
trans_t blis_trans;
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(trans, &blis_trans);
if( bli_is_trans( blis_trans ) )
{
bli_print_msg(" Transpose of matrix is not supported in "
"s8s8s16 gemm.", __FILE__, __LINE__ );
return; // Error.
}
// Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it.
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
{
bli_print_msg(" AVX2 ISA not supported by processor, "
"cannot perform s8s8s16 gemm.", __FILE__, __LINE__ );
return; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
AOCL_MATRIX_TYPE input_mat_type;
bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type);
if (input_mat_type == A_MATRIX)
{
return; // A reorder not supported.
}
if( n == 1 )
{
int16_t* pack_b_column_sum = ( int16_t* ) ( reorder_buf_addr +
( sizeof( int8_t ) * n * k ));
*pack_b_column_sum = 0;
for( dim_t k0 = 0; k0 < k; k0++ )
{
reorder_buf_addr[k0] = input_buf_addr[ k0 * ldb ];
*pack_b_column_sum += reorder_buf_addr[k0];
}
*pack_b_column_sum *= 128;
return;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global(&rntm_g);
bli_pba_rntm_set_pba(&rntm_g);
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S16OS16 );
// Create dummy b_reorder obj.
lpgemm_obj_t b_reorder;
b_reorder.storage.aligned_buffer = reorder_buf_addr;
// Create dummy original b obj;
lpgemm_obj_t b;
b.storage.aligned_buffer = (void *)input_buf_addr;
b.rs = ldb;
b.width = n;
b.length = k;
aocl_reorderb_nr32_s8s8s16o16( &b, &b_reorder, &rntm_g, lcntx_g );
}

View File

@@ -1,200 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "aocl_gemm_interface_apis.h"
#include "aocl_gemm_check.h"
#include "lpgemm_types.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_config.h"
#include "lpgemm_thread_decor_openmp.h"
#include "lpgemm_post_ops.h"
#include "lpgemm_utils_s8.h"
#include "lpgemm_logger.h"
AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8)
{
LPGEMM_START_LOGGER();
LPGEMM_WRITE_LOGGER \
(
"s8s8s16os8", \
order, transa, transb, \
m, n, k, \
( ( float ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
( ( float ) beta ), \
ldc, post_op_unparsed \
);
trans_t blis_transa;
trans_t blis_transb;
// Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it.
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
{
bli_print_msg(" AVX2 ISA not supported by processor, "
"cannot perform s8s8s16 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
// check for validity of params.
int err_no = 0;
AOCL_GEMM_CHECK
(
"s8s8s16os8",
order, transa, transb,
m, n, k,
a, lda, mem_format_a,
b, ldb, mem_format_b,
c, ldc,
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
/* Perform BLAS parameter checking. */
// Transpose not supported.
if ( ( blis_transb != BLIS_NO_TRANSPOSE ) )
{
bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
if ( ( order != 'r' ) && ( order != 'R' ) )
{
bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ );
goto err_hndl;
}
inc_t rs_a = lda;
inc_t cs_a = 1;
inc_t rs_b = ldb;
inc_t cs_b = 1;
const inc_t rs_c = ldc;
const inc_t cs_c = 1;
AOCL_MEMORY_TAG mtag_a;
AOCL_MEMORY_TAG mtag_b;
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
// Pack is enabled for row major storage when trans A is true.
// Pack tranforms column major matrix to row-major storage as kernel
// expects A matrix to be in row-major format.
if ( bli_is_trans( blis_transa ) )
{
rs_a = 1;
cs_a = lda;
mtag_a = PACK;
}
// B matrix needs to be packed in a certain format in order to be loaded
// and used in VNNI instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if (mtag_b == UNPACKED)
{
mtag_b = PACK;
}
// Only unpacked A supported now for row-major A matrix.
if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) )
{
bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ );
goto err_hndl;
}
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed, post_op_list,
( void* )c, ( void* )( &order ),
m, n
);
if( err != BLIS_SUCCESS )
{
goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global(&rntm_g);
bli_pba_rntm_set_pba(&rntm_g);
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S16OS16 );
#ifdef BLIS_ENABLE_OPENMP
lpgemm_s8s8s16o16_openmp_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
( int16_t* )c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S8
);
#else
lpgemm_s8s8s16o16_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
( int16_t* )c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S8
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -1,200 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "aocl_gemm_interface_apis.h"
#include "aocl_gemm_check.h"
#include "lpgemm_types.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_config.h"
#include "lpgemm_utils.h"
#include "lpgemm_thread_decor_openmp.h"
#include "lpgemm_post_ops.h"
#include "lpgemm_logger.h"
AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16)
{
LPGEMM_START_LOGGER();
LPGEMM_WRITE_LOGGER \
(
"u8s8s16os16", \
order, transa, transb, \
m, n, k, \
( ( float ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
( ( float ) beta ), \
ldc, post_op_unparsed \
);
trans_t blis_transa;
trans_t blis_transb;
// Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it.
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
{
bli_print_msg(" AVX2 ISA not supported by processor, "
"cannot perform u8s8s16 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
// check for validity of params.
int err_no = 0;
AOCL_GEMM_CHECK
(
"u8s8s16os16",
order, transa, transb,
m, n, k,
a, lda, mem_format_a,
b, ldb, mem_format_b,
c, ldc,
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
/* Perform BLAS parameter checking. */
// Transpose not supported.
if ( ( blis_transb != BLIS_NO_TRANSPOSE ) )
{
bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
if ( ( order != 'r' ) && ( order != 'R' ) )
{
bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ );
goto err_hndl;
}
inc_t rs_a = lda;
inc_t cs_a = 1;
inc_t rs_b = ldb;
inc_t cs_b = 1;
const inc_t rs_c = ldc;
const inc_t cs_c = 1;
AOCL_MEMORY_TAG mtag_a;
AOCL_MEMORY_TAG mtag_b;
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
// Pack is enabled for row major storage when trans A is true.
// Pack tranforms column major matrix to row-major storage as kernel
// expects A matrix to be in row-major format.
if ( bli_is_trans( blis_transa ) )
{
rs_a = 1;
cs_a = lda;
mtag_a = PACK;
}
// B matrix needs to be packed in a certain format in order to be loaded
// and used in VNNI instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if (mtag_b == UNPACKED)
{
mtag_b = PACK;
}
// Only unpacked A supported now for row-major A matrix.
if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) )
{
bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ );
goto err_hndl;
}
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed, post_op_list,
( void* )c, ( void* )( &order ),
m, n
);
if( err != BLIS_SUCCESS )
{
goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global(&rntm_g);
bli_pba_rntm_set_pba(&rntm_g);
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 );
#ifdef BLIS_ENABLE_OPENMP
lpgemm_u8s8s16o16_openmp_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S16
);
#else
lpgemm_u8s8s16o16_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S16
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -1,169 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "aocl_gemm_interface_apis.h"
#include "lpgemm_types.h"
#include "lpgemm_config.h"
#include "lpgemm_utils.h"
#include "lpgemm_reorder_s16.h"
AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16)
{
if ((k <= 0) || (n <= 0))
{
return 0; // Error.
}
// Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it.
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
{
bli_print_msg(" AVX2 ISA not supported by processor, "
"cannot perform u8s8s16 gemm.", __FILE__, __LINE__ );
return 0; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
AOCL_MATRIX_TYPE input_mat_type;
bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type);
if (input_mat_type == A_MATRIX)
{
return 0; // A reorder not supported.
}
// Extra space since packing does width in multiples of 16. The vpmaddubsw
// instruction can be used as long as at least one ymm register can be fully
// loaded; and since k_dim needs to be at least 2, having n_dim at least 16
// should give 2x16=32 elements, enough for 1 ymm register.The padding is
// not rounded to NR (=16), since that would result in memory wastage.
dim_t n_reorder;
if( n == 1 )
{
n_reorder = 1;
}
else
{
n_reorder = make_multiple_of_n( n, 16 );
}
dim_t k_reorder;
if( n == 1 )
{
k_reorder = k;
}
else
{
k_reorder = make_multiple_of_n( k, 2 );
}
siz_t size_req = sizeof(int8_t) * k_reorder * n_reorder;
return size_req;
}
AOCL_GEMM_REORDER(int8_t,u8s8s16os16)
{
if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) ||
(k <= 0) || (n <= 0) || (ldb < n))
{
return; // Error.
}
// Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it.
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
{
bli_print_msg(" AVX2 ISA not supported by processor, "
"cannot perform u8s8s16 gemm.", __FILE__, __LINE__ );
return; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
AOCL_MATRIX_TYPE input_mat_type;
bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type);
if (input_mat_type == A_MATRIX)
{
return; // A reorder not supported.
}
if( n == 1 )
{
if (ldb == 1)
{
memcpy( reorder_buf_addr, input_buf_addr,
( k * sizeof( int8_t ) ) );
}
else
{
for( dim_t k0 = 0; k0 < k; k0++ )
{
reorder_buf_addr[k0] = input_buf_addr[k0 * ldb];
}
}
return;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global(&rntm_g);
bli_pba_rntm_set_pba(&rntm_g);
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 );
// Create dummy b_reorder obj.
lpgemm_obj_t b_reorder;
b_reorder.storage.aligned_buffer = reorder_buf_addr;
// Create dummy original b obj;
lpgemm_obj_t b;
b.storage.aligned_buffer = (void *)input_buf_addr;
b.rs = ldb;
b.width = n;
b.length = k;
aocl_reorderb_nr32_u8s8s16o16( &b, &b_reorder, &rntm_g, lcntx_g );
}

View File

@@ -1,199 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "aocl_gemm_interface_apis.h"
#include "aocl_gemm_check.h"
#include "lpgemm_types.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_config.h"
#include "lpgemm_utils.h"
#include "lpgemm_thread_decor_openmp.h"
#include "lpgemm_post_ops.h"
#include "lpgemm_logger.h"
AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8)
{
LPGEMM_START_LOGGER();
LPGEMM_WRITE_LOGGER \
(
"u8s8s16os8", \
order, transa, transb, \
m, n, k, \
( ( float ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
( ( float ) beta ), \
ldc, post_op_unparsed \
);
trans_t blis_transa;
trans_t blis_transb;
// Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it.
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
{
bli_print_msg(" AVX2 ISA not supported by processor, "
"cannot perform u8s8s16 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
// check for validity of params.
int err_no = 0;
AOCL_GEMM_CHECK
(
"u8s8s16os8",
order, transa, transb,
m, n, k,
a, lda, mem_format_a,
b, ldb, mem_format_b,
c, ldc,
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
/* Perform BLAS parameter checking. */
// Transpose not supported.
if ( ( blis_transb != BLIS_NO_TRANSPOSE ) )
{
bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
if ( ( order != 'r' ) && ( order != 'R' ) )
{
bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ );
goto err_hndl;
}
inc_t rs_a = lda;
inc_t cs_a = 1;
inc_t rs_b = ldb;
inc_t cs_b = 1;
const inc_t rs_c = ldc;
const inc_t cs_c = 1;
AOCL_MEMORY_TAG mtag_a;
AOCL_MEMORY_TAG mtag_b;
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
// Pack is enabled for row major storage when trans A is true.
// Pack tranforms column major matrix to row-major storage as kernel
// expects A matrix to be in row-major format.
if ( bli_is_trans( blis_transa ) )
{
rs_a = 1;
cs_a = lda;
mtag_a = PACK;
}
// B matrix needs to be packed in a certain format in order to be loaded
// and used in VNNI instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if (mtag_b == UNPACKED)
{
mtag_b = PACK;
}
// Only unpacked A supported now for row-major A matrix.
if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) )
{
bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ );
goto err_hndl;
}
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed, post_op_list,
( void* )c, ( void* )( &order ),
m, n
);
if( err != BLIS_SUCCESS )
{
goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global(&rntm_g);
bli_pba_rntm_set_pba(&rntm_g);
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 );
#ifdef BLIS_ENABLE_OPENMP
lpgemm_u8s8s16o16_openmp_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
( int16_t* )c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S8
);
#else
lpgemm_u8s8s16o16_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
( int16_t* )c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S8
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -1,199 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "aocl_gemm_interface_apis.h"
#include "aocl_gemm_check.h"
#include "lpgemm_types.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_config.h"
#include "lpgemm_utils.h"
#include "lpgemm_thread_decor_openmp.h"
#include "lpgemm_post_ops.h"
#include "lpgemm_logger.h"
AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8)
{
LPGEMM_START_LOGGER();
LPGEMM_WRITE_LOGGER \
(
"u8s8s16ou8", \
order, transa, transb, \
m, n, k, \
( ( float ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
( ( float ) beta ), \
ldc, post_op_unparsed \
);
trans_t blis_transa;
trans_t blis_transb;
// Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it.
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
{
bli_print_msg(" AVX2 ISA not supported by processor, "
"cannot perform u8s8s16 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
// check for validity of params.
int err_no = 0;
AOCL_GEMM_CHECK
(
"u8s8s16ou8",
order, transa, transb,
m, n, k,
a, lda, mem_format_a,
b, ldb, mem_format_b,
c, ldc,
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
/* Perform BLAS parameter checking. */
// Transpose not supported.
if ( ( blis_transb != BLIS_NO_TRANSPOSE ) )
{
bli_print_msg(" Transpose of B matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
if ( ( order != 'r' ) && ( order != 'R' ) )
{
bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ );
goto err_hndl;
}
inc_t rs_a = lda;
inc_t cs_a = 1;
inc_t rs_b = ldb;
inc_t cs_b = 1;
const inc_t rs_c = ldc;
const inc_t cs_c = 1;
AOCL_MEMORY_TAG mtag_a;
AOCL_MEMORY_TAG mtag_b;
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
// Pack is enabled for row major storage when trans A is true.
// Pack tranforms column major matrix to row-major storage as kernel
// expects A matrix to be in row-major format.
if ( bli_is_trans( blis_transa ) )
{
rs_a = 1;
cs_a = lda;
mtag_a = PACK;
}
// B matrix needs to be packed in a certain format in order to be loaded
// and used in VNNI instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if (mtag_b == UNPACKED)
{
mtag_b = PACK;
}
// Only unpacked A supported now for row-major A matrix.
if ( !( bli_is_trans( blis_transa ) ) && ( mtag_a != UNPACKED ) )
{
bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ );
goto err_hndl;
}
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed, post_op_list,
( void* )c, ( void* )( &order ),
m, n
);
if( err != BLIS_SUCCESS )
{
goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global(&rntm_g);
bli_pba_rntm_set_pba(&rntm_g);
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 );
#ifdef BLIS_ENABLE_OPENMP
lpgemm_u8s8s16o16_openmp_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
( int16_t* )c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, U8
);
#else
lpgemm_u8s8s16o16_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
( int16_t* )c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, U8
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -39,23 +39,19 @@
// ID = One of the AOCL_OPERATION_TYPE enum.
#define LPGEMM_BLKSZ_MAP_ZEN4 \
XMACRO(U8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
XMACRO(U8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(F32F32F32OF32, 192, 8064, 512, 6, 64, 1, 6, 64, 1) \
XMACRO(BF16BF16F32OF32, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \
XMACRO(BF16S4F32OF32, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \
XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(S8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
XMACRO(U8S4S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(F32OBF16, 144, 1024, 4096, 6, 64, 0, 0, 2*64, 64/2) \
#define LPGEMM_BLKSZ_MAP_ZEN \
XMACRO(U8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
XMACRO(U8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(F32F32F32OF32, 144, 8160, 512, 6, 16, 1, 6, 16, 1) \
XMACRO(BF16BF16F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(S8S8S16OS16, 240, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \
XMACRO(U8S4S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \
XMACRO(BF16S4F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \
XMACRO(F32OBF16, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \

View File

@@ -38,13 +38,10 @@
#include "lpgemm_blksz_map.h"
#include "lpgemm_kernels.h"
#include "lpgemm_pack_bf16.h"
#include "lpgemm_packb_s16.h"
#include "lpgemm_packa_s16.h"
#include "lpgemm_packa.h"
#include "lpgemm_packb.h"
#include "lpgemm_packa_s8.h"
#include "lpgemm_packb_s8.h"
#include "lpgemm_packb_s8s16.h"
#include "lpgemm_pack_f32.h"
#include "lpgemm_logger.h"
#include "lpgemm_thread_utils.h"

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -46,32 +46,26 @@
// AVX512 + VNNI + BF16
#define LPGEMM_KERN_FUNC_MAP_AVX512_VNNI_BF16 \
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_6x64m) \
KMACRO(BF16BF16F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \
KMACRO(BF16S4F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_BF16_TO_AVX2 \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16 \
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
PAMACRO(S8S8S16OS16, packa_u8s8s16os16)
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \
PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \
PBMACRO(BF16S4F32OF32, packb_nr64_bf16s4f32of32)
@@ -85,11 +79,9 @@
UBMACRO(BF16BF16F32OF32, unpackb_nr64_bf16bf16f32of32)
#define LPGEMM_PACKSCLB_FUNC_MAP_AVX512_VNNI_BF16 \
PBSMACRO(U8S8S16OS16, NULL) \
PBSMACRO(U8S8S32OS32, NULL) \
PBSMACRO(BF16BF16F32OF32, NULL) \
PBSMACRO(S8S8S32OS32, NULL) \
PBSMACRO(S8S8S16OS16, NULL) \
PBSMACRO(U8S4S32OS32, NULL) \
PBSMACRO(BF16S4F32OF32, packsclb_nr64_bf16s4f32of32) \
@@ -105,35 +97,29 @@
// AVX512 + VNNI
#define LPGEMM_KERN_FUNC_MAP_AVX512_VNNI \
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_6x64m) \
KMACRO(BF16BF16F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \
KMACRO(BF16S4F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_VNNI_TO_AVX2 \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI \
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
PAMACRO(S8S8S16OS16, packa_u8s8s16os16)
#define LPGEMM_PACKBMXP_FUNC_MAP_AVX512_VNNI \
PBMXPMACRO(F32OBF16, packb_mxp_nr64_f32obf16)
#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI \
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \
PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \
PBSMACRO(BF16S4F32OF32, packb_nr64_bf16s4f32of32)
@@ -148,32 +134,26 @@
// AVX512
#define LPGEMM_KERN_FUNC_MAP_AVX512 \
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_6x64m) \
KMACRO(BF16BF16F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \
KMACRO(BF16S4F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \
KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
#define LPGEMM_KERN_FUNC_UPD_MAP_AVX512_TO_AVX2 \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
#define LPGEMM_PACKA_FUNC_MAP_AVX512 \
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
PAMACRO(U8S8S32OS32, packa_u8s8s32os32) \
PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(BF16S4F32OF32, packa_mr16_bf16bf16f32of32) \
PAMACRO(S8S8S32OS32, packa_u8s8s32os32) \
PAMACRO(S8S8S16OS16, packa_u8s8s16os16) \
#define LPGEMM_PACKB_FUNC_MAP_AVX512 \
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \
PBMACRO(F32F32F32OF32, packb_nr64_f32f32f32of32) \
PBMACRO(BF16BF16F32OF32, NULL) \
PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
PBMACRO(U8S4S32OS32, packb_nr64_u8s4s32o32) \
PBMACRO(BF16S4F32OF32, NULL) \
PBSMACRO(BF16S4F32OF32, NULL)
@@ -189,30 +169,24 @@
// AVX2
#define LPGEMM_KERN_FUNC_MAP_AVX2 \
KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \
KMACRO(U8S8S32OS32, NULL) \
KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \
KMACRO(BF16BF16F32OF32, NULL) \
KMACRO(BF16S4F32OF32, NULL) \
KMACRO(S8S8S32OS32, NULL) \
KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \
#define LPGEMM_PACKA_FUNC_MAP_AVX2 \
PAMACRO(U8S8S16OS16, packa_u8s8s16os16) \
PAMACRO(U8S8S32OS32, NULL) \
PAMACRO(BF16BF16F32OF32, NULL) \
KMACRO(BF16S4F32OF32, NULL) \
PAMACRO(S8S8S32OS32, NULL) \
PAMACRO(S8S8S16OS16, packa_u8s8s16os16) \
#define LPGEMM_PACKB_FUNC_MAP_AVX2 \
PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \
PBMACRO(U8S8S32OS32, NULL) \
PBMACRO(F32F32F32OF32, packb_nr16_f32f32f32of32) \
PBMACRO(BF16BF16F32OF32, NULL) \
KMACRO(BF16S4F32OF32, NULL) \
PBMACRO(S8S8S32OS32, NULL) \
PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \
PBMACRO(U8S4S32OS32, NULL) \
PBSMACRO(BF16S4F32OF32, NULL) \

View File

@@ -93,11 +93,9 @@ void lpgemm_rowvar_ ## LP_SFX \
) \
LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32);
LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16);
LPGEMM_5LOOP(float,float,float,f32f32f32of32);
LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32);
LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32);
LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16);
#define LPGEMM_5LOOP1(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \

View File

@@ -1,177 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "lpgemm_utils_s8.h"
#include "lpgemm_reorder_s8s16.h"
#include "lpgemm_packb_s8s16.h"
#include "lpgemm_config.h"
void aocl_reorderb_nr32_s8s8s16o16
(
lpgemm_obj_t* b,
lpgemm_obj_t* b_reorder,
rntm_t* rntm,
lpgemm_cntx_t* lcntx
)
{
dim_t NC = lcntx->blksz.NC;
dim_t KC = lcntx->blksz.KC;
dim_t NR = lcntx->blksz.NR;
// Extracting the matrix properties from the lpgemm object
dim_t rs_b = b->rs;
dim_t n = b->width;
dim_t k = b->length;
lpgemm_mod_block_size_s16(0, n, k, NULL, &NC, &KC);
dim_t rs_b_reorder;
dim_t cs_b_reorder;
dim_t k_updated = k;
// Making multiple of 2 to suit k in vpmaddubsw
k_updated += (k_updated & 0x1);
dim_t n_updated = make_multiple_of_n( n, 16 );
dim_t n_threads = bli_rntm_num_threads( rntm );
n_threads = ( n_threads > 0 ) ? n_threads : 1;
// To access the last row of B matrix - Column sum of B matrix
int16_t* pack_b_column_sum = ( int16_t* ) ( b_reorder->storage.aligned_buffer + ( sizeof( int8_t ) * n_updated * k_updated ));
for (dim_t idx = 0; idx < n_updated; idx++ )
{
*( pack_b_column_sum + idx ) = 0;
}
#ifdef BLIS_ENABLE_OPENMP
_Pragma( "omp parallel num_threads(n_threads)" )
{
// Initialise a local thrinfo obj for work split across threads.
thrinfo_t thread_jc;
bli_thrinfo_set_n_way( n_threads, &thread_jc );
bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc );
#else
{
// Initialise a local thrinfo obj for work split across threads.
thrinfo_t thread_jc;
bli_thrinfo_set_n_way( 1, &thread_jc );
bli_thrinfo_set_work_id( 0, &thread_jc );
#endif
// Compute the JC loop thread range for the current thread.
dim_t jc_start, jc_end;
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
{
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
dim_t jc_cur_loop = jc;
dim_t jc_cur_loop_rem = 0;
dim_t n_sub_updated;
get_B_panel_reordered_start_offset_width
(
jc, n, NC, 16,
&jc_cur_loop, &jc_cur_loop_rem,
&nc0, &n_sub_updated
);
for ( dim_t pc = 0; pc < k; pc += KC )
{
dim_t kc0 = bli_min( ( k - pc ), KC );
// kc0 needs to be a multiple of 2 so that it can be used with
// vmaddubsw instruction. Padding is added in cases this
// condition is not satisfied, and therefore the kc0 offsets
// used for packed/reordered buffers needs to be updated.
dim_t kc0_updated = make_multiple_of_n( kc0, 2 );
// The offsets are calculated in such a way that it resembles
// the reorder buffer traversal in single threaded reordering.
// The panel boundaries (KCxNC) remain as it is accessed in
// single thread, and as a consequence a thread with jc_start
// inside the panel cannot consider NC range for reorder. It
// has to work with NC' < NC, and the offset is calulated using
// prev NC panels spanning k dim + cur NC panel spaning pc loop
// cur iteration + (NC - NC') spanning current kc0 (<= KC).
//
//Eg: Consider the following reordered buffer diagram:
// t1 t2
// | |
// | |..NC..|
// | | |
// |.NC. |.NC. |NC'|NC"
// pc=0-+-----+-----+---+--+
// KC| | | | |
// | 1 | 3 | 5 |
// pc=KC-+-----+-----+---st-+
// KC| | | | |
// | 2 | 4 | 6 | 7|
// pc=k=2KC-+-----+-----+---+--+
// |jc=0 |jc=NC|jc=2NC|
//
// The numbers 1,2..6,7 denotes the order in which reordered
// KCxNC blocks are stored in memory, ie: block 1 followed by 2
// followed by 3, etc. Given two threads t1 and t2, and t2 needs
// to acces point st in the reorder buffer to write the data:
// The offset calulation logic will be:
// jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC,
// n_sub_updated = NC, k = 2KC, kc0_updated = KC
//
// st = ( jc_cur_loop * k ) <traverse blocks 1,2,3,4>
// + ( n_sub_updated * pc ) <traverse block 5>
// + ( NC' * kc0_updated) <traverse block 6>
( ( packb_s16_s8 )lcntx->packb_fun_ptr )
(
( ( ( int8_t* )b_reorder->storage.aligned_buffer ) +
( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) +
( jc_cur_loop_rem * kc0_updated ) ),
pack_b_column_sum + jc,
( ( ( int8_t* )b->storage.aligned_buffer ) +
( rs_b * pc ) + jc ),
rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder
);
}
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
}
}
// Changing the packed matrix properties in the packed matrix object
b_reorder->rs = rs_b_reorder;
b_reorder->cs = cs_b_reorder;
b_reorder->mtag = REORDERED;
}

View File

@@ -1,47 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef LPGEMM_REORDER_S8S16_H
#define LPGEMM_REORDER_S8S16_H
#include "lpgemm_types.h"
void aocl_reorderb_nr32_s8s8s16o16
(
lpgemm_obj_t* b,
lpgemm_obj_t* b_reorder,
rntm_t* rntm,
lpgemm_cntx_t* lcntx
);
#endif // LPGEMM_REORDER_S8S16_H

View File

@@ -1,609 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_packb_s8s16.h"
#include "lpgemm_kernels.h"
#include "lpgemm_utils_s8.h"
#include "lpgemm_config.h"
#include "lpgemm_thrinfo_utils.h"
#include "lpgemm_packa_s16.h"
// Kernel function prototypes
typedef void (*lpgemm_rowvar_s16_s8)
(
const dim_t,
const dim_t,
const dim_t,
const int8_t*,
const dim_t,
const dim_t,
const dim_t,
const int8_t*,
const dim_t,
const dim_t,
int16_t*,
const dim_t,
const dim_t,
const int16_t,
const int16_t,
lpgemm_post_op*,
lpgemm_post_op_attr
);
LPGEMV(int8_t,int8_t,int16_t,s8s8s16os16)
{
dim_t KC = lcntx->blksz.KC;
dim_t MC = lcntx->blksz.MC;
// Strides are updated based on matrix packing/reordering.
int8_t* a_use = ( int8_t* )a;
inc_t rs_a_use = rs_a;
inc_t cs_a_use = cs_a;
int8_t* b_use = ( int8_t* )b;
inc_t rs_b_use = rs_b;
inc_t cs_b_use = cs_b;
int16_t *c_use = NULL;
lpgemm_post_op_attr post_ops_attr;
post_ops_attr.c_stor_type = c_downscale;
if (c_downscale < S16) post_ops_attr.buf_downscale = c;
else post_ops_attr.buf_downscale = NULL;
siz_t mem_a_size_req = 0;
siz_t mem_b_size_req = 0;
mem_t mem_a = BLIS_MEM_INITIALIZER;
mem_t mem_b = BLIS_MEM_INITIALIZER;
int8_t* pack_a_buffer;
int8_t* pack_b_buffer;
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
thrinfo_t thread_jc;
thrinfo_t thread_ic;
lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic );
// Increased MR from 6 to 8 to make use of 16 ymm regs
dim_t MR = 8;
// Pack B matrix if rs_b > 1
if( ( mtag_b == PACK ) )
{
mem_b_size_req = sizeof( int8_t ) * k + sizeof( int16_t );
lpgemm_alloc_mem_panel
(
mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE,
&mem_b, rntm
);
pack_b_buffer = ( int8_t* ) bli_mem_buffer( &mem_b );
int16_t* pack_b_column_sum = ( int16_t* ) ( pack_b_buffer +
( sizeof( int8_t ) * k ));
*pack_b_column_sum = 0;
for( dim_t k0 = 0; k0 < k; k0++ )
{
pack_b_buffer[k0] = b[ k0*rs_b ];
*pack_b_column_sum += pack_b_buffer[k0];
}
*pack_b_column_sum *= 128;
post_ops_attr.b_col_sum_vec_s16 = pack_b_column_sum;
b_use = pack_b_buffer;
rs_b_use = 1;
cs_b_use = 1;
}
else if ( mtag_b == REORDERED )
{
post_ops_attr.b_col_sum_vec_s16 = ( int16_t* ) ( b + k );
}
// Compute the IC loop thread range for the current thread.
dim_t ic_start, ic_end;
thread_ic.n_way = ( thread_ic.n_way == 1 ) ?
( thread->n_threads ) : ( thread_ic.n_way );
thread_ic.work_id = thread->tid;
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
for (dim_t ic = ic_start; ic < ic_end; ic += MC)
{
dim_t mc0 = bli_min((ic_end - ic), MC);
a_use = (int8_t*)a + ic * rs_a;
c_use = c + ic * rs_c;
post_ops_attr.post_op_c_i = ic;
post_ops_attr.post_op_c_j = 0;
post_ops_attr.rs_c_downscale = rs_c;
if( mtag_a == PACK )
{
mem_a_size_req = sizeof( int8_t ) * mc0 * k;
lpgemm_alloc_mem_panel
(
mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE,
&mem_a, rntm
);
pack_a_buffer = ( int8_t* ) bli_mem_buffer( &mem_a );
( ( packa_s16 ) lcntx->packa_fun_ptr )
(
( uint8_t* )pack_a_buffer,
( uint8_t* )( a + ( rs_a * ic )), rs_a, cs_a,
mc0, k,
&rs_a_use, &cs_a_use
);
a_use = pack_a_buffer;
}
// Call lpgemv_n_one kernel
lpgemv_n_one_s8s8s16os16
(
mc0, k,
a_use, rs_a_use, cs_a_use, mtag_a,
b_use, rs_b_use, cs_b_use, mtag_b,
c_use, rs_c, cs_c,
alpha, beta,
MR, KC,
post_op_list,
&post_ops_attr
);
}
// Release pack buffers
if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) )
{
bli_pba_release(rntm, &mem_a);
}
if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) )
{
bli_pba_release(rntm, &mem_b);
}
}
// B should always be packed.
LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16)
{
dim_t NC = lcntx->blksz.NC;
dim_t KC = lcntx->blksz.KC;
dim_t MC = lcntx->blksz.MC;
const dim_t NR = lcntx->blksz.NR;
const dim_t MR = lcntx->blksz.MR;
lpgemm_mod_block_size_s16(m, n, k, &MC, &NC, &KC);
if (mtag_b == UNPACKED)
{
// Error: can only work with packed B now.
return;
}
if( n == 1 )
{
lpgemv_rowvar_s8s8s16os16( m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha,
beta,
rntm,
thread,
lcntx,
post_op_list,
c_downscale );
return;
}
const int8_t *b_use;
const int8_t *a_use;
dim_t rs_a_use = rs_a;
dim_t cs_a_use = cs_a;
dim_t a_block_stride = 0;
dim_t rs_b_use = rs_b;
dim_t cs_b_use = cs_b;
int16_t *c_use_jc = NULL;
int16_t *c_use_ic = NULL;
dim_t rs_c_use = rs_c;
dim_t rs_c_downscale = rs_c;
// Pack buffer for A.
int8_t* pack_a_buffer_s8s8s16o16;
mem_t mem_a = BLIS_MEM_INITIALIZER;
siz_t mem_a_size_req = 0;
// Pack buffer for B.
int8_t *pack_b_buffer_s8s8s16o16;
mem_t mem_b = BLIS_MEM_INITIALIZER;
dim_t packb_min_NR = 16;
siz_t mem_b_size_req = 0;
// Temporary buffer for C accumulation when downscaling is required.
int16_t* temp_scal_c_buffer_s8s8s16o16;
mem_t mem_scale_c = BLIS_MEM_INITIALIZER;
siz_t mem_scale_c_size_req = 0;
// Making multiple of 2 to suit k in vpmaddubsw
dim_t k_updated = make_multiple_of_n( k, 2 );
// Making multiple of 16
dim_t n_updated = make_multiple_of_n( n, 16 );
// To decide whether to apply post ops or not.
bool is_last_k = FALSE;
// To decide whether to use original s8 C or temp buffer for beta scale.
bool is_first_k = FALSE;
lpgemm_post_op_attr post_ops_attr;
post_ops_attr.c_stor_type = c_downscale;
if ( c_downscale < S16 )
{
post_ops_attr.buf_downscale = c;
}
else
{
post_ops_attr.buf_downscale = NULL;
}
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
thrinfo_t thread_jc;
thrinfo_t thread_ic;
lpgemm_gen_thrinfo(thread, &thread_jc, &thread_ic);
// Compute the JC, IC loop thread range for the current thread.
dim_t jc_start, jc_end;
bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end);
dim_t ic_start, ic_end;
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
for (dim_t jc = jc_start; jc < jc_end; jc += NC)
{
dim_t nc0 = bli_min((jc_end - jc), NC);
dim_t jc_cur_loop = jc;
dim_t jc_cur_loop_rem = 0;
dim_t n_sub_updated = 0;
if (mtag_b == REORDERED)
{
get_B_panel_reordered_start_offset_width
(
jc, n, NC, packb_min_NR,
&jc_cur_loop, &jc_cur_loop_rem,
&nc0, &n_sub_updated
);
}
if ( c_downscale == S16 )
{
c_use_jc = c + jc;
}
// Temp accumulaton buffer for C allocation.
else if ( c_downscale < S16 )
{
// Buffer memory is only required if output needs to be
// persisted across iterations of the pc/KC loop.
// It was observed that the locks used while checking out
// a buffer from memory pool had an impact on performance
// and is better to not checkout if k <= KC.
if ( k > KC )
{
mem_scale_c_size_req = sizeof( int16_t ) * nc0 * ( ic_end - ic_start );
lpgemm_alloc_mem_panel
(
mem_scale_c_size_req, BLIS_BUFFER_FOR_GEN_USE,
&mem_scale_c, rntm
);
temp_scal_c_buffer_s8s8s16o16 = bli_mem_buffer( &mem_scale_c );
c_use_jc = ( int16_t* )temp_scal_c_buffer_s8s8s16o16;
}
// The temp c buffer stride is modified as opposed to original C matrix.
rs_c_use = nc0;
}
int16_t* pack_b_column_sum = NULL;
for (dim_t pc = 0; pc < k; pc += KC)
{
int16_t beta0 = (pc == 0) ? beta : 1;
dim_t kc0 = bli_min((k - pc), KC);
// No parallelization in k dim, k always starts at 0.
is_first_k = ( pc == 0 ) ? ( TRUE ) : ( FALSE );
post_ops_attr.is_first_k = is_first_k;
is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE );
post_ops_attr.is_last_k = is_last_k;
// kc0 needs to be a multiple of 2 so that it can be
// used with vpmaddubsw instruction. Padding is added in
// cases this condition is not satisfied, and therefore
// the kc0 offsets used for packed/reordered buffers
// needs to be updated.
dim_t kc0_updated = make_multiple_of_n(kc0, 2);
if (mtag_b == PACK)
{
// Pack B chunks are based on jc work id.
dim_t jc_work_id = bli_thread_work_id(&thread_jc);
// Using child thrinfo (thread_ic) tid to decide chief thread
// per B matrix chunk (jc work id group)
// nc0 needs to be a multiple of 16 since this gives maximum
// vectorization. Packing B always results in buffers with width
// which is a multiple of 16. Subsequently the nc0 offsets used
// for packed/reordered buffers needs to be updated.
dim_t nc0_updated = make_multiple_of_n(nc0, packb_min_NR);
if (bli_thread_am_ochief(&thread_ic))
{
mem_b_size_req = sizeof(int8_t) * nc0_updated * kc0_updated + ( nc0_updated * sizeof( int16_t ) );
lpgemm_alloc_mem_panel(
mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL,
&mem_b, rntm);
thread->comm[jc_work_id].sent_object =
bli_mem_buffer(&mem_b);
}
// All threads in work group should wait till chief thread has
// finished allocating the packing buffers.
bli_thrcomm_barrier
(
bli_thread_ocomm_id(&thread_ic),
&thread->comm[jc_work_id]
);
pack_b_buffer_s8s8s16o16 =
(int8_t *)thread->comm[jc_work_id].sent_object;
// Compute the B panel per thread loop range for parallel
// packing using ic_ways number of threads. Since atmost only
// ic_ways threads can be used, the thread_ic attributes are
// used to split the loop range.
dim_t jc_packb_start, jc_packb_end;
bli_thread_range_sub
(
&thread_ic, nc0, NR, FALSE,
&jc_packb_start, &jc_packb_end
);
if ( pc == 0)
{
pack_b_column_sum = ( int16_t* )( pack_b_buffer_s8s8s16o16 + ( sizeof( int8_t ) * nc0_updated * kc0_updated ) );
}
// Ensure thread ranges are valid, especially cases where no:
// of threads available for parallelization are greater than
// no: of B panel NR chunks.
if ((jc_packb_end > jc_packb_start) &&
(jc_packb_start < (jc + nc0)))
{
if ( pc == 0 )
{
for (int idx = jc_packb_start; idx < jc_packb_end; idx++ )
{
*( pack_b_column_sum + idx ) = 0;
}
}
( ( packb_s16_s8 )lcntx->packb_fun_ptr )
(
pack_b_buffer_s8s8s16o16 +
(jc_packb_start * kc0_updated),
pack_b_column_sum + ( cs_b * jc_packb_start ),
(b + (rs_b * pc) + (cs_b * jc) +
(cs_b * jc_packb_start)),
rs_b,
(jc_packb_end - jc_packb_start), kc0,
&rs_b_use, &cs_b_use
);
}
else
{
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
}
// All threads in work group should wait till B matrix packing
// is completed by the participating threads.
bli_thrcomm_barrier
(
bli_thread_ocomm_id(&thread_ic),
&thread->comm[jc_work_id]
);
b_use = pack_b_buffer_s8s8s16o16;
post_ops_attr.b_col_sum_vec_s16 = pack_b_column_sum;
}
else if (mtag_b == REORDERED)
{
// In multi-threaded scenarios, an extra offset into a given
// packed B panel is required, since the jc loop split can
// result in per thread start offset inside the panel, instead
// of panel boundaries.
b_use = b + (jc_cur_loop * k_updated) +
(n_sub_updated * pc) +
(jc_cur_loop_rem * kc0_updated);
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
post_ops_attr.b_col_sum_vec_s16 = ( ( int16_t* )( b + ( k_updated * n_updated ) ) ) + jc;
}
else
{
// Unpacked B not supported.
return;
}
for (dim_t ic = ic_start; ic < ic_end; ic += MC)
{
dim_t mc0 = bli_min((ic_end - ic), MC);
// Only per thread C matrix is stored in temp buffer, so both
// per thread jc and ic start should be normalized to zero.
if ( c_downscale < S16 )
{
c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) );
}
else
{
c_use_ic = c_use_jc + ( rs_c_use * ic );
}
// Matrix A packed and reordered code path is not triggerred
// currently for row-major inputs since we do not support it yet.
// Pack is enabled for column-major inputs to transform into
// row-major inputs as kernel expects row storage format.
if ( mtag_a == PACK )
{
mem_a_size_req = sizeof( uint8_t ) * mc0 * kc0_updated;
lpgemm_alloc_mem_panel
(
mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK,
&mem_a, rntm
);
pack_a_buffer_s8s8s16o16 = ( int8_t* )bli_mem_buffer( &mem_a );
( ( packa_s16 )lcntx->packa_fun_ptr )
(
( uint8_t* )pack_a_buffer_s8s8s16o16,
( uint8_t* )( a + ( rs_a * ic ) + ( cs_a * pc ) ), rs_a, cs_a,
mc0, kc0,
&rs_a_use, &cs_a_use
);
a_use = pack_a_buffer_s8s8s16o16;
if( cs_a == 1 )
{
a_block_stride = kc0_updated;
}
else
{
a_block_stride = rs_a_use;
}
}
else
{
a_use = a + ( rs_a * ic ) + ( cs_a * pc );
cs_a_use = 1;
a_block_stride = rs_a;
}
post_ops_attr.b_sum_offset = 0;
for (dim_t jr = 0; jr < nc0; jr += NR)
{
dim_t nr0 = bli_min((nc0 - jr), NR);
// Post ops meta attributes.
post_ops_attr.post_op_c_i = ic;
post_ops_attr.post_op_c_j = ( jc + jr );
post_ops_attr.rs_c_downscale = rs_c_downscale;
// Calls for reorder B
( ( lpgemm_rowvar_s16_s8 )lcntx->kern_fun_ptr )
(
mc0, nr0, kc0,
a_use, rs_a_use, cs_a_use, a_block_stride,
(b_use + (jr * kc0_updated)), rs_b_use, cs_b_use,
(c_use_ic + jr), rs_c_use, 1,
alpha, beta0,
post_op_list, post_ops_attr
);
post_ops_attr.b_sum_offset += NR;
}
}
}
if (mtag_b == REORDERED)
{
adjust_B_panel_reordered_jc(&jc, jc_cur_loop);
}
}
// Release pack buffers.
if (mtag_b == PACK)
{
// All threads in work group should wait till B matrix usage is
// completed by the participating threads.
bli_thrcomm_barrier(
bli_thread_ocomm_id(&thread_jc),
&thread->comm[bli_thread_work_id(&thread_jc)]);
if (bli_thread_am_ochief(&thread_ic))
{
if (bli_mem_is_alloc(&mem_b))
{
bli_pba_release(rntm, &mem_b);
}
}
}
if ( c_downscale < S16 )
{
if ( bli_mem_is_alloc( &mem_scale_c ) )
{
bli_pba_release( rntm, &mem_scale_c );
}
}
}

View File

@@ -293,105 +293,6 @@ BLIS_INLINE void lpgemm_adjust_ic_jc_ways
}
}
BLIS_INLINE void lpgemm_s16o16_get_threading
(
dim_t* n_threads,
dim_t* ic_ways,
dim_t* jc_ways,
dim_t m,
dim_t n,
dim_t k,
rntm_t* rntm_g,
AOCL_OPERATION_TYPE op_type
)
{
*n_threads = bli_rntm_num_threads( rntm_g );
*jc_ways = bli_rntm_jc_ways( rntm_g );
*ic_ways = bli_rntm_ic_ways( rntm_g );
if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) )
{
// If BLIS_IC_NT or JC_NT are set.
// Default cases.
*ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1;
*jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1;
*n_threads = ( *jc_ways ) * ( *ic_ways );
}
else if ( ( *n_threads ) > 1 )
{
dim_t NR = lpgemm_get_block_size_NR_global_cntx( op_type );
dim_t MR = lpgemm_get_block_size_MR_global_cntx( op_type );
if ( n <= NR )
{
( *ic_ways ) = ( *n_threads );
( *jc_ways ) = 1;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else if ( m <= MR )
{
( *jc_ways ) = ( *n_threads );
( *ic_ways ) = 1;
( *n_threads ) = ( *ic_ways ) * ( *jc_ways );
}
else
{
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways );
}
}
else
{
// Setting all the values to 1 in case n_threads <= 1. This ensures
// the threading parameters are valid.
*n_threads = 1;
*jc_ways = 1;
*ic_ways = 1;
}
}
BLIS_INLINE void lpgemm_u8s8s16o16_get_threading
(
dim_t* n_threads,
dim_t* ic_ways,
dim_t* jc_ways,
dim_t m,
dim_t n,
dim_t k,
rntm_t* rntm_g
)
{
lpgemm_s16o16_get_threading
(
n_threads,
ic_ways, jc_ways,
m, n, k, rntm_g,
U8S8S16OS16
);
}
BLIS_INLINE void lpgemm_s8s8s16o16_get_threading
(
dim_t* n_threads,
dim_t* ic_ways,
dim_t* jc_ways,
dim_t m,
dim_t n,
dim_t k,
rntm_t* rntm_g
)
{
lpgemm_s16o16_get_threading
(
n_threads,
ic_ways, jc_ways,
m, n, k, rntm_g,
S8S8S16OS16
);
}
BLIS_INLINE void lpgemm_s32o32_get_threading
(
dim_t* n_threads,
@@ -1162,12 +1063,10 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
} \
} \
GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16)
GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32)
GEN_LPGEMM_OPENMP_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32)
GEN_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32)
GEN_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32)
GEN_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int16_t,s8s8s16o16)
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \
@@ -1807,12 +1706,10 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
); \
} \
GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16)
GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32)
GEN_LPGEMM_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32)
GEN_LPGEMM_DECORATOR(float,float,float,f32f32f32of32)
GEN_LPGEMM_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32)
GEN_LPGEMM_DECORATOR(int8_t,int8_t,int16_t,s8s8s16o16)
#define GEN_LPGEMM_DECORATOR1(A_type,B_type,C_type,LPGEMM_SFX) \
void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \

View File

@@ -66,12 +66,10 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
AOCL_STORAGE_TYPE c_downscale \
); \
GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16)
GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32)
GEN_LPGEMM_OPENMP_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32)
GEN_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32)
GEN_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32)
GEN_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int16_t,s8s8s16o16)
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(A_type,B_type,C_type,LPGEMM_SFX) \
@@ -211,12 +209,10 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
AOCL_STORAGE_TYPE c_downscale \
); \
GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16)
GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32)
GEN_LPGEMM_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32)
GEN_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32)
GEN_LPGEMM_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32)
GEN_LPGEMM_DECORATOR_FN(int8_t,int8_t,int16_t,s8s8s16o16)
#define GEN_BATCH_LPGEMM_DECORATOR_FN(A_type,B_type,C_type,LPGEMM_SFX) \

View File

@@ -1,167 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "lpgemm_utils.h"
#include "lpgemm_reorder_s16.h"
#include "lpgemm_packb_s16.h"
#include "lpgemm_config.h"
void aocl_reorderb_nr32_u8s8s16o16
(
lpgemm_obj_t* b,
lpgemm_obj_t* b_reorder,
rntm_t* rntm,
lpgemm_cntx_t* lcntx
)
{
dim_t NC = lcntx->blksz.NC;
dim_t KC = lcntx->blksz.KC;
dim_t NR = lcntx->blksz.NR;
// Extracting the matrix properties from the lpgemm object
dim_t rs_b = b->rs;
dim_t n = b->width;
dim_t k = b->length;
lpgemm_mod_block_size_s16(0, n, k, NULL, &NC, &KC);
dim_t rs_b_reorder;
dim_t cs_b_reorder;
dim_t k_updated = k;
// Making multiple of 2 to suit k in vpmaddubsw
k_updated += (k_updated & 0x1);
dim_t n_threads = bli_rntm_num_threads( rntm );
n_threads = ( n_threads > 0 ) ? n_threads : 1;
#ifdef BLIS_ENABLE_OPENMP
_Pragma( "omp parallel num_threads(n_threads)" )
{
// Initialise a local thrinfo obj for work split across threads.
thrinfo_t thread_jc;
bli_thrinfo_set_n_way( n_threads, &thread_jc );
bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc );
#else
{
// Initialise a local thrinfo obj for work split across threads.
thrinfo_t thread_jc;
bli_thrinfo_set_n_way( 1, &thread_jc );
bli_thrinfo_set_work_id( 0, &thread_jc );
#endif
// Compute the JC loop thread range for the current thread.
dim_t jc_start, jc_end;
bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end );
for ( dim_t jc = jc_start; jc < jc_end; jc += NC )
{
dim_t nc0 = bli_min( ( jc_end - jc ), NC );
dim_t jc_cur_loop = jc;
dim_t jc_cur_loop_rem = 0;
dim_t n_sub_updated;
get_B_panel_reordered_start_offset_width
(
jc, n, NC, 16,
&jc_cur_loop, &jc_cur_loop_rem,
&nc0, &n_sub_updated
);
for ( dim_t pc = 0; pc < k; pc += KC )
{
dim_t kc0 = bli_min( ( k - pc ), KC );
// kc0 needs to be a multiple of 2 so that it can be used with
// vmaddubsw instruction. Padding is added in cases this
// condition is not satisfied, and therefore the kc0 offsets
// used for packed/reordered buffers needs to be updated.
dim_t kc0_updated = make_multiple_of_n( kc0, 2 );
// The offsets are calculated in such a way that it resembles
// the reorder buffer traversal in single threaded reordering.
// The panel boundaries (KCxNC) remain as it is accessed in
// single thread, and as a consequence a thread with jc_start
// inside the panel cannot consider NC range for reorder. It
// has to work with NC' < NC, and the offset is calulated using
// prev NC panels spanning k dim + cur NC panel spaning pc loop
// cur iteration + (NC - NC') spanning current kc0 (<= KC).
//
//Eg: Consider the following reordered buffer diagram:
// t1 t2
// | |
// | |..NC..|
// | | |
// |.NC. |.NC. |NC'|NC"
// pc=0-+-----+-----+---+--+
// KC| | | | |
// | 1 | 3 | 5 |
// pc=KC-+-----+-----+---st-+
// KC| | | | |
// | 2 | 4 | 6 | 7|
// pc=k=2KC-+-----+-----+---+--+
// |jc=0 |jc=NC|jc=2NC|
//
// The numbers 1,2..6,7 denotes the order in which reordered
// KCxNC blocks are stored in memory, ie: block 1 followed by 2
// followed by 3, etc. Given two threads t1 and t2, and t2 needs
// to acces point st in the reorder buffer to write the data:
// The offset calulation logic will be:
// jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC,
// n_sub_updated = NC, k = 2KC, kc0_updated = KC
//
// st = ( jc_cur_loop * k ) <traverse blocks 1,2,3,4>
// + ( n_sub_updated * pc ) <traverse block 5>
// + ( NC' * kc0_updated) <traverse block 6>
( ( packb_s16 )lcntx->packb_fun_ptr )
(
( ( ( int8_t* )b_reorder->storage.aligned_buffer ) +
( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) +
( jc_cur_loop_rem * kc0_updated ) ),
( ( ( int8_t* )b->storage.aligned_buffer ) +
( rs_b * pc ) + jc ),
rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder
);
}
adjust_B_panel_reordered_jc( &jc, jc_cur_loop );
}
}
// Changing the packed matrix properties in the packed matrix object
b_reorder->rs = rs_b_reorder;
b_reorder->cs = cs_b_reorder;
b_reorder->mtag = REORDERED;
}

View File

@@ -1,47 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef LPGEMM_REORDER_S16_H
#define LPGEMM_REORDER_S16_H
#include "lpgemm_types.h"
void aocl_reorderb_nr32_u8s8s16o16
(
lpgemm_obj_t* b,
lpgemm_obj_t* b_reorder,
rntm_t* rntm,
lpgemm_cntx_t* lcntx
);
#endif // LPGEMM_REORDER_H

View File

@@ -1,573 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_packb_s16.h"
#include "lpgemm_packa_s16.h"
#include "lpgemm_kernels.h"
#include "lpgemm_utils.h"
#include "lpgemm_config.h"
#include "lpgemm_thrinfo_utils.h"
// Kernel function prototypes
typedef void (*lpgemm_rowvar_s16)
(
const dim_t,
const dim_t,
const dim_t,
const uint8_t*,
const dim_t,
const dim_t,
const dim_t,
const int8_t*,
const dim_t,
const dim_t,
int16_t*,
const dim_t,
const dim_t,
const int16_t,
const int16_t,
lpgemm_post_op*,
lpgemm_post_op_attr
);
LPGEMV(uint8_t,int8_t,int16_t,u8s8s16os16)
{
dim_t KC = lcntx->blksz.KC;
dim_t MC = lcntx->blksz.MC;
// Strides are updated based on matrix packing/reordering.
uint8_t* a_use = ( uint8_t* )a;
inc_t rs_a_use = rs_a;
inc_t cs_a_use = cs_a;
int8_t* b_use = ( int8_t* )b;
inc_t rs_b_use = rs_b;
inc_t cs_b_use = cs_b;
int16_t *c_use = NULL;
lpgemm_post_op_attr post_ops_attr;
post_ops_attr.c_stor_type = c_downscale;
if (c_downscale < S16) post_ops_attr.buf_downscale = c;
else post_ops_attr.buf_downscale = NULL;
siz_t mem_a_size_req = 0;
siz_t mem_b_size_req = 0;
mem_t mem_a = BLIS_MEM_INITIALIZER;
mem_t mem_b = BLIS_MEM_INITIALIZER;
uint8_t* pack_a_buffer;
int8_t* pack_b_buffer;
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
thrinfo_t thread_jc;
thrinfo_t thread_ic;
lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic );
// Increased MR from 6 to 8 to make use of 16 ymm regs
dim_t MR = 8;
// Pack B matrix if rs_b > 1
if( ( mtag_b == PACK ) && ( rs_b != 1 ) )
{
mem_b_size_req = sizeof( int8_t ) * k;
lpgemm_alloc_mem_panel
(
mem_b_size_req, BLIS_BUFFER_FOR_GEN_USE,
&mem_b, rntm
);
pack_b_buffer = ( int8_t* ) bli_mem_buffer( &mem_b );
for( dim_t k0 = 0; k0 < k; k0++ )
{
pack_b_buffer[k0] = b[ k0*rs_b ];
}
b_use = pack_b_buffer;
rs_b_use = 1;
cs_b_use = 1;
}
// Compute the IC loop thread range for the current thread.
dim_t ic_start, ic_end;
thread_ic.n_way = ( thread_ic.n_way == 1 ) ?
( thread->n_threads ) : ( thread_ic.n_way );
thread_ic.work_id = thread->tid;
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
for (dim_t ic = ic_start; ic < ic_end; ic += MC)
{
dim_t mc0 = bli_min((ic_end - ic), MC);
a_use = (uint8_t*)a + ic * rs_a;
c_use = c + ic * rs_c;
post_ops_attr.post_op_c_i = ic;
post_ops_attr.post_op_c_j = 0;
post_ops_attr.rs_c_downscale = rs_c;
if( mtag_a == PACK )
{
mem_a_size_req = sizeof( uint8_t ) * mc0 * k;
lpgemm_alloc_mem_panel
(
mem_a_size_req, BLIS_BUFFER_FOR_GEN_USE,
&mem_a, rntm
);
pack_a_buffer = ( uint8_t* ) bli_mem_buffer( &mem_a );
( ( packa_s16 ) lcntx->packa_fun_ptr )
(
pack_a_buffer,
( a + ( rs_a * ic )), rs_a, cs_a,
mc0, k,
&rs_a_use, &cs_a_use
);
a_use = pack_a_buffer;
}
// Call lpgemv_n_one kernel
lpgemv_n_one_u8s8s16os16
(
mc0, k,
a_use, rs_a_use, cs_a_use, mtag_a,
b_use, rs_b_use, cs_b_use, mtag_b,
c_use, rs_c, cs_c,
alpha, beta,
MR, KC,
post_op_list,
&post_ops_attr
);
}
// Release pack buffers
if( mtag_a == PACK && bli_mem_is_alloc( &mem_a ) )
{
bli_pba_release(rntm, &mem_a);
}
if( mtag_b == PACK && bli_mem_is_alloc( &mem_b ) )
{
bli_pba_release(rntm, &mem_b);
}
}
// B should always be packed.
LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16)
{
dim_t NC = lcntx->blksz.NC;
dim_t KC = lcntx->blksz.KC;
dim_t MC = lcntx->blksz.MC;
const dim_t NR = lcntx->blksz.NR;
const dim_t MR = lcntx->blksz.MR;
lpgemm_mod_block_size_s16(m, n, k, &MC, &NC, &KC);
if (mtag_b == UNPACKED)
{
// Error: can only work with packed B now.
return;
}
if( n == 1 )
{
lpgemv_rowvar_u8s8s16os16( m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha,
beta,
rntm,
thread,
lcntx,
post_op_list,
c_downscale );
return;
}
const int8_t *b_use;
const uint8_t *a_use;
dim_t rs_a_use = rs_a;
dim_t cs_a_use = cs_a;
dim_t a_block_stride = 0;
dim_t rs_b_use = rs_b;
dim_t cs_b_use = cs_b;
int16_t *c_use_jc = NULL;
int16_t *c_use_ic = NULL;
dim_t rs_c_use = rs_c;
dim_t rs_c_downscale = rs_c;
// Pack buffer for A.
uint8_t* pack_a_buffer_u8s8s16o16;
mem_t mem_a = BLIS_MEM_INITIALIZER;
siz_t mem_a_size_req = 0;
// Pack buffer for B.
int8_t *pack_b_buffer_u8s8s16o16;
mem_t mem_b = BLIS_MEM_INITIALIZER;
dim_t packb_min_NR = 16;
siz_t mem_b_size_req = 0;
// Temporary buffer for C accumulation when downscaling is required.
int16_t* temp_scal_c_buffer_u8s8s16o16;
mem_t mem_scale_c = BLIS_MEM_INITIALIZER;
siz_t mem_scale_c_size_req = 0;
// Making multiple of 2 to suit k in vpmaddubsw
dim_t k_updated = make_multiple_of_n( k, 2 );
// To decide whether to apply post ops or not.
bool is_last_k = FALSE;
// To decide whether to use original s8 C or temp buffer for beta scale.
bool is_first_k = FALSE;
lpgemm_post_op_attr post_ops_attr;
post_ops_attr.c_stor_type = c_downscale;
if ( c_downscale < S16 )
{
post_ops_attr.buf_downscale = c;
}
else
{
post_ops_attr.buf_downscale = NULL;
}
// Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t.
thrinfo_t thread_jc;
thrinfo_t thread_ic;
lpgemm_gen_thrinfo(thread, &thread_jc, &thread_ic);
// Compute the JC, IC loop thread range for the current thread.
dim_t jc_start, jc_end;
bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end);
dim_t ic_start, ic_end;
bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end);
for (dim_t jc = jc_start; jc < jc_end; jc += NC)
{
dim_t nc0 = bli_min((jc_end - jc), NC);
dim_t jc_cur_loop = jc;
dim_t jc_cur_loop_rem = 0;
dim_t n_sub_updated = 0;
if (mtag_b == REORDERED)
{
get_B_panel_reordered_start_offset_width
(
jc, n, NC, packb_min_NR,
&jc_cur_loop, &jc_cur_loop_rem,
&nc0, &n_sub_updated
);
}
if ( c_downscale == S16 )
{
c_use_jc = c + jc;
}
// Temp accumulaton buffer for C allocation.
else if ( c_downscale < S16 )
{
// Buffer memory is only required if output needs to be
// persisted across iterations of the pc/KC loop.
// It was observed that the locks used while checking out
// a buffer from memory pool had an impact on performance
// and is better to not checkout if k <= KC.
if ( k > KC )
{
mem_scale_c_size_req = sizeof( int16_t ) * nc0 * ( ic_end - ic_start );
lpgemm_alloc_mem_panel
(
mem_scale_c_size_req, BLIS_BUFFER_FOR_GEN_USE,
&mem_scale_c, rntm
);
temp_scal_c_buffer_u8s8s16o16 = bli_mem_buffer( &mem_scale_c );
c_use_jc = ( int16_t* )temp_scal_c_buffer_u8s8s16o16;
}
// The temp c buffer stride is modified as opposed to original C matrix.
rs_c_use = nc0;
}
for (dim_t pc = 0; pc < k; pc += KC)
{
int16_t beta0 = (pc == 0) ? beta : 1;
dim_t kc0 = bli_min((k - pc), KC);
// No parallelization in k dim, k always starts at 0.
is_first_k = ( pc == 0 ) ? ( TRUE ) : ( FALSE );
post_ops_attr.is_first_k = is_first_k;
is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE );
post_ops_attr.is_last_k = is_last_k;
// kc0 needs to be a multiple of 2 so that it can be
// used with vpmaddubsw instruction. Padding is added in
// cases this condition is not satisfied, and therefore
// the kc0 offsets used for packed/reordered buffers
// needs to be updated.
dim_t kc0_updated = make_multiple_of_n(kc0, 2);
if (mtag_b == PACK)
{
// Pack B chunks are based on jc work id.
dim_t jc_work_id = bli_thread_work_id(&thread_jc);
// Using child thrinfo (thread_ic) tid to decide chief thread
// per B matrix chunk (jc work id group)
if (bli_thread_am_ochief(&thread_ic))
{
// nc0 needs to be a multiple of 16 since this gives maximum
// vectorization. Packing B always results in buffers with width
// which is a multiple of 16. Subsequently the nc0 offsets used
// for packed/reordered buffers needs to be updated.
dim_t nc0_updated = make_multiple_of_n(nc0, packb_min_NR);
mem_b_size_req = sizeof(int8_t) * nc0_updated * kc0_updated;
lpgemm_alloc_mem_panel(
mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL,
&mem_b, rntm);
thread->comm[jc_work_id].sent_object =
bli_mem_buffer(&mem_b);
}
// All threads in work group should wait till chief thread has
// finished allocating the packing buffers.
bli_thrcomm_barrier
(
bli_thread_ocomm_id(&thread_ic),
&thread->comm[jc_work_id]
);
pack_b_buffer_u8s8s16o16 =
(int8_t *)thread->comm[jc_work_id].sent_object;
// Compute the B panel per thread loop range for parallel
// packing using ic_ways number of threads. Since atmost only
// ic_ways threads can be used, the thread_ic attributes are
// used to split the loop range.
dim_t jc_packb_start, jc_packb_end;
bli_thread_range_sub
(
&thread_ic, nc0, NR, FALSE,
&jc_packb_start, &jc_packb_end
);
// Ensure thread ranges are valid, especially cases where no:
// of threads available for parallelization are greater than
// no: of B panel NR chunks.
if ((jc_packb_end > jc_packb_start) &&
(jc_packb_start < (jc + nc0)))
{
( ( packb_s16 )lcntx->packb_fun_ptr )
(
pack_b_buffer_u8s8s16o16 +
(jc_packb_start * kc0_updated),
(b + (rs_b * pc) + (cs_b * jc) +
(cs_b * jc_packb_start)),
rs_b,
(jc_packb_end - jc_packb_start), kc0,
&rs_b_use, &cs_b_use
);
}
else
{
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
}
// All threads in work group should wait till B matrix packing
// is completed by the participating threads.
bli_thrcomm_barrier
(
bli_thread_ocomm_id(&thread_ic),
&thread->comm[jc_work_id]
);
b_use = pack_b_buffer_u8s8s16o16;
}
else if (mtag_b == REORDERED)
{
// In multi-threaded scenarios, an extra offset into a given
// packed B panel is required, since the jc loop split can
// result in per thread start offset inside the panel, instead
// of panel boundaries.
b_use = b + (jc_cur_loop * k_updated) +
(n_sub_updated * pc) +
(jc_cur_loop_rem * kc0_updated);
lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use );
}
else
{
// Unpacked B not supported.
return;
}
for (dim_t ic = ic_start; ic < ic_end; ic += MC)
{
dim_t mc0 = bli_min((ic_end - ic), MC);
// Only per thread C matrix is stored in temp buffer, so both
// per thread jc and ic start should be normalized to zero.
if ( c_downscale < S16 )
{
c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) );
}
else
{
c_use_ic = c_use_jc + ( rs_c_use * ic );
}
// Matrix A packed and reordered code path is not triggerred
// currently for row-major inputs since we do not support it yet.
// Pack is enabled for column-major inputs to transform into
// row-major inputs as kernel expects row storage format.
if ( mtag_a == PACK )
{
mem_a_size_req = sizeof( uint8_t ) * mc0 * kc0_updated;
lpgemm_alloc_mem_panel
(
mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK,
&mem_a, rntm
);
pack_a_buffer_u8s8s16o16 = ( uint8_t* )bli_mem_buffer( &mem_a );
( ( packa_s16 )lcntx->packa_fun_ptr )
(
pack_a_buffer_u8s8s16o16,
( a + ( rs_a * ic ) + ( cs_a * pc ) ), rs_a, cs_a,
mc0, kc0,
&rs_a_use, &cs_a_use
);
a_use = pack_a_buffer_u8s8s16o16;
if( cs_a == 1 )
{
a_block_stride = kc0_updated;
}
else
{
a_block_stride = rs_a_use;
}
}
else if ( mtag_a == REORDERED )
{
lpgemm_get_packa_strides( lcntx, &rs_a_use, &cs_a_use );
a_use = a + ( pc * m ) + ( kc0_updated * ic );
a_block_stride = kc0_updated;
}
else
{
a_use = a + ( rs_a * ic ) + ( cs_a * pc );
cs_a_use = 1;
a_block_stride = rs_a;
}
for (dim_t jr = 0; jr < nc0; jr += NR)
{
dim_t nr0 = bli_min((nc0 - jr), NR);
// Post ops meta attributes.
post_ops_attr.post_op_c_i = ic;
post_ops_attr.post_op_c_j = ( jc + jr );
post_ops_attr.rs_c_downscale = rs_c_downscale;
// Calls for reorder B
( ( lpgemm_rowvar_s16 )lcntx->kern_fun_ptr )
(
mc0, nr0, kc0,
a_use, rs_a_use, cs_a_use, a_block_stride,
(b_use + (jr * kc0_updated)), rs_b_use, cs_b_use,
(c_use_ic + jr), rs_c_use, 1,
alpha, beta0,
post_op_list, post_ops_attr
);
}
}
}
if (mtag_b == REORDERED)
{
adjust_B_panel_reordered_jc(&jc, jc_cur_loop);
}
}
// Release pack buffers.
if (mtag_b == PACK)
{
// All threads in work group should wait till B matrix usage is
// completed by the participating threads.
bli_thrcomm_barrier(
bli_thread_ocomm_id(&thread_jc),
&thread->comm[bli_thread_work_id(&thread_jc)]);
if (bli_thread_am_ochief(&thread_ic))
{
if (bli_mem_is_alloc(&mem_b))
{
bli_pba_release(rntm, &mem_b);
}
}
}
if ( c_downscale < S16 )
{
if ( bli_mem_is_alloc( &mem_scale_c ) )
{
bli_pba_release( rntm, &mem_scale_c );
}
}
}

View File

@@ -87,12 +87,10 @@ void lpgemm_rowvar_ ## LP_SFX \
) \
LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64);
LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32);
LPGEMM_MAIN_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x64);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m);
LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m);
LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64);
LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32);
#define LPGEMM_MAIN_KERN1(A_type,B_type,C_type,LP_SFX) \
@@ -144,10 +142,6 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64);
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64);
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64);
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32);
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32);
LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32);
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x64);
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x64);
LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x64);
@@ -201,10 +195,6 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x64);
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x64);
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x64);
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32);
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32);
LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32);
#define LPGEMM_M_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
@@ -258,8 +248,6 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32);
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_9x32);
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48);
LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16);
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x16);
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x32);
LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x48);
@@ -275,8 +263,6 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x16);
LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x32);
LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x48);
LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16);
#define LPGEMM_N_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
@@ -329,14 +315,10 @@ void lpgemm_rowvar_ ## LP_SFX \
LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16);
LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16);
LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16);
LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6xlt16);
LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16);
LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16);
#define LPGEMM_N_LT_NR0_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
@@ -396,10 +378,6 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48);
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48);
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48);
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16);
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16);
LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x16);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x16);
LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x16);
@@ -432,10 +410,6 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x48);
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x48);
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x48);
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16);
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16);
LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x16);
#define LPGEMM_MN_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
@@ -496,10 +470,6 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3xlt16);
@@ -512,10 +482,6 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16);
LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16);
#define LPGEMM_MN_LT_NR0_FRINGE_KERN1(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \
@@ -600,8 +566,6 @@ void lpgemv_n_one_ ## LP_SFX \
LPGEMV_N_EQ1_KERN(float, float, float,f32f32f32of32);
LPGEMV_N_EQ1_KERN(bfloat16, bfloat16, float,bf16bf16f32of32);
LPGEMV_N_EQ1_KERN(uint8_t,int8_t,int32_t,u8s8s32os32);
LPGEMV_N_EQ1_KERN(uint8_t,int8_t,int16_t,u8s8s16os16);
LPGEMV_N_EQ1_KERN(int8_t,int8_t,int32_t,s8s8s32os32);
LPGEMV_N_EQ1_KERN(int8_t,int8_t,int16_t,s8s8s16os16);
#endif //BLIS_LPGEMM_KERN_H

View File

@@ -1,63 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef BLIS_GEMM_S8_INT16_PACKB
#define BLIS_GEMM_S8_INT16_PACKB
typedef void (*packb_s16_s8)
(
int8_t*,
int16_t*,
const int8_t*,
const dim_t,
const dim_t,
const dim_t,
dim_t*,
dim_t*
);
void packb_nr32_s8s8s16o16
(
int8_t *pack_b_buffer_s8s8s16o16,
int16_t *pack_b_column_sum,
const int8_t *b,
const dim_t ldb,
const dim_t cols,
const dim_t rows,
dim_t *rs_b,
dim_t *cs_b
);
#endif // BLIS_GEMM_S8_INT16_PACKB

View File

@@ -1,62 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef BLIS_GEMM_INT8_U8S8S16_PACKA
#define BLIS_GEMM_INT8_U8S8S16_PACKA
typedef void (*packa_s16)
(
uint8_t*,
const uint8_t*,
const dim_t,
const dim_t,
const dim_t,
const dim_t,
dim_t*,
dim_t*
);
void packa_u8s8s16os16
(
uint8_t* pack_a_buffer_u8s8s16o16,
const uint8_t* a,
const dim_t rs,
const dim_t cs,
const dim_t MC,
const dim_t KC,
dim_t* rs_a,
dim_t* cs_a
);
#endif //BLIS_GEMM_INT8_U8S8S16_PACKA

View File

@@ -1,60 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef BLIS_GEMM_INT16_PACKB
#define BLIS_GEMM_INT16_PACKB
typedef void (*packb_s16)
(
int8_t*,
const int8_t*,
const dim_t,
const dim_t,
const dim_t,
dim_t*,
dim_t*
);
void packb_nr32_u8s8s16o16
(
int8_t *pack_b_buffer_u8s8s16o16,
const int8_t *b,
const dim_t ldb,
const dim_t cols,
const dim_t rows,
dim_t *rs_b,
dim_t *cs_b
);
#endif // BLIS_GEMM_INT16_PACKB

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,412 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <immintrin.h>
#include "blis.h"
#ifdef BLIS_ADDON_LPGEMM
void packb_nrlt16_s8s8s16o16
(
int8_t *pack_b_buffer_s8s8s16o16,
int16_t *pack_b_column_sum,
const int8_t *b,
const dim_t ldb,
const dim_t rows,
dim_t n0_partial_rem
)
{
dim_t k_full_pieces_blks = rows / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = rows % 2;
dim_t NR = 16;
dim_t kr_new = 0;
int8_t buf0[16], buf1[16];
__m128i b_vec[2], inter_vec[2];
__m256i sum1;
__m256i temp1;
__m256 temp2, temp3;
//load the temp buffer to compute column sum of B matrix
sum1 = _mm256_loadu_si256( (__m256i const *)(pack_b_column_sum) );
for (dim_t kr = 0; kr < k_full_pieces; kr += 2)
{
memcpy(buf0, (b + (ldb * (kr + 0))), (n0_partial_rem * sizeof(int8_t)));
memcpy(buf1, (b + (ldb * (kr + 1))), (n0_partial_rem * sizeof(int8_t)));
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
b_vec[0] = _mm_loadu_si128((__m128i *)buf0);
// Read b[1,0], b[1,1], b[1,2]......., b[1,15]
b_vec[1] = _mm_loadu_si128((__m128i *)buf1);
//compute sum1 to compute B matrix column sum
temp1 =
_mm256_add_epi16( _mm256_cvtepi8_epi16( b_vec[0] ), _mm256_cvtepi8_epi16( b_vec[1] ));
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
sum1 = _mm256_add_epi16 (sum1, temp1);
// Reorder B matrix inputs to suit vpmaddubsw instructions
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
// Store b[0,0], b[1,0], b[0,1]......., b[0,7], b[1,7]
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + (kr_new * NR)), inter_vec[0]);
// Store b[0,8], b[1,8], b[0,9]......., b[0,15], b[1,15]
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
// Increment to ignore the padded bits
kr_new += 2;
}
// Handle k partial cases
if (k_partial_pieces > 0)
{
memcpy(buf0, (b + (ldb * (k_full_pieces + 0))), (n0_partial_rem * sizeof(int8_t)));
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
b_vec[0] = _mm_loadu_si128((__m128i *)buf0);
b_vec[1] = _mm_setzero_si128(); // Initialize with zero for padding
//compute sum1 to compute B matrix column sum
temp1 = ( _mm256_cvtepi8_epi16( b_vec[0] ));
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
sum1 = _mm256_add_epi16 (sum1, temp1);
// Reorder B matrix inputs to suit vpmaddubsw instructions
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
// Store b[0,0], 0, b[0,1]......., b[0,7], 0
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]);
// Store b[0,8], 0, b[0,9]......., b[0,15], 0
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
}
//store the sum column
_mm256_storeu_si256( (__m256i *)(pack_b_column_sum), sum1 );
}
void packb_nr16_s8s8s16o16(
int8_t *pack_b_buffer_s8s8s16o16,
int16_t *pack_b_column_sum,
const int8_t *b,
const dim_t ldb,
const dim_t rows)
{
dim_t k_full_pieces_blks = rows / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = rows % 2;
dim_t NR = 16;
dim_t kr_new = 0;
__m128i b_vec[2], inter_vec[2];
__m256i sum1;
__m256i temp1;
__m256 temp2, temp3;
//load the temp buffer to compute column sum of B matrix
sum1 = _mm256_loadu_si256( (__m256i const *)(pack_b_column_sum) );
for (dim_t kr = 0; kr < k_full_pieces; kr += 2)
{
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
b_vec[0] = _mm_loadu_si128((__m128i const *)(b + (ldb * (kr + 0))));
// Read b[1,0], b[1,1], b[1,2]......., b[1,15]
b_vec[1] = _mm_loadu_si128((__m128i const *)(b + (ldb * (kr + 1))));
//compute sum1 to compute B matrix column sum
temp1 =
_mm256_add_epi16( _mm256_cvtepi8_epi16( b_vec[0] ), _mm256_cvtepi8_epi16( b_vec[1] ));
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
sum1 = _mm256_add_epi16 (sum1, temp1);
// Reorder B matrix inputs to suit vpmaddubsw instructions
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
// Store b[0,0], b[1,0], b[0,1]......., b[0,7], b[1,7]
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]);
// Store b[0,8], b[1,8], b[0,9]......., b[0,15], b[1,15]
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
// Increment to ignore the padded bits
kr_new += 2;
}
if (k_partial_pieces > 0)
{
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
b_vec[0] = _mm_loadu_si128((__m128i const *)(b + (ldb * (k_full_pieces + 0))));
b_vec[1] = _mm_setzero_si128(); // Initialize with zero for padding
//compute sum1
temp1 = ( _mm256_cvtepi8_epi16( b_vec[0] ));
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
sum1 = _mm256_add_epi16 (sum1, temp1);
// Reorder B matrix inputs to suit vpmaddubsw instructions
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
// Store b[0,0], 0, b[0,1]......., b[0,7], 0
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]);
// Store b[0,8], 0, b[0,9]......., b[0,15], 0
_mm_storeu_si128((__m128i *)(pack_b_buffer_s8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
}
//store the sum column
_mm256_storeu_si256( (__m256i *)(pack_b_column_sum), sum1 );
}
void packb_nr32_s8s8s16o16(
int8_t *pack_b_buffer_s8s8s16o16,
int16_t *pack_b_column_sum,
const int8_t *b,
const dim_t ldb,
const dim_t cols,
const dim_t rows,
dim_t *rs_b,
dim_t *cs_b)
{
dim_t NR = 32;
dim_t n_full_pieces = cols / NR;
dim_t n_full_pieces_loop_limit = n_full_pieces * NR;
dim_t n_partial_pieces = cols % NR;
dim_t k_full_pieces_blks = rows / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = rows % 2;
dim_t KC_updated = rows;
// Making multiple of 2 to suit k in vpmaddubsw
KC_updated += (KC_updated & 0x1);
//to compute column sum of B matrix
__m256i sum1, sum2;
__m256i temp1;
__m256 temp2, temp3;
__m256i b_vec[2], inter_vec[2];
for (dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR)
{
//load the temp buffer to compute column sum of B matrix
sum1 = _mm256_loadu_si256( (__m256i const *)(pack_b_column_sum + jc) );
sum2 = _mm256_loadu_si256( (__m256i const *)(pack_b_column_sum + 16 + jc) );
for (dim_t kr = 0; kr < k_full_pieces; kr += 2)
{
// Read b[0,0], b[0,1], b[0,2]......., b[0,31]
b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 0)) + jc));
// Read b[1,0], b[1,1], b[1,2]......., b[1,31]
b_vec[1] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 1)) + jc));
//add all the columns : sum = add (sum, a0, b0)
//compute sum1 and sum2 to compute B matrix column sum
temp1 =
_mm256_add_epi16( _mm256_cvtepi8_epi16( _mm256_extractf128_si256( b_vec[0], 0 )),
_mm256_cvtepi8_epi16( _mm256_extractf128_si256( b_vec[1], 0 )));
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
sum1 = _mm256_add_epi16 (sum1, temp1);
//compute sum2
temp1 =
_mm256_add_epi16( _mm256_cvtepi8_epi16( _mm256_extractf128_si256( b_vec[0], 1 )),
_mm256_cvtepi8_epi16( _mm256_extractf128_si256( b_vec[1], 1 )));
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
sum2 = _mm256_add_epi16 (sum2, temp1);
// Reorder B matrix inputs to suit vpmaddubsw instructions
inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]);
inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]);
b_vec[0] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x20);
b_vec[1] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x31);
// Store B[0,0], B[1,0], B[0,1], B[1,1], ......, B[0,15], B[1,15]
_mm256_storeu_si256((__m256i *)(pack_b_buffer_s8s8s16o16 + ((jc * KC_updated) + (kr * NR))), b_vec[0]);
// Store B[0,16], B[1,16], B[0,17], B[1,17], ......, B[0,31], B[1,31]
_mm256_storeu_si256((__m256i *)(pack_b_buffer_s8s8s16o16 + ((jc * KC_updated) + ((kr + 1) * NR))), b_vec[1]);
}
if (k_partial_pieces > 0)
{
// Read b[0,0], b[0,1], b[0,2]......., b[0,31]
b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (k_full_pieces + 0)) + jc));
b_vec[1] = _mm256_setzero_si256(); // Initialize with zero for padding
//compute sum1
temp1 = _mm256_cvtepi8_epi16( _mm256_extractf128_si256( b_vec[0], 0 ));
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
sum1 = _mm256_add_epi16 (sum1, temp1);
//compute sum2
temp1 = _mm256_cvtepi8_epi16( _mm256_extractf128_si256( b_vec[0], 1 ));
temp2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 0)));
temp2 = _mm256_mul_ps(temp2, _mm256_set1_ps (128));
temp3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extractf128_si256(temp1, 1)));
temp3 = _mm256_mul_ps(temp3, _mm256_set1_ps (128));
temp1 = _mm256_packs_epi32(_mm256_cvtps_epi32(temp2), _mm256_cvtps_epi32(temp3));
temp1 = _mm256_permute4x64_epi64(temp1, 0XD8);
sum2 = _mm256_add_epi16 (sum2, temp1);
// Reorder B matrix inputs to suit vpmaddubsw instructions
inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]);
inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]);
b_vec[0] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x20);
b_vec[1] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x31);
// Store B[0,0], B[1,0], B[0,1], B[1,1], ......, B[0,15], B[1,15]
_mm256_storeu_si256((__m256i *)(pack_b_buffer_s8s8s16o16 + ((jc * KC_updated) + (k_full_pieces * NR))), b_vec[0]);
// Store B[0,16], B[1,16], B[0,17], B[1,17], ......, B[0,31], B[1,31]
_mm256_storeu_si256((__m256i *)(pack_b_buffer_s8s8s16o16 + ((jc * KC_updated) + ((k_full_pieces + 1) * NR))), b_vec[1]);
}
//store the sum column
_mm256_storeu_si256( (__m256i *)(pack_b_column_sum + jc), sum1 );
_mm256_storeu_si256( (__m256i *)(pack_b_column_sum + 16 + jc), sum2 );
}
// B matrix packing when n < NR
if (n_partial_pieces > 0)
{
// Split into multiple smaller fringe kernels, so as to maximize
// vectorization after packing. Any n0 < NR(32) can be expressed
// as n0 = 16 + n`.
dim_t n0_16 = n_partial_pieces / 16;
dim_t n0_partial_rem = n_partial_pieces % 16;
dim_t n0_partial_pack = 0;
if (n0_16 == 1)
{
packb_nr16_s8s8s16o16(
(pack_b_buffer_s8s8s16o16 +
(n_full_pieces_loop_limit * KC_updated)),
( pack_b_column_sum + ( n_full_pieces_loop_limit ) ),
(b + n_full_pieces_loop_limit), ldb, rows);
n0_partial_pack = 16;
}
if (n0_partial_rem > 0)
{
packb_nrlt16_s8s8s16o16(
(pack_b_buffer_s8s8s16o16 + (n_full_pieces_loop_limit * KC_updated) +
(n0_partial_pack * KC_updated)),
( pack_b_column_sum + n_full_pieces_loop_limit + n0_partial_pack ),
(b + n_full_pieces_loop_limit + n0_partial_pack),
ldb, rows, n0_partial_rem);
}
}
*rs_b = NR * 2;
*cs_b = NR;
}
#endif

View File

@@ -1,877 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <immintrin.h>
#include "blis.h"
#ifdef BLIS_ADDON_LPGEMM
#include "../u8s8s16/lpgemm_s16_kern_macros.h"
#define LPGEMV_N_KERNEL_2_LOADS( ymm0, ymm1, paddr, stride ) \
ymm0 = _mm256_loadu_si256( (__m256i const *)paddr ); \
ymm1 = _mm256_loadu_si256( (__m256i const *)(paddr + stride) ); \
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 ); \
ymm1 = _mm256_add_epi8( ymm1, vec_uint8 );
#define LPGEMV_N_KERNEL_2_FMA( a_reg1, a_reg2, b_reg, \
inter_reg1, inter_reg2, c_reg1, c_reg2 ) \
inter_reg1 = _mm256_maddubs_epi16(a_reg1, b_reg); \
c_reg1 = _mm256_add_epi16(inter_reg1, c_reg1); \
inter_reg2 = _mm256_maddubs_epi16(a_reg2, b_reg); \
c_reg2 = _mm256_add_epi16(inter_reg2, c_reg2);
#define LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, paddr, stride ) \
ymm0 = _mm256_loadu_si256( (__m256i const *)(paddr) ); \
ymm1 = _mm256_loadu_si256( (__m256i const *)(paddr + stride) ); \
ymm2 = _mm256_loadu_si256( (__m256i const *)(paddr + 2 * stride) ); \
ymm3 = _mm256_loadu_si256( (__m256i const *)(paddr + 3 * stride) ); \
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 ); \
ymm1 = _mm256_add_epi8( ymm1, vec_uint8 ); \
ymm2 = _mm256_add_epi8( ymm2, vec_uint8 ); \
ymm3 = _mm256_add_epi8( ymm3, vec_uint8 );
#define LPGEMV_N_KERNEL_4_FMA( a_reg1, a_reg2, a_reg3, a_reg4, b_reg, \
inter_reg1, inter_reg2, \
inter_reg3, inter_reg4, \
out_reg1, out_reg2, out_reg3, out_reg4 ) \
inter_reg1 = _mm256_maddubs_epi16(a_reg1, b_reg); \
out_reg1 = _mm256_add_epi16(inter_reg1, out_reg1); \
inter_reg2 = _mm256_maddubs_epi16(a_reg2, b_reg); \
out_reg2 = _mm256_add_epi16(inter_reg2, out_reg2); \
inter_reg3 = _mm256_maddubs_epi16(a_reg3, b_reg); \
out_reg3 = _mm256_add_epi16(inter_reg3, out_reg3); \
inter_reg4 = _mm256_maddubs_epi16(a_reg4, b_reg); \
out_reg4 = _mm256_add_epi16(inter_reg4, out_reg4);
#define LPGEMV_YMM2XMM( ymm0, ymm1, ymm2, ymm3, xmm0 ) \
ymm0 = _mm256_hadd_epi16( ymm0, ymm1 ); \
ymm1 = _mm256_hadd_epi16( ymm2, ymm3 ); \
ymm0 = _mm256_hadd_epi16( ymm0, ymm1 ); \
xmm0 = _mm_add_epi16( _mm256_extracti128_si256( ymm0, 0 ), \
_mm256_extracti128_si256( ymm0, 1 ) );
LPGEMV_N_EQ1_KERN(int8_t, int8_t, int16_t, s8s8s16os16)
{
static void* post_ops_labels[] =
{
&&POST_OPS_DISABLE,
&&POST_OPS_BIAS,
&&POST_OPS_RELU,
&&POST_OPS_RELU_SCALE,
&&POST_OPS_GELU_TANH,
&&POST_OPS_GELU_ERF,
&&POST_OPS_CLIP,
&&POST_OPS_DOWNSCALE,
&&POST_OPS_MATRIX_ADD,
&&POST_OPS_SWISH,
NULL,// Virtual node for matrix_mul, else segfault
&&POST_OPS_TANH,
&&POST_OPS_SIGMOID
};
int8_t *a_use = NULL;
int8_t *b_use = NULL;
int16_t *c_use = NULL;
lpgemm_post_op_attr post_ops_attr = *(post_op_attr);
// temp buffer to store output C vector
int16_t ctemp[16];
// temp buffers to store a, b data in k_rem case.
int8_t buf0[32] = {0};
int8_t buf1[32] = {0};
int8_t buf2[32] = {0};
int8_t buf3[32] = {0};
int8_t buf4[32] = {0};
int8_t buf5[32] = {0};
int8_t buf6[32] = {0};
int8_t buf7[32] = {0};
int8_t buf8[32] = {0};
uint8_t cvt_uint8 = 128;
__m256i vec_uint8;
int16_t* bsumptr = post_ops_attr.b_col_sum_vec_s16;
for ( dim_t ir = 0; ir < m0; ir += MR )
{
dim_t mr0 = bli_min( ( m0 - ir ), MR );
dim_t k_iter = k / 32;
dim_t k_rem = k % 32;
__m256i ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7;
__m256i ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14;
__m256i ymm15;
__m128i xmm0, xmm1;
/* zero the accumulator registers */
ZERO_ACC_YMM_4_REG( ymm8, ymm9, ymm10, ymm11 )
ZERO_ACC_YMM_4_REG( ymm12, ymm13, ymm14, ymm15 )
//update pointers
a_use = (int8_t*)a + ir * rs_a;
b_use = (int8_t*)b;
c_use = (int16_t*)c + ir * rs_c;
if( mr0 == MR )
{
vec_uint8 = _mm256_set1_epi8 (cvt_uint8);
for (dim_t k = 0; k < k_iter; k++)
{
ymm6 = _mm256_loadu_si256( (__m256i const *)(b_use) );
b_use += 32;
//Load 4x32 elements from row0-row3 of A
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, a_use, rs_a )
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
ymm6, ymm4, ymm5, ymm7, ymm4,
ymm8, ymm9, ymm10, ymm11
)
// Load 4x32 elements from row8-row11 of A
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3,
( a_use + 4 * rs_a ), rs_a
)
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
ymm6, ymm4, ymm5, ymm7, ymm4,
ymm12, ymm13, ymm14, ymm15
)
a_use += 32;
}
if( k_rem )
{
uint8_t buf_vec_uint8_t[32] = {0};
int8_t* restrict a0 = (a_use);
int8_t* restrict a1 = (a_use + rs_a );
int8_t* restrict a2 = (a_use + 2 * rs_a );
int8_t* restrict a3 = (a_use + 3 * rs_a );
int8_t* restrict a4 = (a_use + 4 * rs_a );
int8_t* restrict a5 = (a_use + 5 * rs_a );
int8_t* restrict a6 = (a_use + 6 * rs_a );
int8_t* restrict a7 = (a_use + 7 * rs_a );
for( dim_t i = 0; i < k_rem; i++)
{
buf8[i] = b_use[i];
buf0[i] = a0[i];
buf1[i] = a1[i];
buf2[i] = a2[i];
buf3[i] = a3[i];
buf4[i] = a4[i];
buf5[i] = a5[i];
buf6[i] = a6[i];
buf7[i] = a7[i];
buf_vec_uint8_t[i] = cvt_uint8;
}
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
vec_uint8 = _mm256_loadu_si256( ( __m256i const *) buf_vec_uint8_t );
//Load 4x32 elements from row0-row3 of A
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 );
ymm2 = _mm256_loadu_si256( (__m256i const *)buf2 );
ymm3 = _mm256_loadu_si256( (__m256i const *)buf3 );
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 );
ymm1 = _mm256_add_epi8( ymm1, vec_uint8 );
ymm2 = _mm256_add_epi8( ymm2, vec_uint8 );
ymm3 = _mm256_add_epi8( ymm3, vec_uint8 );
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
ymm6, ymm4, ymm5, ymm7, ymm4,
ymm8, ymm9, ymm10, ymm11
)
// Load 4x32 elements from row8-row11 of A
ymm0 = _mm256_loadu_si256( (__m256i const *)buf4 );
ymm1 = _mm256_loadu_si256( (__m256i const *)buf5 );
ymm2 = _mm256_loadu_si256( (__m256i const *)buf6 );
ymm3 = _mm256_loadu_si256( (__m256i const *)buf7 );
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 );
ymm1 = _mm256_add_epi8( ymm1, vec_uint8 );
ymm2 = _mm256_add_epi8( ymm2, vec_uint8 );
ymm3 = _mm256_add_epi8( ymm3, vec_uint8 );
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
ymm6, ymm4, ymm5, ymm7, ymm4,
ymm12, ymm13, ymm14, ymm15
)
}
//Add the registers horizantally to get one
LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, xmm0 )
LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, xmm1 )
xmm0 = _mm_hadd_epi16( xmm0, xmm1 );
// post ops are applied on ymm register though
// second half of the register is filled with zeroes.
ymm8 = _mm256_setzero_si256();
ymm8 = _mm256_inserti128_si256( ymm8, xmm0, 0);
ymm0 = _mm256_set1_epi16( *bsumptr );
ymm8 = _mm256_sub_epi16( ymm8, ymm0 );
}
else
{
int8_t *a_use_fringe = a_use;
dim_t mr0_use = mr0;
dim_t regidx = 0;
if( mr0_use >= 4 )
{
vec_uint8 = _mm256_set1_epi8 (cvt_uint8);
for (dim_t k = 0; k < k_iter; k++)
{
ymm6 = _mm256_loadu_si256( (__m256i const *)b_use );
b_use += 32;
//Load 4x32 elements from row0-row3 of A
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3,
a_use, rs_a )
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
ymm6, ymm4, ymm5, ymm7, ymm4,
ymm8, ymm9, ymm10, ymm11
)
a_use += 32;
}
if( k_rem )
{
uint8_t buf_vec_uint8_t[32] = {0};
int8_t* restrict a0 = (a_use);
int8_t* restrict a1 = (a_use + rs_a );
int8_t* restrict a2 = (a_use + 2 * rs_a );
int8_t* restrict a3 = (a_use + 3 * rs_a );
for( dim_t i = 0; i < k_rem; i++)
{
buf8[i] = b_use[i];
buf0[i] = a0[i];
buf1[i] = a1[i];
buf2[i] = a2[i];
buf3[i] = a3[i];
buf_vec_uint8_t[i] = cvt_uint8;
}
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
vec_uint8 = _mm256_loadu_si256( (__m256i const *)buf_vec_uint8_t );
//Load 4xk_rem elements from row0-row3 of A
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 );
ymm2 = _mm256_loadu_si256( (__m256i const *)buf2 );
ymm3 = _mm256_loadu_si256( (__m256i const *)buf3 );
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 );
ymm1 = _mm256_add_epi8( ymm1, vec_uint8 );
ymm2 = _mm256_add_epi8( ymm2, vec_uint8 );
ymm3 = _mm256_add_epi8( ymm3, vec_uint8 );
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
ymm6, ymm4, ymm5, ymm7, ymm4,
ymm8, ymm9, ymm10, ymm11
)
}
//update pointers
mr0_use -= 4;
a_use = a_use_fringe + 4 * rs_a;
a_use_fringe = a_use;
b_use = (int8_t*)b;
//Add the registers horizantally to get one
LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, xmm0 )
xmm0 = _mm_hadd_epi16( xmm0, xmm0 );
int64_t data = _mm_extract_epi64( xmm0, 0);
//insert xmm outputs into final output reg based on regidx
ymm8 = _mm256_setzero_si256();
ymm8 = _mm256_insert_epi64( ymm8, data, 0 );
regidx++;
}
// Dot product for <= 3
if ( mr0_use )
{
// Dot product for m = 2
if ( mr0_use >= 2 )
{
vec_uint8 = _mm256_set1_epi8 (cvt_uint8);
for ( dim_t k = 0; k < k_iter; k++ )
{
// Load 0-31 in b[k+0 - k+31]
ymm6 = _mm256_loadu_si256( (__m256i const *)b_use );
LPGEMV_N_KERNEL_2_LOADS( ymm0, ymm1, a_use, rs_a);
LPGEMV_N_KERNEL_2_FMA( ymm0, ymm1, ymm6, ymm4,
ymm5, ymm12, ymm13);
b_use += 32; // move b pointer to next 32 elements
a_use += 32;
}
if ( k_rem )
{
uint8_t buf_vec_uint8_t[32] = {0};
int8_t* restrict a0 = (a_use);
int8_t* restrict a1 = (a_use + rs_a );
for( dim_t i = 0; i < k_rem; i++)
{
buf8[i] = b_use[i];
buf0[i] = a0[i];
buf1[i] = a1[i];
buf_vec_uint8_t[i] = cvt_uint8;
}
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
vec_uint8 = _mm256_loadu_si256( (__m256i const *)buf_vec_uint8_t );
//Load 2xk_rem elements from row0-row3 of A
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 );
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 );
ymm1 = _mm256_add_epi8( ymm1, vec_uint8 );
LPGEMV_N_KERNEL_2_FMA( ymm0, ymm1, ymm6,
ymm4, ymm5, ymm12, ymm13 );
}
mr0_use -= 2;
a_use = a_use_fringe + 2 * rs_a;
a_use_fringe = a_use;
b_use = (int8_t*)b;
}
// Dot product for m = 1
if ( mr0_use == 1 )
{
vec_uint8 = _mm256_set1_epi8 (cvt_uint8);
for ( dim_t k = 0; k < k_iter; k++ )
{
// Load 0-31 in b[k+0 - k+31]
ymm6 = _mm256_loadu_si256( (__m256i const *)b_use );
// Load 1x32 elements from row0-row1 of A
ymm0 = _mm256_loadu_si256( (__m256i const *)a_use );
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 );
ymm4 = _mm256_maddubs_epi16(ymm0, ymm6);
ymm14 = _mm256_add_epi16(ymm4, ymm14);
b_use += 32; // move b pointer to next 32 elements
a_use += 32;
}
if ( k_rem )
{
uint8_t buf_vec_uint8_t[32] = {0};
int8_t* restrict a0 = (a_use);
for( dim_t i = 0; i < k_rem; i++)
{
buf8[i] = b_use[i];
buf0[i] = a0[i];
buf_vec_uint8_t[i] = cvt_uint8;
}
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
vec_uint8 = _mm256_loadu_si256( (__m256i const *)buf_vec_uint8_t );
//Load 1xk_rem elements from row0-row3 of A
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
ymm0 = _mm256_add_epi8( ymm0, vec_uint8 );
ymm4 = _mm256_maddubs_epi16(ymm0, ymm6);
ymm14 = _mm256_add_epi16(ymm4, ymm14);
}
// When only fringe 1,
// update the registers to store in order
if ( !( mr0 & 0x2 ) ) ymm12 = ymm14;
}
LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, xmm0)
xmm0 = _mm_hadd_epi16( xmm0, xmm0 );
int64_t data = _mm_extract_epi64( xmm0, 0);
//insert xmm outputs into final output reg based on regidx
if( regidx == 0 )
{
ymm8 = _mm256_insert_epi64( ymm8, data, 0 );
}
else
{
ymm8 = _mm256_insert_epi64( ymm8, data, 1 );
}
}
int16_t buf_vec_int16_t[16] = {0};
for( dim_t i = 0; i < mr0; i++)
buf_vec_int16_t[i] = *bsumptr;
ymm0 = _mm256_loadu_si256( ( __m256i const *) buf_vec_int16_t);
ymm8 = _mm256_sub_epi16( ymm8, ymm0 );
}
// Load alpha and beta
__m256i selector1 = _mm256_set1_epi16(alpha);
__m256i selector2 = _mm256_set1_epi16(beta);
// Scale by alpha
ymm8 = _mm256_mullo_epi16(selector1, ymm8);
if( beta != 0 )
{
if ( post_ops_attr.buf_downscale != NULL )
{
if( post_ops_attr.rs_c_downscale == 1 )
{
if( post_ops_attr.c_stor_type == S8 )
{
dim_t m0_rem_dscale_bytes = mr0 * sizeof( int8_t );
S8_S16_BETA_NLT16_MEMCP_UTIL( ctemp, 0,
m0_rem_dscale_bytes );
S8_S16_BETA_OP_NLT16( ymm8, ctemp,
selector1, selector2 )
}
else if( post_ops_attr.c_stor_type == U8 )
{
dim_t m0_rem_dscale_bytes = mr0 * sizeof( uint8_t );
U8_S16_BETA_NLT16_MEMCP_UTIL( ctemp, 0,
m0_rem_dscale_bytes );
U8_S16_BETA_OP_NLT16( ymm8, ctemp,
selector1, selector2 )
}
}
else
{
if( post_ops_attr.c_stor_type == S8 )
{
int8_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( (int8_t*)post_ops_attr.buf_downscale
+ ( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) );
}
selector1 = _mm256_cvtepi8_epi32
( _mm_loadu_si128( (__m128i const*)ctemp ) );
S16_BETA_FMA( ymm8, selector1, selector2 );
}
else if( post_ops_attr.c_stor_type == U8 )
{
uint8_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( (uint8_t*)post_ops_attr.buf_downscale
+ ( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) );
}
selector1 = _mm256_cvtepu8_epi32
( _mm_loadu_si128( (__m128i const*)ctemp ) );
S16_BETA_FMA( ymm8, selector1, selector2 );
}
}
}
else
{
if( rs_c == 1 )
{
dim_t m0_rem_bytes = mr0 * sizeof( int16_t );
memcpy( ctemp, c_use, m0_rem_bytes );
S16_S16_BETA_OP_NLT16( ymm8, ctemp,
selector1, selector2 )
}
else
{
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = c_use[ i * rs_c ];
}
selector1 = _mm256_loadu_si256( (__m256i const *)ctemp );
S16_BETA_FMA( ymm8, selector1, selector2 );
}
}
}
// Post Ops
lpgemm_post_op * post_ops_list_temp = post_op;
post_ops_attr.is_last_k = TRUE;
POST_OP_LABEL_LASTK_SAFE_JUMP
POST_OPS_BIAS:
{
selector1 =
_mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args1) );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_RELU:
{
selector1 = _mm256_setzero_si256();
ymm8 = _mm256_max_epi16( selector1, ymm8 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_RELU_SCALE:
{
__m256i b0;
selector1 = _mm256_setzero_si256();
selector2 = _mm256_set1_epi16(
*( ( int16_t* )post_ops_list_temp->op_args2 ) );
RELU_SCALE_OP_S16_AVX2( ymm8 )
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_GELU_TANH:
{
__m256 dn, z, x, r2, r, y1, y2, x_tanh;
__m256i q;
GELU_TANH_S16_AVX2( ymm8, y1, y2, r, r2, x, z, dn, x_tanh, q )
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_GELU_ERF:
{
__m256 x, r, y1, y2, x_erf;
GELU_ERF_S16_AVX2(ymm8, y1, y2, r, x, x_erf)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_CLIP:
{
__m256i min = _mm256_set1_epi16(
*( int16_t* )post_ops_list_temp->op_args2 );
__m256i max = _mm256_set1_epi16(
*( int16_t* )post_ops_list_temp->op_args3 );
CLIP_S16_AVX2(ymm8, min, max)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_DOWNSCALE:
{
__m128i temp[2];
__m256i temp_32[2];
__m256 temp_float[2];
__m256 scale_1 = _mm256_setzero_ps();
__m256 scale_2 = _mm256_setzero_ps();
__m128i _zero_point_0 = _mm_setzero_si128();
__m256i zero_point_0 = _mm256_setzero_si256();
__m256 res_1, res_2;
scale_1 =
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
scale_2 =
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
_zero_point_0 = _mm_set1_epi8(
*( ( int8_t* )post_ops_list_temp->op_args1 ) );
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}
// Scale first 16 columns of the 2 rows.
CVT_MULRND_CVT16(ymm8, scale_1, scale_2, zero_point_0)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_MATRIX_ADD:
{
dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3;
if ( post_ops_attr.c_stor_type == S8 )
{
int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1;
if( ldm == 1 )
{
memcpy
(
( int8_t* )ctemp,
matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) +
post_ops_attr.post_op_c_j + ( 0 * 16 ),
( mr0 ) * sizeof(int8_t)
);
selector1 = _mm256_cvtepi8_epi16(
_mm_loadu_si128( ( __m128i const* )ctemp ) );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
}
else
{
int8_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( matptr +
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm256_cvtepi8_epi16
( _mm_loadu_si128( (__m128i const*)ctemp ) );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
}
}
else if ( post_ops_attr.c_stor_type == U8 )
{
uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1;
if( ldm == 1 )
{
memcpy
(
( uint8_t* )ctemp,
matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) +
post_ops_attr.post_op_c_j + ( 0 * 16 ),
( mr0 ) * sizeof(uint8_t)
);
selector1 = _mm256_cvtepu8_epi16(
_mm_loadu_si128( ( __m128i const* )ctemp ) );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
}
else
{
uint8_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( matptr +
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm256_cvtepu8_epi16
( _mm_loadu_si128( (__m128i const*)ctemp ) );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
}
}
else
{
int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1;
if( ldm == 1 )
{
memcpy
(
( int16_t* )ctemp,
matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) +
post_ops_attr.post_op_c_j + ( 0 * 16 ),
( mr0 ) * sizeof(int16_t)
);
selector1 = _mm256_loadu_si256( ( __m256i const* )ctemp );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
}
else
{
int16_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( matptr +
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm256_loadu_si256( (__m256i const *)ctemp );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_SWISH:
{
selector1 =
_mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) );
__m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \
_mm256_extractf128_si256( selector1, 0 ) ) );
__m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn;
__m256i ex_out;
SWISH_S16_AVX2( ymm8, al, al_in, tmp_reg1,
tmp_reg2, r, r2, z, dn, ex_out );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_TANH:
{
__m256 dn, z, x, r2, r, y1, y2;
__m256i q;
TANH_S16_AVX2( ymm8, y1, y2, r, r2, x, z, dn, q )
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_SIGMOID:
{
__m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn;
__m256i ex_out;
SIGMOID_S16_AVX2( ymm8, al_in, tmp_reg1,
tmp_reg2, r, r2, z, dn, ex_out );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_DISABLE:
{
if ( post_ops_attr.buf_downscale != NULL )
{
__m128i temp[2];
__m256i zero_reg = _mm256_setzero_si256();
if( post_ops_attr.rs_c_downscale == 1 )
{
if( post_ops_attr.c_stor_type == S8 )
{
// Store the results in downscaled type
// (int8 instead of int16).
CVT_STORE_S16_S8_1ROW_NLT16(ymm8, zero_reg, ctemp);
dim_t m0_rem_dscale_bytes = mr0 * sizeof( int8_t );
CVT_STORE_S16_S8_NLT16_MEMCP_UTIL( ctemp, 0,
m0_rem_dscale_bytes);
}
else if( post_ops_attr.c_stor_type == U8 )
{
// Store the results in downscaled type (uint8 instead of int16).
CVT_STORE_S16_U8_1ROW_NLT16(ymm8, zero_reg, ctemp);
dim_t m0_rem_dscale_bytes = mr0 * sizeof( uint8_t );
CVT_STORE_S16_U8_NLT16_MEMCP_UTIL( ctemp, 0,
m0_rem_dscale_bytes);
}
}
else
{
if( post_ops_attr.c_stor_type == S8 )
{
int8_t ctemp[16];
CVT_STORE_S16_S8_1ROW_NLT16(ymm8, zero_reg, ctemp);
for( dim_t i = 0; i < mr0; i++ )
{
*( ( int8_t* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
}
}
else if( post_ops_attr.c_stor_type == U8 )
{
uint8_t ctemp[16];
CVT_STORE_S16_U8_1ROW_NLT16(ymm8, zero_reg, ctemp);
for( dim_t i = 0; i < mr0; i++ )
{
*( ( uint8_t* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
}
}
}
}
else
{
if( rs_c == 1 )
{
_mm256_storeu_si256( ( __m256i* )ctemp, ymm8 );
dim_t m0_rem_bytes = mr0 * sizeof( int16_t );
memcpy( c_use, ctemp, m0_rem_bytes );
}
else
{
_mm256_storeu_si256( ( __m256i* )ctemp, ymm8 );
for( dim_t i = 0; i < mr0; i++ )
{
c_use[i * rs_c] = ctemp[i];
}
}
}
post_ops_attr.post_op_c_i += MR;
}
}
}
#endif

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,258 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <immintrin.h>
#include "blis.h"
#ifdef BLIS_ADDON_LPGEMM
void packb_nrlt16_u8s8s16o16
(
int8_t *pack_b_buffer_u8s8s16o16,
const int8_t *b,
const dim_t ldb,
const dim_t rows,
dim_t n0_partial_rem
)
{
dim_t k_full_pieces_blks = rows / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = rows % 2;
dim_t NR = 16;
dim_t kr_new = 0;
int8_t buf0[16], buf1[16];
__m128i b_vec[2], inter_vec[2];
for (dim_t kr = 0; kr < k_full_pieces; kr += 2)
{
memcpy(buf0, (b + (ldb * (kr + 0))), (n0_partial_rem * sizeof(int8_t)));
memcpy(buf1, (b + (ldb * (kr + 1))), (n0_partial_rem * sizeof(int8_t)));
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
b_vec[0] = _mm_loadu_si128((__m128i *)buf0);
// Read b[1,0], b[1,1], b[1,2]......., b[1,15]
b_vec[1] = _mm_loadu_si128((__m128i *)buf1);
// Reorder B matrix inputs to suit vpmaddubsw instructions
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
// Store b[0,0], b[1,0], b[0,1]......., b[0,7], b[1,7]
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + (kr_new * NR)), inter_vec[0]);
// Store b[0,8], b[1,8], b[0,9]......., b[0,15], b[1,15]
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
// Increment to ignore the padded bits
kr_new += 2;
}
// Handle k partial cases
if (k_partial_pieces > 0)
{
memcpy(buf0, (b + (ldb * (k_full_pieces + 0))), (n0_partial_rem * sizeof(int8_t)));
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
b_vec[0] = _mm_loadu_si128((__m128i *)buf0);
b_vec[1] = _mm_setzero_si128(); // Initialize with zero for padding
// Reorder B matrix inputs to suit vpmaddubsw instructions
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
// Store b[0,0], 0, b[0,1]......., b[0,7], 0
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]);
// Store b[0,8], 0, b[0,9]......., b[0,15], 0
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
}
}
void packb_nr16_u8s8s16o16(
int8_t *pack_b_buffer_u8s8s16o16,
const int8_t *b,
const dim_t ldb,
const dim_t rows)
{
dim_t k_full_pieces_blks = rows / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = rows % 2;
dim_t NR = 16;
dim_t kr_new = 0;
__m128i b_vec[2], inter_vec[2];
for (dim_t kr = 0; kr < k_full_pieces; kr += 2)
{
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
b_vec[0] = _mm_loadu_si128((__m128i const *)(b + (ldb * (kr + 0))));
// Read b[1,0], b[1,1], b[1,2]......., b[1,15]
b_vec[1] = _mm_loadu_si128((__m128i const *)(b + (ldb * (kr + 1))));
// Reorder B matrix inputs to suit vpmaddubsw instructions
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
// Store b[0,0], b[1,0], b[0,1]......., b[0,7], b[1,7]
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]);
// Store b[0,8], b[1,8], b[0,9]......., b[0,15], b[1,15]
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
// Increment to ignore the padded bits
kr_new += 2;
}
if (k_partial_pieces > 0)
{
// Read b[0,0], b[0,1], b[0,2]......., b[0,15]
b_vec[0] = _mm_loadu_si128((__m128i const *)(b + (ldb * (k_full_pieces + 0))));
b_vec[1] = _mm_setzero_si128(); // Initialize with zero for padding
// Reorder B matrix inputs to suit vpmaddubsw instructions
inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]);
inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]);
// Store b[0,0], 0, b[0,1]......., b[0,7], 0
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]);
// Store b[0,8], 0, b[0,9]......., b[0,15], 0
_mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]);
}
}
void packb_nr32_u8s8s16o16(
int8_t *pack_b_buffer_u8s8s16o16,
const int8_t *b,
const dim_t ldb,
const dim_t cols,
const dim_t rows,
dim_t *rs_b,
dim_t *cs_b)
{
dim_t NR = 32;
dim_t n_full_pieces = cols / NR;
dim_t n_full_pieces_loop_limit = n_full_pieces * NR;
dim_t n_partial_pieces = cols % NR;
dim_t k_full_pieces_blks = rows / 2;
dim_t k_full_pieces = k_full_pieces_blks * 2;
dim_t k_partial_pieces = rows % 2;
dim_t KC_updated = rows;
// Making multiple of 2 to suit k in vpmaddubsw
KC_updated += (KC_updated & 0x1);
__m256i b_vec[2], inter_vec[2];
for (dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR)
{
for (dim_t kr = 0; kr < k_full_pieces; kr += 2)
{
// Read b[0,0], b[0,1], b[0,2]......., b[0,31]
b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 0)) + jc));
// Read b[1,0], b[1,1], b[1,2]......., b[1,31]
b_vec[1] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 1)) + jc));
// Reorder B matrix inputs to suit vpmaddubsw instructions
inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]);
inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]);
b_vec[0] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x20);
b_vec[1] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x31);
// Store B[0,0], B[1,0], B[0,1], B[1,1], ......, B[0,15], B[1,15]
_mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + (kr * NR))), b_vec[0]);
// Store B[0,16], B[1,16], B[0,17], B[1,17], ......, B[0,31], B[1,31]
_mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + ((kr + 1) * NR))), b_vec[1]);
}
if (k_partial_pieces > 0)
{
// Read b[0,0], b[0,1], b[0,2]......., b[0,31]
b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (k_full_pieces + 0)) + jc));
b_vec[1] = _mm256_setzero_si256(); // Initialize with zero for padding
// Reorder B matrix inputs to suit vpmaddubsw instructions
inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]);
inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]);
b_vec[0] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x20);
b_vec[1] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x31);
// Store B[0,0], B[1,0], B[0,1], B[1,1], ......, B[0,15], B[1,15]
_mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + (k_full_pieces * NR))), b_vec[0]);
// Store B[0,16], B[1,16], B[0,17], B[1,17], ......, B[0,31], B[1,31]
_mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + ((k_full_pieces + 1) * NR))), b_vec[1]);
}
}
// B matrix packing when n < NR
if (n_partial_pieces > 0)
{
// Split into multiple smaller fringe kernels, so as to maximize
// vectorization after packing. Any n0 < NR(32) can be expressed
// as n0 = 16 + n`.
dim_t n0_16 = n_partial_pieces / 16;
dim_t n0_partial_rem = n_partial_pieces % 16;
dim_t n0_partial_pack = 0;
if (n0_16 == 1)
{
packb_nr16_u8s8s16o16(
(pack_b_buffer_u8s8s16o16 +
(n_full_pieces_loop_limit * KC_updated)),
(b + n_full_pieces_loop_limit), ldb, rows);
n0_partial_pack = 16;
}
if (n0_partial_rem > 0)
{
packb_nrlt16_u8s8s16o16(
(pack_b_buffer_u8s8s16o16 + (n_full_pieces_loop_limit * KC_updated) +
(n0_partial_pack * KC_updated)),
(b + n_full_pieces_loop_limit + n0_partial_pack),
ldb, rows, n0_partial_rem);
}
}
*rs_b = NR * 2;
*cs_b = NR;
}
#endif

View File

@@ -1,510 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef LPGEMM_S16_KERN_MACROS_H
#define LPGEMM_S16_KERN_MACROS_H
#include "../gelu_avx2.h"
#include "../silu_avx2.h"
#include "../sigmoid_avx2.h"
#include "../math_utils_avx2.h"
#define S8_MIN (-128)
#define S8_MAX (+127)
/* ReLU scale (Parametric ReLU): f(x) = x, when x > 0 and f(x) = a*x when x <= 0 */
#define RELU_SCALE_OP_S16_AVX2(reg) \
selector1 = _mm256_setzero_si256();\
selector1 = _mm256_cmpgt_epi16 ( selector1, reg ); \
\
/* Only < 0 elements in b0. */ \
b0 = _mm256_and_si256 ( selector1, reg ); \
\
/* Only >= 0 elements in c_int16_0p0. */ \
reg = _mm256_andnot_si256( selector1, reg ); \
\
/* Only scaling for < 0 elements. */ \
b0 = _mm256_mullo_epi16( b0, selector2 ); \
\
/* Combine the scaled < 0 and >= 0 elements. */ \
reg = _mm256_or_si256( b0, reg ); \
// s16 fma macro
#define S16_BETA_FMA(reg,scratch1,scratch2) \
scratch1 = _mm256_mullo_epi16( scratch2, scratch1 ); \
reg = _mm256_add_epi16( scratch1, reg ); \
// Beta scale macro, scratch2=beta
#define S16_S16_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \
scratch1 = \
_mm256_loadu_si256 \
( \
( __m256i const* )( c + ( rs_c * ( m_ir + m_ind ) ) + ( n_ind * 16 ) ) \
); \
S16_BETA_FMA(reg,scratch1,scratch2) \
// Beta n < 16 scale macro, scratch2=beta
#define S16_S16_BETA_OP_NLT16(reg,buf_,scratch1,scratch2) \
scratch1 = _mm256_loadu_si256( ( __m256i const* )buf_ ); \
S16_BETA_FMA(reg,scratch1,scratch2) \
// Downscale beta scale macro (s8 -> s16), scratch2=beta
#define S8_S16_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \
scratch1 = \
_mm256_cvtepi8_epi16 \
( \
_mm_loadu_si128 \
( \
( __m128i const* )( ( int8_t* )post_ops_attr.buf_downscale + \
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) )\
) \
); \
S16_BETA_FMA(reg,scratch1,scratch2) \
// Downscale beta scale macro (u8 -> s16), scratch2=beta
#define U8_S16_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \
scratch1 = \
_mm256_cvtepu8_epi16 \
( \
_mm_loadu_si128 \
( \
( __m128i const* )( ( uint8_t* )post_ops_attr.buf_downscale + \
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) )\
) \
); \
S16_BETA_FMA(reg,scratch1,scratch2) \
// Downscale beta n < 16 scale macro (s8 -> s16), scratch2=beta
#define S8_S16_BETA_OP_NLT16(reg,buf_,scratch1,scratch2) \
scratch1 = _mm256_cvtepi8_epi16( _mm_loadu_si128( ( __m128i const* )buf_ ) ); \
S16_BETA_FMA(reg,scratch1,scratch2) \
// Downscale beta n < 16 scale macro (u8 -> s16), scratch2=beta
#define U8_S16_BETA_OP_NLT16(reg,buf_,scratch1,scratch2) \
scratch1 = _mm256_cvtepu8_epi16( _mm_loadu_si128( ( __m128i const* )buf_ ) ); \
S16_BETA_FMA(reg,scratch1,scratch2) \
#define US8_S16_BETA_NLT16_MEMCP_HELPER(buf_,m_ind,bytes, C_type) \
memcpy \
( \
buf_, \
( ( C_type* )post_ops_attr.buf_downscale + \
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
post_ops_attr.post_op_c_j ), bytes \
); \
#define S8_S16_BETA_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \
US8_S16_BETA_NLT16_MEMCP_HELPER(buf_,m_ind,bytes,int8_t) \
#define U8_S16_BETA_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \
US8_S16_BETA_NLT16_MEMCP_HELPER(buf_,m_ind,bytes,uint8_t) \
// Downscale macro
#define CVT_MULRND_CVT16(reg, scale0, scale1, zero_point_0) \
\
/* Extract the first 128 bits of the register*/ \
temp[0] = _mm256_extractf128_si256( reg, 0 ); \
/* Extract the second 128 bits of the register*/ \
temp[1] = _mm256_extractf128_si256( reg, 1 ); \
\
temp_32[0] = _mm256_cvtepi16_epi32( temp[0] ); \
temp_32[1] = _mm256_cvtepi16_epi32( temp[1] ); \
temp_float[0] = _mm256_cvtepi32_ps( temp_32[0] ); \
temp_float[1] = _mm256_cvtepi32_ps( temp_32[1] ); \
\
/* Multiply the C matrix by the scale value*/ \
res_1 = _mm256_mul_ps( temp_float[0], scale0 ); \
res_2 = _mm256_mul_ps( temp_float[1], scale1 ); \
\
/* Round the resultant value to the nearest float value. */ \
res_1 = \
_mm256_round_ps \
( \
res_1, ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \
); \
res_2 = \
_mm256_round_ps \
( \
res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) \
); \
\
/* Convert the clipped float32 scaled rounded value to int32 */ \
temp_32[0] = _mm256_cvtps_epi32( res_1 ); \
temp_32[1] = _mm256_cvtps_epi32( res_2 ); \
\
/* Convert the s32 to s16 */ \
reg = _mm256_packs_epi32( temp_32[0], temp_32[1] ); \
\
/*Permute to make sure the order is correct*/ \
reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \
\
/* Zero point addition.*/ \
reg = _mm256_add_epi16( reg, zero_point_0 ); \
// Downscale store macro helper
#define CVT_STORE_S16_SU8_HELPER(reg, m_ind, n_ind, C_type) \
reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \
\
_mm256_storeu_si256 \
( \
( __m256i* )( ( C_type* )post_ops_attr.buf_downscale + \
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
post_ops_attr.post_op_c_j + ( n_ind * 32 ) ), \
reg \
); \
// Downscale store macro (s16 -> s8)
#define CVT_STORE_S16_S8(reg0, reg1, m_ind, n_ind) \
/* Convert the s16 to s8 */ \
reg0 = _mm256_packs_epi16( reg0, reg1 ); \
CVT_STORE_S16_SU8_HELPER(reg0, m_ind, n_ind, int8_t) \
// Downscale store macro (s16 -> u8)
#define CVT_STORE_S16_U8(reg0, reg1, m_ind, n_ind) \
/* Convert the s16 to s8 */ \
reg0 = _mm256_packus_epi16( reg0, reg1 ); \
CVT_STORE_S16_SU8_HELPER(reg0, m_ind, n_ind, uint8_t) \
// Downscale store helper macro for fringe cases
#define CVT_STORE_S16_US8_2ROW_HELPER(reg, m_ind0, m_ind1, n_ind, C_type) \
reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \
\
/* Extract the first 128 bits of the register*/ \
temp[0] = _mm256_extractf128_si256( reg, 0 ); \
/* Extract the second 128 bits of the register*/ \
temp[1] = _mm256_extractf128_si256( reg, 1 ); \
\
_mm_storeu_si128 \
( \
( __m128i* )( ( C_type* )post_ops_attr.buf_downscale + \
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind0 ) ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) ), \
temp[0] \
); \
_mm_storeu_si128 \
( \
( __m128i* )( ( C_type* )post_ops_attr.buf_downscale + \
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind1 ) ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) ), \
temp[1] \
); \
// Downscale store macro for fringe cases (s16 -> s8)
#define CVT_STORE_S16_S8_2ROW(reg0, reg1, m_ind0, m_ind1, n_ind) \
/* Convert the s16 to s8 */ \
reg0 = _mm256_packs_epi16( reg0, reg1 ); \
CVT_STORE_S16_US8_2ROW_HELPER(reg0, m_ind0, m_ind1, n_ind, int8_t) \
// Downscale store macro for fringe cases (s16 -> u8)
#define CVT_STORE_S16_U8_2ROW(reg0, reg1, m_ind0, m_ind1, n_ind) \
/* Convert the s16 to u8 */ \
reg0 = _mm256_packus_epi16( reg0, reg1 ); \
CVT_STORE_S16_US8_2ROW_HELPER(reg0, m_ind0, m_ind1, n_ind, uint8_t) \
// Downscale store helper macro for fringe cases
#define CVT_STORE_S16_US8_1ROW(reg, m_ind0, n_ind, C_type) \
reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \
\
/* Extract the first 128 bits of the register*/ \
temp[0] = _mm256_extractf128_si256( reg, 0 ); \
\
_mm_storeu_si128 \
( \
( __m128i* )( ( C_type* )post_ops_attr.buf_downscale + \
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind0 ) ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) ), \
temp[0] \
); \
// Downscale store (s16 -> s8) macro for fringe cases
#define CVT_STORE_S16_S8_1ROW(reg0, reg1, m_ind0, n_ind) \
/* Convert the s16 to s8 */ \
reg0 = _mm256_packs_epi16( reg0, reg1 ); \
CVT_STORE_S16_US8_1ROW(reg0, m_ind0, n_ind, int8_t) \
// Downscale store (s16 -> u8) macro for fringe cases
#define CVT_STORE_S16_U8_1ROW(reg0, reg1, m_ind0, n_ind) \
/* Convert the s16 to u8 */ \
reg0 = _mm256_packus_epi16( reg0, reg1 ); \
CVT_STORE_S16_US8_1ROW(reg0, m_ind0, n_ind, uint8_t) \
// Downscale store helper macro for n < 16 fringe cases
#define CVT_STORE_S16_US8_2ROW_NLT16(reg, buf0, buf1) \
reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \
\
/* Extract the first 128 bits of the register*/ \
temp[0] = _mm256_extractf128_si256( reg, 0 ); \
/* Extract the second 128 bits of the register*/ \
temp[1] = _mm256_extractf128_si256( reg, 1 ); \
\
_mm_storeu_si128( ( __m128i* )buf0, temp[0] ); \
_mm_storeu_si128( ( __m128i* )buf1, temp[1] ); \
// Downscale store (int16 -> s8) macro for n < 16 fringe cases
#define CVT_STORE_S16_S8_2ROW_NLT16(reg0, reg1, buf0, buf1) \
/* Convert the s16 to s8 */ \
reg0 = _mm256_packs_epi16( reg0, reg1 ); \
CVT_STORE_S16_US8_2ROW_NLT16(reg0, buf0, buf1) \
// Downscale store (int16 -> u8) macro for n < 16 fringe cases
#define CVT_STORE_S16_U8_2ROW_NLT16(reg0, reg1, buf0, buf1) \
/* Convert the s16 to s8 */ \
reg0 = _mm256_packus_epi16( reg0, reg1 ); \
CVT_STORE_S16_US8_2ROW_NLT16(reg0, buf0, buf1) \
// Downscale store helper macro for n < 16 fringe cases
#define CVT_STORE_S16_US8_1ROW_NLT16(reg, buf0) \
reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \
\
/* Extract the first 128 bits of the register*/ \
temp[0] = _mm256_extractf128_si256( reg, 0 ); \
\
_mm_storeu_si128( ( __m128i* )buf0, temp[0] ); \
// Downscale store (s16 -> s8) macro for n < 16 fringe cases
#define CVT_STORE_S16_S8_1ROW_NLT16(reg0, reg1, buf0) \
/* Convert the s16 to s8 */ \
reg0 = _mm256_packs_epi16( reg0, reg1 ); \
CVT_STORE_S16_US8_1ROW_NLT16(reg0, buf0) \
// Downscale store (s16 -> u8) macro for n < 16 fringe cases
#define CVT_STORE_S16_U8_1ROW_NLT16(reg0, reg1, buf0) \
/* Convert the s16 to u8 */ \
reg0 = _mm256_packus_epi16( reg0, reg1 ); \
CVT_STORE_S16_US8_1ROW_NLT16(reg0, buf0) \
#define CVT_STORE_S16_US8_NLT16_MEMCP_HELPER(buf_,m_ind,bytes, C_type) \
memcpy \
( \
( ( C_type* )post_ops_attr.buf_downscale + \
( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \
post_ops_attr.post_op_c_j ), buf_, bytes \
); \
#define CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \
CVT_STORE_S16_US8_NLT16_MEMCP_HELPER(buf_,m_ind,bytes, int8_t) \
#define CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \
CVT_STORE_S16_US8_NLT16_MEMCP_HELPER(buf_,m_ind,bytes, uint8_t) \
//--------------------------------------------------------------------------
/* GeLU (x) = 0.5* x * (1 + tanh ( 0.797884 * ( x + ( 0.044715 * x^3 ) ) ) ) */
#define GELU_TANH_S16_AVX2(reg, y1, y2, r, r2, x, z, dn, x_tanh, q) \
\
y1 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(reg, 0)) ); \
y2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(reg, 1)) ); \
\
GELU_TANH_F32_AVX2_DEF(y1, r, r2, x, z, dn, x_tanh, q); \
\
GELU_TANH_F32_AVX2_DEF(y2, r, r2, x, z, dn, x_tanh, q); \
\
reg = _mm256_packs_epi32(_mm256_cvtps_epi32(y1), _mm256_cvtps_epi32(y2));\
reg = _mm256_permute4x64_epi64(reg, 0XD8);\
/* ERF GeLU (x) = 0.5* x * (1 + erf (x * 0.707107 )) */
#define GELU_ERF_S16_AVX2(reg, y1, y2, r, x, x_erf) \
\
y1 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(reg, 0)) ); \
y2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(reg, 1)) ); \
\
GELU_ERF_F32_AVX2_DEF(y1, r, x, x_erf); \
\
GELU_ERF_F32_AVX2_DEF(y2, r, x, x_erf); \
\
reg = _mm256_packs_epi32(_mm256_cvtps_epi32(y1), _mm256_cvtps_epi32(y2));\
reg = _mm256_permute4x64_epi64(reg, 0XD8);\
#define CLIP_S16_AVX2(reg, min, max) \
\
reg = _mm256_min_epi16( _mm256_max_epi16( reg, min ), max ); \
// Matrix Add post-ops helper macros
#define S16_MATRIX_ADD_1COL(scr0,m_ind) \
c_int16_ ## m_ind ## p0 = _mm256_add_epi16( scr0, c_int16_ ## m_ind ## p0 ); \
#define S16_MATRIX_ADD_2COL(scr0,scr1,m_ind) \
c_int16_ ## m_ind ## p0 = _mm256_add_epi16( scr0, c_int16_ ## m_ind ## p0 ); \
c_int16_ ## m_ind ## p1 = _mm256_add_epi16( scr1, c_int16_ ## m_ind ## p1 ); \
#define S8_S16_MATRIX_ADD_LOAD(scr,m_ind,n_ind) \
scr = _mm256_cvtepi8_epi16 \
( \
_mm_loadu_si128 \
( \
( __m128i const* ) \
( matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) ) \
) \
); \
#define S8_S16_MATRIX_ADD_1COL_PAR(buf,scr0,m_ind,n_rem,OTYPE) \
memcpy \
( \
( OTYPE* )buf, \
matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
post_ops_attr.post_op_c_j + ( 0 * 16 ), \
( n_rem ) * sizeof(OTYPE) \
); \
scr0 = _mm256_cvtepi8_epi16 \
( \
_mm_loadu_si128( ( __m128i const* )buf ) \
); \
S16_MATRIX_ADD_1COL(scr0,m_ind); \
#define S8_S16_MATRIX_ADD_1COL(scr0,m_ind) \
S8_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \
S16_MATRIX_ADD_1COL(scr0,m_ind); \
#define S8_S16_MATRIX_ADD_2COL(scr0,scr1,m_ind) \
S8_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \
S8_S16_MATRIX_ADD_LOAD(scr1,m_ind,1); \
S16_MATRIX_ADD_2COL(scr0,scr1,m_ind); \
#define U8_S16_MATRIX_ADD_LOAD(scr,m_ind,n_ind) \
scr = _mm256_cvtepu8_epi16 \
( \
_mm_loadu_si128 \
( \
( __m128i const* ) \
( matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) ) \
) \
); \
#define U8_S16_MATRIX_ADD_1COL_PAR(buf,scr0,m_ind,n_rem,OTYPE) \
memcpy \
( \
( OTYPE* )buf, \
matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
post_ops_attr.post_op_c_j + ( 0 * 16 ), \
( n_rem ) * sizeof(OTYPE) \
); \
scr0 = _mm256_cvtepu8_epi16 \
( \
_mm_loadu_si128( ( __m128i const* )buf ) \
); \
S16_MATRIX_ADD_1COL(scr0,m_ind); \
#define U8_S16_MATRIX_ADD_1COL(scr0,m_ind) \
U8_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \
S16_MATRIX_ADD_1COL(scr0,m_ind); \
#define U8_S16_MATRIX_ADD_2COL(scr0,scr1,m_ind) \
U8_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \
U8_S16_MATRIX_ADD_LOAD(scr1,m_ind,1); \
S16_MATRIX_ADD_2COL(scr0,scr1,m_ind); \
#define S16_S16_MATRIX_ADD_LOAD(scr,m_ind,n_ind) \
scr = _mm256_loadu_si256 \
( \
(__m256i const *) \
( matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
post_ops_attr.post_op_c_j + ( n_ind * 16 ) ) \
); \
#define S16_S16_MATRIX_ADD_1COL_PAR(buf,scr0,m_ind,n_rem,OTYPE) \
memcpy \
( \
( OTYPE* )buf, \
matptr + ( ( post_ops_attr.post_op_c_i + m_ind ) * ldm ) + \
post_ops_attr.post_op_c_j + ( 0 * 16 ), \
( n_rem ) * sizeof(OTYPE) \
); \
scr0 = _mm256_loadu_si256( ( __m256i const* )buf ); \
S16_MATRIX_ADD_1COL(scr0,m_ind); \
#define S16_S16_MATRIX_ADD_1COL(scr0,m_ind) \
S16_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \
S16_MATRIX_ADD_1COL(scr0,m_ind); \
#define S16_S16_MATRIX_ADD_2COL(scr0,scr1,m_ind) \
S16_S16_MATRIX_ADD_LOAD(scr0,m_ind,0); \
S16_S16_MATRIX_ADD_LOAD(scr1,m_ind,1); \
S16_MATRIX_ADD_2COL(scr0,scr1,m_ind); \
// SiLU utility macros. al1, al2 register expected to contain floats.
#define SWISH_S16_AVX2(in_reg, al, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out) \
\
tmp_reg1 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \
_mm256_extractf128_si256( in_reg, 0 ) ) ); \
tmp_reg2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \
_mm256_extractf128_si256( in_reg, 1 ) ) ); \
\
SWISH_F32_AVX2_DEF(tmp_reg1, al, al_in, r, r2, z, dn, ex_out); \
\
SWISH_F32_AVX2_DEF(tmp_reg2, al, al_in, r, r2, z, dn, ex_out); \
\
in_reg = _mm256_packs_epi32(_mm256_cvtps_epi32(tmp_reg1), _mm256_cvtps_epi32(tmp_reg2));\
in_reg = _mm256_permute4x64_epi64(in_reg, 0XD8);\
//TANH utility macros.
#define TANH_S16_AVX2(reg, y1, y2, r, r2, x, z, dn, q) \
\
y1 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(reg, 0)) ); \
y2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32(_mm256_extractf128_si256(reg, 1)) ); \
\
TANHF_AVX2(y1, r, r2, x, z, dn, q); \
\
TANHF_AVX2(y2, r, r2, x, z, dn, q); \
\
reg = _mm256_packs_epi32(_mm256_cvtps_epi32(y1), _mm256_cvtps_epi32(y2));\
reg = _mm256_permute4x64_epi64(reg, 0XD8);\
// SIGMOID utility macros. al1, al2 register expected to contain floats.
#define SIGMOID_S16_AVX2(in_reg, al_in, tmp_reg1, tmp_reg2, r, r2, z, dn, ex_out) \
\
tmp_reg1 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \
_mm256_extractf128_si256( in_reg, 0 ) ) ); \
tmp_reg2 = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \
_mm256_extractf128_si256( in_reg, 1 ) ) ); \
\
SIGMOID_F32_AVX2_DEF(tmp_reg1, al_in, r, r2, z, dn, ex_out); \
\
SIGMOID_F32_AVX2_DEF(tmp_reg2, al_in, r, r2, z, dn, ex_out); \
\
in_reg = _mm256_packs_epi32(_mm256_cvtps_epi32(tmp_reg1), _mm256_cvtps_epi32(tmp_reg2));\
in_reg = _mm256_permute4x64_epi64(in_reg, 0XD8);\
//Zero-out the given YMM accumulator registers
#define ZERO_ACC_YMM_4_REG(ymm0,ymm1,ymm2,ymm3) \
ymm0 = _mm256_setzero_si256 (); \
ymm1 = _mm256_setzero_si256 (); \
ymm2 = _mm256_setzero_si256 (); \
ymm3 = _mm256_setzero_si256 ();
#endif //LPGEMM_S16_KERN_MACROS_H

View File

@@ -1,815 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <immintrin.h>
#include "blis.h"
#ifdef BLIS_ADDON_LPGEMM
#include "lpgemm_s16_kern_macros.h"
#define LPGEMV_N_KERNEL_2_LOADS( ymm0, ymm1, paddr, stride ) \
ymm0 = _mm256_loadu_si256( (__m256i const *)paddr ); \
ymm1 = _mm256_loadu_si256( (__m256i const *)(paddr + stride) );
#define LPGEMV_N_KERNEL_2_FMA( a_reg1, a_reg2, b_reg, \
inter_reg1, inter_reg2, c_reg1, c_reg2 ) \
inter_reg1 = _mm256_maddubs_epi16(a_reg1, b_reg); \
c_reg1 = _mm256_add_epi16(inter_reg1, c_reg1); \
inter_reg2 = _mm256_maddubs_epi16(a_reg2, b_reg); \
c_reg2 = _mm256_add_epi16(inter_reg2, c_reg2);
#define LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, paddr, stride ) \
ymm0 = _mm256_loadu_si256( (__m256i const *)(paddr) ); \
ymm1 = _mm256_loadu_si256( (__m256i const *)(paddr + stride) ); \
ymm2 = _mm256_loadu_si256( (__m256i const *)(paddr + 2 * stride) ); \
ymm3 = _mm256_loadu_si256( (__m256i const *)(paddr + 3 * stride) );
#define LPGEMV_N_KERNEL_4_FMA( a_reg1, a_reg2, a_reg3, a_reg4, b_reg, \
inter_reg1, inter_reg2, \
inter_reg3, inter_reg4, \
out_reg1, out_reg2, out_reg3, out_reg4 ) \
inter_reg1 = _mm256_maddubs_epi16(a_reg1, b_reg); \
out_reg1 = _mm256_add_epi16(inter_reg1, out_reg1); \
inter_reg2 = _mm256_maddubs_epi16(a_reg2, b_reg); \
out_reg2 = _mm256_add_epi16(inter_reg2, out_reg2); \
inter_reg3 = _mm256_maddubs_epi16(a_reg3, b_reg); \
out_reg3 = _mm256_add_epi16(inter_reg3, out_reg3); \
inter_reg4 = _mm256_maddubs_epi16(a_reg4, b_reg); \
out_reg4 = _mm256_add_epi16(inter_reg4, out_reg4);
#define LPGEMV_YMM2XMM( ymm0, ymm1, ymm2, ymm3, xmm0 ) \
ymm0 = _mm256_hadd_epi16( ymm0, ymm1 ); \
ymm1 = _mm256_hadd_epi16( ymm2, ymm3 ); \
ymm0 = _mm256_hadd_epi16( ymm0, ymm1 ); \
xmm0 = _mm_add_epi16( _mm256_extracti128_si256( ymm0, 0 ), \
_mm256_extracti128_si256( ymm0, 1 ) );
LPGEMV_N_EQ1_KERN(uint8_t, int8_t, int16_t, u8s8s16os16)
{
static void* post_ops_labels[] =
{
&&POST_OPS_DISABLE,
&&POST_OPS_BIAS,
&&POST_OPS_RELU,
&&POST_OPS_RELU_SCALE,
&&POST_OPS_GELU_TANH,
&&POST_OPS_GELU_ERF,
&&POST_OPS_CLIP,
&&POST_OPS_DOWNSCALE,
&&POST_OPS_MATRIX_ADD,
&&POST_OPS_SWISH,
NULL,// Virtual node for matrix_mul, else segfault
&&POST_OPS_TANH,
&&POST_OPS_SIGMOID
};
uint8_t *a_use = NULL;
int8_t *b_use = NULL;
int16_t *c_use = NULL;
lpgemm_post_op_attr post_ops_attr = *(post_op_attr);
// temp buffer to store output C vector
int16_t ctemp[16];
// temp buffers to store a, b data in k_rem case.
uint8_t buf0[32] = {0};
uint8_t buf1[32] = {0};
uint8_t buf2[32] = {0};
uint8_t buf3[32] = {0};
uint8_t buf4[32] = {0};
uint8_t buf5[32] = {0};
uint8_t buf6[32] = {0};
uint8_t buf7[32] = {0};
int8_t buf8[32] = {0};
for ( dim_t ir = 0; ir < m0; ir += MR )
{
dim_t mr0 = bli_min( ( m0 - ir ), MR );
dim_t k_iter = k / 32;
dim_t k_rem = k % 32;
__m256i ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7;
__m256i ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14;
__m256i ymm15;
__m128i xmm0, xmm1;
/* zero the accumulator registers */
ZERO_ACC_YMM_4_REG( ymm8, ymm9, ymm10, ymm11 )
ZERO_ACC_YMM_4_REG( ymm12, ymm13, ymm14, ymm15 )
//update pointers
a_use = (uint8_t*)a + ir * rs_a;
b_use = (int8_t*)b;
c_use = (int16_t*)c + ir * rs_c;
if( mr0 == MR )
{
for (dim_t k = 0; k < k_iter; k++)
{
ymm6 = _mm256_loadu_si256( (__m256i const *)(b_use) );
b_use += 32;
//Load 4x32 elements from row0-row3 of A
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3, a_use, rs_a )
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
ymm6, ymm4, ymm5, ymm7, ymm4,
ymm8, ymm9, ymm10, ymm11
)
// Load 4x32 elements from row8-row11 of A
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3,
( a_use + 4 * rs_a ), rs_a
)
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
ymm6, ymm4, ymm5, ymm7, ymm4,
ymm12, ymm13, ymm14, ymm15
)
a_use += 32;
}
if( k_rem )
{
uint8_t* restrict a0 = (a_use);
uint8_t* restrict a1 = (a_use + rs_a );
uint8_t* restrict a2 = (a_use + 2 * rs_a );
uint8_t* restrict a3 = (a_use + 3 * rs_a );
uint8_t* restrict a4 = (a_use + 4 * rs_a );
uint8_t* restrict a5 = (a_use + 5 * rs_a );
uint8_t* restrict a6 = (a_use + 6 * rs_a );
uint8_t* restrict a7 = (a_use + 7 * rs_a );
for( dim_t i = 0; i < k_rem; i++)
{
buf8[i] = b_use[i];
buf0[i] = a0[i];
buf1[i] = a1[i];
buf2[i] = a2[i];
buf3[i] = a3[i];
buf4[i] = a4[i];
buf5[i] = a5[i];
buf6[i] = a6[i];
buf7[i] = a7[i];
}
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
//Load 4x32 elements from row0-row3 of A
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 );
ymm2 = _mm256_loadu_si256( (__m256i const *)buf2 );
ymm3 = _mm256_loadu_si256( (__m256i const *)buf3 );
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
ymm6, ymm4, ymm5, ymm7, ymm4,
ymm8, ymm9, ymm10, ymm11
)
// Load 4x32 elements from row8-row11 of A
ymm0 = _mm256_loadu_si256( (__m256i const *)buf4 );
ymm1 = _mm256_loadu_si256( (__m256i const *)buf5 );
ymm2 = _mm256_loadu_si256( (__m256i const *)buf6 );
ymm3 = _mm256_loadu_si256( (__m256i const *)buf7 );
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
ymm6, ymm4, ymm5, ymm7, ymm4,
ymm12, ymm13, ymm14, ymm15
)
}
//Add the registers horizantally to get one
LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, xmm0 )
LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, xmm1 )
xmm0 = _mm_hadd_epi16( xmm0, xmm1 );
// post ops are applied on ymm register though
// second half of the register is filled with zeroes.
ymm8 = _mm256_setzero_si256();
ymm8 = _mm256_inserti128_si256( ymm8, xmm0, 0);
}
else
{
uint8_t *a_use_fringe = a_use;
dim_t mr0_use = mr0;
dim_t regidx = 0;
if( mr0_use >= 4 )
{
for (dim_t k = 0; k < k_iter; k++)
{
ymm6 = _mm256_loadu_si256( (__m256i const *)b_use );
b_use += 32;
//Load 4x32 elements from row0-row3 of A
LPGEMV_N_KERNEL_4_LOADS( ymm0, ymm1, ymm2, ymm3,
a_use, rs_a )
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
ymm6, ymm4, ymm5, ymm7, ymm4,
ymm8, ymm9, ymm10, ymm11
)
a_use += 32;
}
if( k_rem )
{
uint8_t* restrict a0 = (a_use);
uint8_t* restrict a1 = (a_use + rs_a );
uint8_t* restrict a2 = (a_use + 2 * rs_a );
uint8_t* restrict a3 = (a_use + 3 * rs_a );
for( dim_t i = 0; i < k_rem; i++)
{
buf8[i] = b_use[i];
buf0[i] = a0[i];
buf1[i] = a1[i];
buf2[i] = a2[i];
buf3[i] = a3[i];
}
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
//Load 4xk_rem elements from row0-row3 of A
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 );
ymm2 = _mm256_loadu_si256( (__m256i const *)buf2 );
ymm3 = _mm256_loadu_si256( (__m256i const *)buf3 );
LPGEMV_N_KERNEL_4_FMA( ymm0, ymm1, ymm2, ymm3,
ymm6, ymm4, ymm5, ymm7, ymm4,
ymm8, ymm9, ymm10, ymm11
)
}
//update pointers
mr0_use -= 4;
a_use = a_use_fringe + 4 * rs_a;
a_use_fringe = a_use;
b_use = (int8_t*)b;
//Add the registers horizantally to get one
LPGEMV_YMM2XMM( ymm8, ymm9, ymm10, ymm11, xmm0 )
xmm0 = _mm_hadd_epi16( xmm0, xmm0 );
int64_t data = _mm_extract_epi64( xmm0, 0);
//insert xmm outputs into final output reg based on regidx
ymm8 = _mm256_setzero_si256();
ymm8 = _mm256_insert_epi64( ymm8, data, 0 );
regidx++;
}
// Dot product for <= 3
if ( mr0_use )
{
// Dot product for m = 2
if ( mr0_use >= 2 )
{
for ( dim_t k = 0; k < k_iter; k++ )
{
// Load 0-31 in b[k+0 - k+31]
ymm6 = _mm256_loadu_si256( (__m256i const *)b_use );
LPGEMV_N_KERNEL_2_LOADS( ymm0, ymm1, a_use, rs_a);
LPGEMV_N_KERNEL_2_FMA( ymm0, ymm1, ymm6, ymm4,
ymm5, ymm12, ymm13);
b_use += 32; // move b pointer to next 32 elements
a_use += 32;
}
if ( k_rem )
{
uint8_t* restrict a0 = (a_use);
uint8_t* restrict a1 = (a_use + rs_a );
for( dim_t i = 0; i < k_rem; i++)
{
buf8[i] = b_use[i];
buf0[i] = a0[i];
buf1[i] = a1[i];
}
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
//Load 2xk_rem elements from row0-row3 of A
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
ymm1 = _mm256_loadu_si256( (__m256i const *)buf1 );
LPGEMV_N_KERNEL_2_FMA( ymm0, ymm1, ymm6,
ymm4, ymm5, ymm12, ymm13 );
}
mr0_use -= 2;
a_use = a_use_fringe + 2 * rs_a;
a_use_fringe = a_use;
b_use = (int8_t*)b;
}
// Dot product for m = 1
if ( mr0_use == 1 )
{
for ( dim_t k = 0; k < k_iter; k++ )
{
// Load 0-31 in b[k+0 - k+31]
ymm6 = _mm256_loadu_si256( (__m256i const *)b_use );
// Load 1x32 elements from row0-row1 of A
ymm0 = _mm256_loadu_si256( (__m256i const *)a_use );
ymm4 = _mm256_maddubs_epi16(ymm0, ymm6);
ymm14 = _mm256_add_epi16(ymm4, ymm14);
b_use += 32; // move b pointer to next 32 elements
a_use += 32;
}
if ( k_rem )
{
uint8_t* restrict a0 = (a_use);
for( dim_t i = 0; i < k_rem; i++)
{
buf8[i] = b_use[i];
buf0[i] = a0[i];
}
ymm6 = _mm256_loadu_si256( (__m256i const *)buf8 );
//Load 1xk_rem elements from row0-row3 of A
ymm0 = _mm256_loadu_si256( (__m256i const *)buf0 );
ymm4 = _mm256_maddubs_epi16(ymm0, ymm6);
ymm14 = _mm256_add_epi16(ymm4, ymm14);
}
// When only fringe 1,
// update the registers to store in order
if ( !( mr0 & 0x2 ) ) ymm12 = ymm14;
}
LPGEMV_YMM2XMM( ymm12, ymm13, ymm14, ymm15, xmm0)
xmm0 = _mm_hadd_epi16( xmm0, xmm0 );
int64_t data = _mm_extract_epi64( xmm0, 0);
//insert xmm outputs into final output reg based on regidx
if( regidx == 0 )
{
ymm8 = _mm256_insert_epi64( ymm8, data, 0 );
}
else
{
ymm8 = _mm256_insert_epi64( ymm8, data, 1 );
}
}
}
// Load alpha and beta
__m256i selector1 = _mm256_set1_epi16(alpha);
__m256i selector2 = _mm256_set1_epi16(beta);
// Scale by alpha
ymm8 = _mm256_mullo_epi16(selector1, ymm8);
if( beta != 0 )
{
if ( post_ops_attr.buf_downscale != NULL )
{
if( post_ops_attr.rs_c_downscale == 1 )
{
if( post_ops_attr.c_stor_type == S8 )
{
dim_t m0_rem_dscale_bytes = mr0 * sizeof( int8_t );
S8_S16_BETA_NLT16_MEMCP_UTIL( ctemp, 0,
m0_rem_dscale_bytes );
S8_S16_BETA_OP_NLT16( ymm8, ctemp,
selector1, selector2 )
}
else if( post_ops_attr.c_stor_type == U8 )
{
dim_t m0_rem_dscale_bytes = mr0 * sizeof( uint8_t );
U8_S16_BETA_NLT16_MEMCP_UTIL( ctemp, 0,
m0_rem_dscale_bytes );
U8_S16_BETA_OP_NLT16( ymm8, ctemp,
selector1, selector2 )
}
}
else
{
if( post_ops_attr.c_stor_type == S8 )
{
int8_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( (int8_t*)post_ops_attr.buf_downscale
+ ( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) );
}
selector1 = _mm256_cvtepi8_epi32
( _mm_loadu_si128( (__m128i const*)ctemp ) );
S16_BETA_FMA( ymm8, selector1, selector2 );
}
else if( post_ops_attr.c_stor_type == U8 )
{
uint8_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( (uint8_t*)post_ops_attr.buf_downscale
+ ( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) );
}
selector1 = _mm256_cvtepu8_epi32
( _mm_loadu_si128( (__m128i const*)ctemp ) );
S16_BETA_FMA( ymm8, selector1, selector2 );
}
}
}
else
{
if( rs_c == 1 )
{
dim_t m0_rem_bytes = mr0 * sizeof( int16_t );
memcpy( ctemp, c_use, m0_rem_bytes );
S16_S16_BETA_OP_NLT16( ymm8, ctemp,
selector1, selector2 )
}
else
{
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = c_use[ i * rs_c ];
}
selector1 = _mm256_loadu_si256( (__m256i const *)ctemp );
S16_BETA_FMA( ymm8, selector1, selector2 );
}
}
}
// Post Ops
lpgemm_post_op * post_ops_list_temp = post_op;
post_ops_attr.is_last_k = TRUE;
POST_OP_LABEL_LASTK_SAFE_JUMP
POST_OPS_BIAS:
{
selector1 =
_mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args1) );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_RELU:
{
selector1 = _mm256_setzero_si256();
ymm8 = _mm256_max_epi16( selector1, ymm8 );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_RELU_SCALE:
{
__m256i b0;
selector1 = _mm256_setzero_si256();
selector2 = _mm256_set1_epi16(
*( ( int16_t* )post_ops_list_temp->op_args2 ) );
RELU_SCALE_OP_S16_AVX2( ymm8 )
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_GELU_TANH:
{
__m256 dn, z, x, r2, r, y1, y2, x_tanh;
__m256i q;
GELU_TANH_S16_AVX2( ymm8, y1, y2, r, r2, x, z, dn, x_tanh, q )
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_GELU_ERF:
{
__m256 x, r, y1, y2, x_erf;
GELU_ERF_S16_AVX2(ymm8, y1, y2, r, x, x_erf)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_CLIP:
{
__m256i min = _mm256_set1_epi16(
*( int16_t* )post_ops_list_temp->op_args2 );
__m256i max = _mm256_set1_epi16(
*( int16_t* )post_ops_list_temp->op_args3 );
CLIP_S16_AVX2(ymm8, min, max)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_DOWNSCALE:
{
__m128i temp[2];
__m256i temp_32[2];
__m256 temp_float[2];
__m256 scale_1 = _mm256_setzero_ps();
__m256 scale_2 = _mm256_setzero_ps();
__m128i _zero_point_0 = _mm_setzero_si128();
__m256i zero_point_0 = _mm256_setzero_si256();
__m256 res_1, res_2;
scale_1 =
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
scale_2 =
_mm256_set1_ps( *( ( float* )post_ops_list_temp->scale_factor ) );
_zero_point_0 = _mm_set1_epi8(
*( ( int8_t* )post_ops_list_temp->op_args1 ) );
if ( post_ops_attr.c_stor_type == S8 )
{
zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 );
}
else if ( post_ops_attr.c_stor_type == U8 )
{
zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 );
}
// Scale first 16 columns of the 2 rows.
CVT_MULRND_CVT16(ymm8, scale_1, scale_2, zero_point_0)
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_MATRIX_ADD:
{
dim_t ldm = *( dim_t* )post_ops_list_temp->op_args3;
if ( post_ops_attr.c_stor_type == S8 )
{
int8_t* matptr = ( int8_t* )post_ops_list_temp->op_args1;
if( ldm == 1 )
{
memcpy
(
( int8_t* )ctemp,
matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) +
post_ops_attr.post_op_c_j + ( 0 * 16 ),
( mr0 ) * sizeof(int8_t)
);
selector1 = _mm256_cvtepi8_epi16(
_mm_loadu_si128( ( __m128i const* )ctemp ) );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
}
else
{
int8_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( matptr +
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm256_cvtepi8_epi16
( _mm_loadu_si128( (__m128i const*)ctemp ) );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
}
}
else if ( post_ops_attr.c_stor_type == U8 )
{
uint8_t* matptr = ( uint8_t* )post_ops_list_temp->op_args1;
if( ldm == 1 )
{
memcpy
(
( uint8_t* )ctemp,
matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) +
post_ops_attr.post_op_c_j + ( 0 * 16 ),
( mr0 ) * sizeof(uint8_t)
);
selector1 = _mm256_cvtepu8_epi16(
_mm_loadu_si128( ( __m128i const* )ctemp ) );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
}
else
{
uint8_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( matptr +
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm256_cvtepu8_epi16
( _mm_loadu_si128( (__m128i const*)ctemp ) );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
}
}
else
{
int16_t* matptr = ( int16_t* )post_ops_list_temp->op_args1;
if( ldm == 1 )
{
memcpy
(
( int16_t* )ctemp,
matptr + ( ( post_ops_attr.post_op_c_i ) * ldm ) +
post_ops_attr.post_op_c_j + ( 0 * 16 ),
( mr0 ) * sizeof(int16_t)
);
selector1 = _mm256_loadu_si256( ( __m256i const* )ctemp );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
}
else
{
int32_t ctemp[16];
for( dim_t i = 0; i < mr0; i++ )
{
ctemp[i] = *( matptr +
( ( post_ops_attr.post_op_c_i + i )
* ldm ) );
}
selector1 = _mm256_loadu_si256( (__m256i const *)ctemp );
ymm8 = _mm256_add_epi16( selector1, ymm8 );
}
}
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_SWISH:
{
selector1 =
_mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) );
__m256 al = _mm256_cvtepi32_ps( _mm256_cvtepi16_epi32( \
_mm256_extractf128_si256( selector1, 0 ) ) );
__m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn;
__m256i ex_out;
SWISH_S16_AVX2( ymm8, al, al_in, tmp_reg1,
tmp_reg2, r, r2, z, dn, ex_out );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_TANH:
{
__m256 dn, z, x, r2, r, y1, y2;
__m256i q;
TANH_S16_AVX2( ymm8, y1, y2, r, r2, x, z, dn, q )
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_SIGMOID:
{
__m256 al_in, tmp_reg1, tmp_reg2, r, r2, z, dn;
__m256i ex_out;
SIGMOID_S16_AVX2( ymm8, al_in, tmp_reg1,
tmp_reg2, r, r2, z, dn, ex_out );
POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR
}
POST_OPS_DISABLE:
{
if ( post_ops_attr.buf_downscale != NULL )
{
__m128i temp[2];
__m256i zero_reg = _mm256_setzero_si256();
if( post_ops_attr.rs_c_downscale == 1 )
{
if( post_ops_attr.c_stor_type == S8 )
{
// Store the results in downscaled type
// (int8 instead of int16).
CVT_STORE_S16_S8_1ROW_NLT16(ymm8, zero_reg, ctemp);
dim_t m0_rem_dscale_bytes = mr0 * sizeof( int8_t );
CVT_STORE_S16_S8_NLT16_MEMCP_UTIL( ctemp, 0,
m0_rem_dscale_bytes);
}
else if( post_ops_attr.c_stor_type == U8 )
{
// Store the results in downscaled type (uint8 instead of int16).
CVT_STORE_S16_U8_1ROW_NLT16(ymm8, zero_reg, ctemp);
dim_t m0_rem_dscale_bytes = mr0 * sizeof( uint8_t );
CVT_STORE_S16_U8_NLT16_MEMCP_UTIL( ctemp, 0,
m0_rem_dscale_bytes);
}
}
else
{
if( post_ops_attr.c_stor_type == S8 )
{
int8_t ctemp[16];
CVT_STORE_S16_S8_1ROW_NLT16(ymm8, zero_reg, ctemp);
for( dim_t i = 0; i < mr0; i++ )
{
*( ( int8_t* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
}
}
else if( post_ops_attr.c_stor_type == U8 )
{
uint8_t ctemp[16];
CVT_STORE_S16_U8_1ROW_NLT16(ymm8, zero_reg, ctemp);
for( dim_t i = 0; i < mr0; i++ )
{
*( ( uint8_t* )post_ops_attr.buf_downscale +
( post_ops_attr.rs_c_downscale *
( post_ops_attr.post_op_c_i + i ) ) ) = ctemp[i];
}
}
}
}
else
{
if( rs_c == 1 )
{
_mm256_storeu_si256( ( __m256i* )ctemp, ymm8 );
dim_t m0_rem_bytes = mr0 * sizeof( int16_t );
memcpy( c_use, ctemp, m0_rem_bytes );
}
else
{
_mm256_storeu_si256( ( __m256i* )ctemp, ymm8 );
for( dim_t i = 0; i < mr0; i++ )
{
c_use[i * rs_c] = ctemp[i];
}
}
}
post_ops_attr.post_op_c_i += MR;
}
}
}
#endif