Implemented batch_matmul for f32 & int8 datatypes

Details:
- The batch matmul performs a series of matmuls, processing
  more than one GEMM problem at once.
- Introduced a new parameter called batch_size for the user
  to indicate number of GEMM problems in a batch/group.
- This operation supports processing GEMM problems with
  different parameters including dims,post-ops,stor-schemes etc.,
- This operation is optimized for problems where all the
  GEMMs in a batch are of same size and shape.
- For now, the threads are distributed among different GEMM
  problems equally irrespective of their dimensions which
  leads to better performance for batches with identical GEMMs
  but performs sub-optimally for batches with non-identical GEMMs.
- Optimizations for batches with non-identical GEMMs is in progress.
- Added bench and input files for batch_matmul.
- Added logger functionality for batch_matmul APIs.

AMD-Internal: [SWLCSG-2944]
Change-Id: I83e26c1f30a5dd5a31139f6706ac74be0aa6bd9a
This commit is contained in:
Meghana Vankadari
2025-01-04 04:11:37 +05:30
committed by Nallani Bhaskar
parent ef4286a97e
commit 852cdc6a9a
17 changed files with 2947 additions and 481 deletions

View File

@@ -41,34 +41,22 @@
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_config.h"
#include "lpgemm_utils.h"
#include "lpgemm_logger.h"
AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
{
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
return; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
#ifdef LPGEMM_BF16_JIT
if( get_jit_kernels_generated() == FALSE )
{
bli_print_msg(" Could not generate bf16bf16f32of32 "
" kernels using JIT.", __FILE__, __LINE__ );
return;
}
#endif
trans_t blis_transa;
trans_t blis_transb;
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"bf16bf16f32of32", \
order, transa, transb, \
batch_size, m, n, k, \
( ( float* ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
( ( float* ) beta ), \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
@@ -88,6 +76,33 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
#ifdef LPGEMM_BF16_JIT
if( jit_kernels_generated == FALSE )
{
bli_print_msg(" Could not generate bf16bf16f32of32 "
" kernels using JIT.", __FILE__, __LINE__ );
goto err_hndl;
}
#endif
trans_t blis_transa;
trans_t blis_transb;
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
@@ -99,9 +114,13 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i]
c[bs_i], ldc[bs_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
@@ -136,7 +155,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
return;
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
@@ -182,7 +201,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
return;
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
@@ -221,7 +240,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) return;
if( err != BLIS_SUCCESS ) goto err_hndl;
}
@@ -258,35 +277,25 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
post_op_list, F32
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}
AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
{
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
return; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
#ifdef LPGEMM_BF16_JIT
if( get_jit_kernels_generated() == FALSE )
{
bli_print_msg(" Could not generate bf16bf16f32of32 "
" kernels using JIT.", __FILE__, __LINE__ );
return;
}
#endif
trans_t blis_transa;
trans_t blis_transb;
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"bf16bf16f32obf16", \
order, transa, transb, \
batch_size, m, n, k, \
( ( float* ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
( ( float* ) beta ), \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
@@ -306,6 +315,35 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
#ifdef LPGEMM_BF16_JIT
if( jit_kernels_generated == FALSE )
{
bli_print_msg(" Could not generate bf16bf16f32of32 "
" kernels using JIT.", __FILE__, __LINE__ );
goto err_hndl;
}
#endif
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
@@ -317,9 +355,15 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i]
c[bs_i], ldc[bs_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
@@ -354,7 +398,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
return;
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
@@ -400,7 +444,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
return;
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
@@ -439,7 +483,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) return;
if( err != BLIS_SUCCESS ) goto err_hndl;
}
@@ -476,4 +520,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
post_op_list, BF16
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -0,0 +1,446 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "aocl_gemm_interface_apis.h"
#include "aocl_gemm_check.h"
#include "lpgemm_types.h"
#include "lpgemm_post_ops.h"
#include "lpgemm_thread_decor_openmp.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_config.h"
#include "lpgemm_utils.h"
#include "lpgemm_logger.h"
AOCL_BGEMM_MATMUL(bfloat16,int8_t,float,float,bf16s4f32of32)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"bf16s4f32of32", \
order, transa, transb, \
batch_size, m, n, k, \
( ( float* ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
( ( float* ) beta ), \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
lpgemm_pre_op pre_op_list[batch_size][AOCL_MAX_PRE_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16s4f32 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
#ifdef LPGEMM_BF16_JIT
if( jit_kernels_generated == FALSE )
{
bli_print_msg(" Could not generate bf16bf16f32of32 "
" kernels using JIT.", __FILE__, __LINE__ );
goto err_hndl;
}
#endif
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_bf16s4f32of32",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
if( is_column_major == TRUE )
{
bli_print_msg("Column major inputs not supported.",
__FILE__, __LINE__);
goto err_hndl;
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
// Convert pre op struct to pre op linked list format.
err_t err = lpgemm_translate_to_pre_ops_list
(
post_op_unparsed[bs_i]->pre_ops,
pre_op_list[bs_i],
m[bs_i], n[bs_i], k[bs_i]
);
if (err != BLIS_SUCCESS) goto err_hndl;
// Convert post op struct to post op linked list format.
err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16S4F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_bf16s4f32of32_openmp_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
pre_op_list, post_op_list, F32
);
#else
batch_lpgemm_bf16s4f32of32_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
pre_op_list, post_op_list, F32
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}
AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"bf16s4f32obf16", \
order, transa, transb, \
batch_size, m, n, k, \
( ( float* ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
( ( float* ) beta ), \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
lpgemm_pre_op pre_op_list[batch_size][AOCL_MAX_PRE_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
#ifdef LPGEMM_BF16_JIT
if( jit_kernels_generated == FALSE )
{
bli_print_msg(" Could not generate bf16bf16f32of32 "
" kernels using JIT.", __FILE__, __LINE__ );
goto err_hndl;
}
#endif
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_bf16s4f32obf16",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
if( is_column_major == TRUE )
{
bli_print_msg("Column major inputs not supported.",
__FILE__, __LINE__);
goto err_hndl;
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
// Convert pre op struct to pre op linked list format.
err_t err = lpgemm_translate_to_pre_ops_list
(
post_op_unparsed[bs_i]->pre_ops,
pre_op_list[bs_i],
m[bs_i], n[bs_i], k[bs_i]
);
if (err != BLIS_SUCCESS) goto err_hndl;
// Convert post op struct to post op linked list format.
err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16S4F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_bf16s4f32of32_openmp_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(float**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
pre_op_list, post_op_list, BF16
);
#else
batch_lpgemm_bf16s4f32of32_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(float**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
pre_op_list, post_op_list, BF16
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -0,0 +1,280 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "aocl_gemm_interface_apis.h"
#include "aocl_gemm_check.h"
#include "lpgemm_types.h"
#include "lpgemm_post_ops.h"
#include "lpgemm_thread_decor_openmp.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_config.h"
#include "lpgemm_utils.h"
#include "lpgemm_logger.h"
AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"f32f32f32of32", \
order, transa, transb, \
batch_size, m, n, k, \
( ( float* ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
( ( float* ) beta ), \
ldc, post_op_unparsed \
);
trans_t blis_transa;
trans_t blis_transb;
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
float *a_local[batch_size], *b_local[batch_size];
dim_t m_local[batch_size], n_local[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if AVX2 ISA is supported, lpgemm fp32 matmul only works with it.
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
{
bli_print_msg(" AVX2 ISA not supported by processor, "
"cannot perform f32f32f32 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_f32f32f32of32",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
if( is_column_major == TRUE )
{
rs_a[bs_i] = ldb[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = ldb[bs_i];
}
rs_b[bs_i] = lda[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = lda[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
// Inputs swapped in column major, A becomes B from kernel point of view.
if ( bli_is_trans(blis_transb ) )
{
mtag_a[bs_i] = PACK;
}
// swap m & n in case of col-major matrices
m_local[bs_i] = n[bs_i];
n_local[bs_i] = m[bs_i];
// swap a & b pointers in case of col-major matrices
a_local[bs_i] = (float*)(b[bs_i]);
b_local[bs_i] = (float*)(a[bs_i]);
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
// copy the values of m & n
m_local[bs_i] = m[bs_i];
n_local[bs_i] = n[bs_i];
// copy the values of a & b pointers
a_local[bs_i] = (float*)(a[bs_i]);
b_local[bs_i] = (float*)(b[bs_i]);
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// By default enable packing for B matrix. Before the 5 loop, based on
// the input dimensions, the smart threading logic will adjust it
// (disable/enable) accordingly.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( F32F32F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_f32f32f32of32_openmp_thread_decorator
(
batch_size, m_local, n_local, k,
(const float**)a_local, rs_a, cs_a, mtag_a,
(const float**)b_local, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
#else
// Setting pack A and B by default for non open mp case.
bli_rntm_set_pack_a( 1, &rntm_g );
bli_rntm_set_pack_b( 1, &rntm_g );
batch_lpgemm_f32f32f32of32_thread_decorator
(
batch_size, m_local, n_local, k,
(const float**)a_local, rs_a, cs_a, mtag_a,
(const float**)b_local, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -0,0 +1,533 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "aocl_gemm_interface_apis.h"
#include "aocl_gemm_check.h"
#include "lpgemm_types.h"
#include "lpgemm_post_ops.h"
#include "lpgemm_thread_decor_openmp.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_config.h"
#include "lpgemm_utils.h"
#include "lpgemm_logger.h"
AOCL_BGEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"s8s8s32os32", \
order, transa, transb, \
batch_size, m, n, k, \
alpha, \
lda, mem_format_a, \
ldb, mem_format_b, \
beta, \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
int8_t *a_local[batch_size];
int8_t *b_local[batch_size];
dim_t m_local[batch_size], n_local[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
{
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
"cannot perform s8s8s32 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_s8s8s32os32",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ((order[bs_i] == 'c') || (order[bs_i] == 'C'));
if ( is_column_major == TRUE )
{
// Column major support disabled for int API's till micro-kernel
// post-ops are updated to account for column major.
if (post_op_unparsed[bs_i] != NULL )
{
bli_print_msg("Column major inputs not supported with Post-ops.",
__FILE__, __LINE__);
goto err_hndl;
}
rs_a[bs_i] = ldb[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = ldb[bs_i];
}
rs_b[bs_i] = lda[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = lda[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
// Inputs swapped in column major, A becomes B from kernel point of view.
if ( bli_is_trans(blis_transb ) )
{
mtag_a[bs_i] = PACK;
}
// swap m & n in case of col-major matrices
m_local[bs_i] = n[bs_i];
n_local[bs_i] = m[bs_i];
// swap a & b pointers in case of col-major matrices
a_local[bs_i] = (int8_t*)(b[bs_i]);
b_local[bs_i] = (int8_t*)(a[bs_i]);
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
// copy the values of m & n
m_local[bs_i] = m[bs_i];
n_local[bs_i] = n[bs_i];
// copy the values of a & b pointers
a_local[bs_i] = (int8_t*)(a[bs_i]);
b_local[bs_i] = (int8_t*)(b[bs_i]);
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_s8s8s32o32_openmp_thread_decorator
(
batch_size, m_local, n_local, k,
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S32
);
#else
batch_lpgemm_s8s8s32o32_thread_decorator
(
batch_size, m_local, n_local, k,
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S32
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}
AOCL_BGEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"s8s8s32os8", \
order, transa, transb, \
batch_size, m, n, k, \
alpha, \
lda, mem_format_a, \
ldb, mem_format_b, \
beta, \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
int8_t *a_local[batch_size];
int8_t *b_local[batch_size];
dim_t m_local[batch_size], n_local[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
{
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
"cannot perform s8s8s32 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_s8s8s32os8",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
if( is_column_major == TRUE )
{
// Column major support disabled for int API's till micro-kernel
// post-ops are updated to account for column major.
if (post_op_unparsed[bs_i] != NULL )
{
bli_print_msg("Column major inputs not supported with Post-ops.",
__FILE__, __LINE__);
goto err_hndl;
}
rs_a[bs_i] = ldb[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = ldb[bs_i];
}
rs_b[bs_i] = lda[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = lda[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
// Inputs swapped in column major, A becomes B from kernel point of view.
if ( bli_is_trans(blis_transb ) )
{
mtag_a[bs_i] = PACK;
}
// swap m & n in case of col-major matrices
m_local[bs_i] = n[bs_i];
n_local[bs_i] = m[bs_i];
// swap a & b pointers in case of col-major matrices
a_local[bs_i] = (int8_t*)(b[bs_i]);
b_local[bs_i] = (int8_t*)(a[bs_i]);
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
// copy the values of m & n
m_local[bs_i] = m[bs_i];
n_local[bs_i] = n[bs_i];
// copy the values of a & b pointers
a_local[bs_i] = (int8_t*)(a[bs_i]);
b_local[bs_i] = (int8_t*)(b[bs_i]);
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_s8s8s32o32_openmp_thread_decorator
(
batch_size, m_local, n_local, k,
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S8
);
#else
batch_lpgemm_s8s8s32o32_thread_decorator
(
batch_size, m_local, n_local, k,
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S8
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -0,0 +1,411 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "aocl_gemm_interface_apis.h"
#include "aocl_gemm_check.h"
#include "lpgemm_types.h"
#include "lpgemm_post_ops.h"
#include "lpgemm_thread_decor_openmp.h"
#include "lpgemm_5loop_interface_apis.h"
#include "lpgemm_config.h"
#include "lpgemm_utils.h"
#include "lpgemm_logger.h"
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"u8s8s32os32", \
order, transa, transb, \
batch_size, m, n, k, \
alpha, \
lda, mem_format_a, \
ldb, mem_format_b, \
beta, \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
{
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
"cannot perform u8s8s32 gemm.", __FILE__, __LINE__ );
goto err_hndl; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_u8s8s32os32",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
if( is_column_major == TRUE )
{
bli_print_msg("Column major inputs not supported.",
__FILE__, __LINE__);
goto err_hndl;
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_u8s8s32o32_openmp_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S32
);
#else
batch_lpgemm_u8s8s32o32_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S32
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"u8s8s32os8", \
order, transa, transb, \
batch_size, m, n, k, \
alpha, \
lda, mem_format_a, \
ldb, mem_format_b, \
beta, \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
{
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
"cannot perform u8s8s32 gemm.", __FILE__, __LINE__ );
goto err_hndl; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_u8s8s32os8",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
if( is_column_major == TRUE )
{
bli_print_msg("Column major inputs not supported.",
__FILE__, __LINE__);
goto err_hndl;
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_u8s8s32o32_openmp_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S8
);
#else
batch_lpgemm_u8s8s32o32_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, S8
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -9,14 +9,14 @@
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
@@ -58,190 +58,191 @@ AOCL_GEMM_MATMUL(bfloat16, int8_t, float, float, bf16s4f32of32)
ldc, post_op_unparsed \
);
trans_t blis_transa;
trans_t blis_transb;
trans_t blis_transa;
trans_t blis_transb;
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.",
__FILE__, __LINE__);
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.",
__FILE__, __LINE__);
goto err_hndl;
}
}
/* Initialize BLIS. */
bli_init_auto();
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
// check for validity of params.
// check for validity of params.
int err_no = 0;
AOCL_GEMM_CHECK(
"bf16s4f32of32",
order, transa, transb,
m, n, k,
a, lda, mem_format_a,
b, ldb, mem_format_b,
c, ldc, err_no);
AOCL_GEMM_CHECK(
"bf16s4f32of32",
order, transa, transb,
m, n, k,
a, lda, mem_format_a,
b, ldb, mem_format_b,
c, ldc, err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
bool is_row_major = ((order == 'r') || (order == 'R'));
bool is_column_major = ((order == 'c') || (order == 'C'));
bool is_row_major = ((order == 'r') || (order == 'R'));
bool is_column_major = ((order == 'c') || (order == 'C'));
// The strides are set assuming a row major kernel.
inc_t rs_a = lda;
inc_t cs_a = 1;
// The strides are set assuming a row major kernel.
inc_t rs_a = lda;
inc_t cs_a = 1;
if (bli_is_trans(blis_transa))
{
rs_a = 1;
cs_a = lda;
}
inc_t rs_b = ldb;
inc_t cs_b = 1;
if (bli_is_trans(blis_transa))
{
rs_a = 1;
cs_a = lda;
}
inc_t rs_b = ldb;
inc_t cs_b = 1;
if (bli_is_trans(blis_transb))
{
rs_b = 1;
cs_b = ldb;
}
const inc_t rs_c = ldc;
const inc_t cs_c = 1;
if (bli_is_trans(blis_transb))
{
rs_b = 1;
cs_b = ldb;
}
const inc_t rs_c = ldc;
const inc_t cs_c = 1;
AOCL_MEMORY_TAG mtag_a;
AOCL_MEMORY_TAG mtag_b;
AOCL_MEMORY_TAG mtag_a;
AOCL_MEMORY_TAG mtag_b;
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
// Reorder is not supported for A matrix
if ((is_row_major == TRUE) && (mtag_a == REORDERED))
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__);
// Reorder is not supported for A matrix
if ((is_row_major == TRUE) && (mtag_a == REORDERED))
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__);
goto err_hndl;
}
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
else if ((is_column_major == TRUE) && ((mtag_b == REORDERED) || (mtag_a == REORDERED)))
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__);
}
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
else if ((is_column_major == TRUE) && ((mtag_b == REORDERED) || (mtag_a == REORDERED)))
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__);
goto err_hndl;
}
}
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ((is_row_major == TRUE) && (mtag_b == UNPACKED))
{
mtag_b = PACK;
}
// Inputs swapped in column major, A becomes B from kernel point of view.
else if ((is_column_major == TRUE) && (mtag_a == UNPACKED))
{
mtag_a = PACK;
}
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ((is_row_major == TRUE) && (mtag_b == UNPACKED))
{
mtag_b = PACK;
}
// Inputs swapped in column major, A becomes B from kernel point of view.
else if ((is_column_major == TRUE) && (mtag_a == UNPACKED))
{
mtag_a = PACK;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if ((is_row_major == TRUE) && (bli_is_trans(blis_transa)))
{
mtag_a = PACK;
}
// Inputs swapped in column major, A becomes B from kernel point of view.
else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb)))
{
mtag_b = PACK;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if ((is_row_major == TRUE) && (bli_is_trans(blis_transa)))
{
mtag_a = PACK;
}
// Inputs swapped in column major, A becomes B from kernel point of view.
else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb)))
{
mtag_b = PACK;
}
// Convert post op struct to post op linked list format.
lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS];
err_t err = lpgemm_translate_to_pre_ops_list
(
post_op_unparsed->pre_ops,
pre_op_list,
m, n, k
);
if (err != BLIS_SUCCESS)
// Convert pre op struct to pre op linked list format.
lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS];
err_t err = lpgemm_translate_to_pre_ops_list
(
post_op_unparsed->pre_ops,
pre_op_list,
m, n, k
);
if (err != BLIS_SUCCESS)
{
goto err_hndl;
}
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed,
post_op_list,
(void *)c, (void *)(&order),
m, n
);
if (err != BLIS_SUCCESS)
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed,
post_op_list,
(void *)c, (void *)(&order),
m, n
);
if (err != BLIS_SUCCESS)
{
goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global(&rntm_g);
bli_pba_rntm_set_pba(&rntm_g);
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global(&rntm_g);
bli_pba_rntm_set_pba(&rntm_g);
lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(BF16S4F32OF32);
lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(BF16S4F32OF32);
#ifdef BLIS_ENABLE_OPENMP
if (is_column_major == TRUE)
{
// Swapping inputs not possible in case of mixed precision.
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
if (is_column_major == TRUE)
{
// Swapping inputs not possible in case of mixed precision.
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
goto err_hndl;
}
else
{
lpgemm_bf16s4f32of32_openmp_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g, pre_op_list,
post_op_list, F32
);
}
}
else
{
lpgemm_bf16s4f32of32_openmp_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g, pre_op_list,
post_op_list, F32
);
}
#else
// Swapping inputs to induce row major computation for column major inputs.
if (is_column_major == TRUE)
{
// Swapping inputs not possible in case of mixed precision.
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
// Swapping inputs to induce row major computation for column major inputs.
if (is_column_major == TRUE)
{
// Swapping inputs not possible in case of mixed precision.
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
goto err_hndl;
}
else
{
lpgemm_bf16s4f32of32_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g, pre_op_list,
post_op_list, F32
);
}
}
else
{
lpgemm_bf16s4f32of32_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g, pre_op_list,
post_op_list, F32
);
}
#endif
err_hndl:;
@@ -263,185 +264,188 @@ AOCL_GEMM_MATMUL(bfloat16, int8_t, bfloat16, float, bf16s4f32obf16)
ldc, post_op_unparsed \
);
trans_t blis_transa;
trans_t blis_transb;
trans_t blis_transa;
trans_t blis_transb;
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.",
__FILE__, __LINE__);
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.",
__FILE__, __LINE__);
goto err_hndl;
}
}
/* Initialize BLIS. */
bli_init_auto();
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
// check for validity of params.
// check for validity of params.
int err_no = 0;
AOCL_GEMM_CHECK(
"bf16s4f32obf16",
order, transa, transb,
m, n, k,
a, lda, mem_format_a,
b, ldb, mem_format_b,
c, ldc, err_no);
AOCL_GEMM_CHECK(
"bf16s4f32obf16",
order, transa, transb,
m, n, k,
a, lda, mem_format_a,
b, ldb, mem_format_b,
c, ldc, err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
bool is_row_major = ((order == 'r') || (order == 'R'));
bool is_column_major = ((order == 'c') || (order == 'C'));
bool is_row_major = ((order == 'r') || (order == 'R'));
bool is_column_major = ((order == 'c') || (order == 'C'));
// The strides are set assuming a row major kernel.
inc_t rs_a = lda;
inc_t cs_a = 1;
// The strides are set assuming a row major kernel.
inc_t rs_a = lda;
inc_t cs_a = 1;
if (bli_is_trans(blis_transa))
{
rs_a = 1;
cs_a = lda;
}
if (bli_is_trans(blis_transa))
{
rs_a = 1;
cs_a = lda;
}
inc_t rs_b = ldb;
inc_t cs_b = 1;
inc_t rs_b = ldb;
inc_t cs_b = 1;
if (bli_is_trans(blis_transb))
{
rs_b = 1;
cs_b = ldb;
}
const inc_t rs_c = ldc;
const inc_t cs_c = 1;
if (bli_is_trans(blis_transb))
{
rs_b = 1;
cs_b = ldb;
}
const inc_t rs_c = ldc;
const inc_t cs_c = 1;
AOCL_MEMORY_TAG mtag_a;
AOCL_MEMORY_TAG mtag_b;
AOCL_MEMORY_TAG mtag_a;
AOCL_MEMORY_TAG mtag_b;
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
// Reorder is not supported for A matrix
if ((is_row_major == TRUE) && (mtag_a == REORDERED))
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__);
// Reorder is not supported for A matrix
if ((is_row_major == TRUE) && (mtag_a == REORDERED))
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__);
goto err_hndl;
}
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
else if ((is_column_major == TRUE) && ((mtag_b == REORDERED) || (mtag_a == REORDERED)))
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__);
}
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
else if ((is_column_major == TRUE) && ((mtag_b == REORDERED) || (mtag_a == REORDERED)))
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__);
goto err_hndl;
}
}
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ((is_row_major == TRUE) && (mtag_b == UNPACKED))
{
mtag_b = PACK;
}
// Inputs swapped in column major, A becomes B from kernel point of view.
else if ((is_column_major == TRUE) && (mtag_a == UNPACKED))
{
mtag_a = PACK;
}
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ((is_row_major == TRUE) && (mtag_b == UNPACKED))
{
mtag_b = PACK;
}
// Inputs swapped in column major, A becomes B from kernel point of view.
else if ((is_column_major == TRUE) && (mtag_a == UNPACKED))
{
mtag_a = PACK;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if ((is_row_major == TRUE) && (bli_is_trans(blis_transa)))
{
mtag_a = PACK;
}
// Inputs swapped in column major, A becomes B from kernel point of view.
else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb)))
{
mtag_b = PACK;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if ((is_row_major == TRUE) && (bli_is_trans(blis_transa)))
{
mtag_a = PACK;
}
// Inputs swapped in column major, A becomes B from kernel point of view.
else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb)))
{
mtag_b = PACK;
}
// Convert post op struct to post op linked list format.
lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS];
err_t err = lpgemm_translate_to_pre_ops_list(
post_op_unparsed->pre_ops, pre_op_list,
m, n, k);
// Convert post op struct to post op linked list format.
lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS];
err_t err = lpgemm_translate_to_pre_ops_list(
post_op_unparsed->pre_ops, pre_op_list,
m, n, k);
if (err != BLIS_SUCCESS)
if (err != BLIS_SUCCESS)
{
goto err_hndl;
}
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err = lpgemm_translate_to_post_ops_list(
post_op_unparsed, post_op_list,
(void *)c, (void *)(&order),
m, n);
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err = lpgemm_translate_to_post_ops_list(
post_op_unparsed, post_op_list,
(void *)c, (void *)(&order),
m, n);
if (err != BLIS_SUCCESS)
if (err != BLIS_SUCCESS)
{
goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global(&rntm_g);
bli_pba_rntm_set_pba(&rntm_g);
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global(&rntm_g);
bli_pba_rntm_set_pba(&rntm_g);
lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(BF16S4F32OF32);
lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(BF16S4F32OF32);
#ifdef BLIS_ENABLE_OPENMP
// Swapping inputs to induce row major computation for column major inputs.
if (is_column_major == TRUE)
{
// Swapping inputs not possible in case of mixed precision.
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
// Swapping inputs to induce row major computation for column major inputs.
if (is_column_major == TRUE)
{
// Swapping inputs not possible in case of mixed precision.
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
goto err_hndl;
}
else
{
lpgemm_bf16s4f32of32_openmp_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(float *)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g, pre_op_list,
post_op_list, BF16
);
}
}
else
{
lpgemm_bf16s4f32of32_openmp_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(float *)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g, pre_op_list,
post_op_list, BF16
);
}
#else
// Swapping inputs to induce row major computation for column major inputs.
if (is_column_major == TRUE)
{
// Swapping inputs not possible in case of mixed precision.
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
// Swapping inputs to induce row major computation for column major inputs.
if (is_column_major == TRUE)
{
// Swapping inputs not possible in case of mixed precision.
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
goto err_hndl;
}
else
{
lpgemm_bf16s4f32of32_thread_decorator(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(float*)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g, pre_op_list,
post_op_list, BF16);
}
}
else
{
lpgemm_bf16s4f32of32_thread_decorator
(
m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(float*)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g, pre_op_list,
post_op_list, BF16
);
}
#endif
err_hndl:;

View File

@@ -109,7 +109,8 @@
m, n, k, \
a, lda, mtag_a, \
b, ldb, mtag_b, \
c, ldc \
c, ldc, \
err_no \
) \
{ \
int32_t info = 0; \
@@ -170,7 +171,7 @@
\
sprintf( print_msg, "** On entry to %6s, parameter number %2i of problem %ld had an illegal value", op_str, info, gemm_no); \
bli_print_msg(print_msg, __FILE__, __LINE__); \
return; \
err_no = info; \
} \
}

View File

@@ -177,5 +177,14 @@ BLIS_EXPORT_ADDON void aocl_batch_gemm_ ## LP_SFX \
AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32);
AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16);
AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32);
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32);
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8);
AOCL_BGEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32);
AOCL_BGEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8);
AOCL_BGEMM_MATMUL(bfloat16,int8_t,float,float,bf16s4f32of32);
AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16);
#endif // AOCL_GEMM_INTERFACE_H

View File

@@ -45,6 +45,11 @@ static bli_pthread_once_t once_check_lpgemm_logger_init = BLIS_PTHREAD_ONCE_INIT
static bool lpgemm_logger_enabled = FALSE;
bool is_logger_enabled()
{
return lpgemm_logger_enabled;
}
FILE* lpgemm_start_logger_fn(void)
{
lpgemm_init_logger();
@@ -83,7 +88,7 @@ void lpgemm_stop_logger_fn( FILE* fd )
ops_str_len += c_ops_str_len; \
} while ( 0 ); \
static void lpgemm_get_pre_ops_str( aocl_post_op* post_ops, char* ops_str )
void lpgemm_get_pre_ops_str( aocl_post_op* post_ops, char* ops_str )
{
if ( post_ops == NULL )
{
@@ -148,7 +153,7 @@ static void lpgemm_get_pre_ops_str( aocl_post_op* post_ops, char* ops_str )
}
}
static void lpgemm_get_post_ops_str( aocl_post_op* post_ops, char* ops_str )
void lpgemm_get_post_ops_str( aocl_post_op* post_ops, char* ops_str )
{
if ( ( post_ops == NULL ) || ( post_ops->seq_length <= 0 ) )
{
@@ -308,11 +313,54 @@ void lpgemm_write_logger_gemm_fn
}
}
void batch_lpgemm_write_logger_gemm_fn
(
FILE* fd,
char* op_type,
const char* order,
const char* transa,
const char* transb,
const dim_t batch_size,
const dim_t* m,
const dim_t* n,
const dim_t* k,
const float* alpha,
const dim_t* lda,
const char* mem_format_a,
const dim_t* ldb,
const char* mem_format_b,
const float* beta,
const dim_t* ldc,
aocl_post_op** post_op_unparsed
)
{
if ( ( lpgemm_logger_enabled == TRUE ) && ( fd != NULL ) )
{
char pre_ops_str[1024] = {0};
char post_ops_str[2048] = {0};
fprintf(fd, "%s:bs=%ld\n", op_type, batch_size);
for( dim_t i = 0; i < batch_size; i++ )
{
lpgemm_get_pre_ops_str( post_op_unparsed[i], pre_ops_str );
lpgemm_get_post_ops_str( post_op_unparsed[i], post_ops_str );
fprintf( fd, "%c %c %c %c %c %ld %ld %ld %ld %ld %ld "\
":pre_ops=[%s]:post_ops=[%s] %f %f\n",
order[i], transa[i], transb[i], mem_format_a[i], mem_format_b[i],
m[i], n[i], k[i], lda[i], ldb[i], ldc[i],
pre_ops_str, post_ops_str,
(float)(alpha[i]), (float)(beta[i]) );
}
}
}
void lpgemm_write_logger_time_break_fn( FILE* fd, double stime )
{
if ( ( lpgemm_logger_enabled == TRUE ) && ( fd != NULL ) )
{
fprintf( fd, "%f \n", stime );
fprintf( fd, "time:%f \n", stime );
}
}

View File

@@ -42,6 +42,9 @@
FILE* lpgemm_start_logger_fn(void);
void lpgemm_stop_logger_fn( FILE* fd );
void lpgemm_get_post_ops_str( aocl_post_op* post_ops, char* ops_str );
void lpgemm_get_pre_ops_str( aocl_post_op* post_ops, char* ops_str );
bool is_logger_enabled();
void lpgemm_write_logger_gemm_fn
(
FILE* fd,
@@ -61,6 +64,26 @@ void lpgemm_write_logger_gemm_fn
const dim_t ldc,
aocl_post_op* post_op_unparsed
);
void batch_lpgemm_write_logger_gemm_fn
(
FILE* fd,
char* op_type,
const char* order,
const char* transa,
const char* transb,
const dim_t batch_size,
const dim_t* m,
const dim_t* n,
const dim_t* k,
const float* alpha,
const dim_t* lda,
const char* mem_format_a,
const dim_t* ldb,
const char* mem_format_b,
const float* beta,
const dim_t* ldc,
aocl_post_op** post_op_unparsed
);
void lpgemm_write_logger_time_break_fn( FILE* fd, double stime );
#define LPGEMM_START_LOGGER() \
@@ -81,6 +104,33 @@ void lpgemm_write_logger_time_break_fn( FILE* fd, double stime );
#define LPGEMM_WRITE_LOGGER(...) \
lpgemm_write_logger_gemm_fn( fd, __VA_ARGS__ ); \
#define BATCH_LPGEMM_WRITE_LOGGER( op_type, order, transa, transb, \
batch_size, m, n, k, \
alpha, lda, mem_format_a, \
ldb, mem_format_b, beta, \
ldc, post_op_unparsed ) \
{ \
if ( ( is_logger_enabled() ) && ( fd != NULL ) ) \
{ \
char pre_ops_str[1024] = {0}; \
\
char post_ops_str[2048] = {0}; \
\
fprintf(fd, "%s:bs=%ld\n", op_type, batch_size); \
for( dim_t i = 0; i < batch_size; i++ ) \
{ \
lpgemm_get_pre_ops_str( post_op_unparsed[i], pre_ops_str ); \
lpgemm_get_post_ops_str( post_op_unparsed[i], post_ops_str ); \
fprintf( fd, "%c %c %c %c %c %ld %ld %ld %ld %ld %ld "\
":pre_ops=[%s]:post_ops=[%s] %f %f\n", \
order[i], transa[i], transb[i], mem_format_a[i], mem_format_b[i], \
m[i], n[i], k[i], lda[i], ldb[i], ldc[i], \
pre_ops_str, post_ops_str, \
(float)(alpha[i]), (float)(beta[i]) ); \
} \
} \
}
#else
#define LPGEMM_START_LOGGER(...)
@@ -89,6 +139,12 @@ void lpgemm_write_logger_time_break_fn( FILE* fd, double stime );
#define LPGEMM_WRITE_LOGGER(...)
#define BATCH_LPGEMM_WRITE_LOGGER(op_type, order, transa, transb, \
batch_size, m, n, k, \
alpha, lda, mem_format_a, \
ldb, mem_format_b, beta, \
ldc, post_op_unparsed)
#endif
void lpgemm_init_logger();

View File

@@ -72,35 +72,6 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32);
LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32);
LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16);
#define BATCH_LPGEMM_5LOOP(A_type,B_type,C_type,LP_SFX) \
void batch_lpgemm_rowvar_ ## LP_SFX \
( \
const dim_t m, \
const dim_t n, \
const dim_t k, \
const A_type** a, \
const dim_t rs_a, \
const dim_t cs_a, \
const AOCL_MEMORY_TAG mtag_a, \
const B_type** b, \
dim_t rs_b, \
dim_t cs_b, \
AOCL_MEMORY_TAG mtag_b, \
C_type** c, \
const dim_t rs_c, \
const dim_t cs_c, \
const C_type alpha, \
const C_type beta, \
rntm_t* rntm, \
lpgemm_thrinfo_t* thread, \
lpgemm_cntx_t* lcntx, \
lpgemm_post_op* post_op_list, \
AOCL_STORAGE_TYPE c_downscale \
) \
BATCH_LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32);
#define LPGEMM_5LOOP1(A_type,B_type,C_type,LP_SFX) \
void lpgemm_rowvar_ ## LP_SFX \
( \

View File

@@ -43,6 +43,23 @@
#define BLIS_LPGEMM_NUM_STATIC_COMMS 96
BLIS_INLINE void calculate_n_threads_per_gemm
(
dim_t batch_size,
dim_t* n_threads,
dim_t* n_gemms_in_parallel,
dim_t* n_threads_per_gemm,
rntm_t* rntm_g
)
{
*n_threads = bli_rntm_num_threads( rntm_g ); \
*n_gemms_in_parallel = -1; \
if( *n_threads == 1 ) *n_gemms_in_parallel = 1; \
else if( *n_gemms_in_parallel < 1 ) *n_gemms_in_parallel = bli_min(*n_threads, batch_size); \
/* ToDo: All the leftover thrads might go under-utilized. Could be optimized further. */ \
*n_threads_per_gemm = ( *n_threads ) / *n_gemms_in_parallel;
}
BLIS_INLINE dim_t next_factor
(
const dim_t nt,
@@ -437,6 +454,119 @@ BLIS_INLINE void lpgemm_s32o32_get_threading
}
}
BLIS_INLINE void batch_lpgemm_s32o32_get_threading
(
dim_t batch_size,
dim_t* n_threads,
dim_t* n_gemms_in_parallel,
dim_t* n_threads_per_gemm,
dim_t* ic_ways,
dim_t* jc_ways,
dim_t m,
dim_t n,
dim_t k,
rntm_t* rntm_g,
AOCL_OPERATION_TYPE op_type
)
{
calculate_n_threads_per_gemm(batch_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
if ( ( *n_threads_per_gemm ) > 1 )
{
dim_t NR = lpgemm_get_block_size_NR_global_cntx( op_type );
dim_t MR = lpgemm_get_block_size_MR_global_cntx( op_type );
dim_t mr_blks = ( m + MR - 1 ) / MR;
dim_t nr_blks = ( n + NR - 1 ) / NR;
if ( n <= NR )
{
( *ic_ways ) = ( *n_threads_per_gemm );
( *jc_ways ) = 1;
( *n_threads_per_gemm ) = ( *ic_ways ) * ( *jc_ways );
}
else if ( m <= MR )
{
( *jc_ways ) = ( *n_threads_per_gemm );
( *ic_ways ) = 1;
( *n_threads_per_gemm ) = ( *ic_ways ) * ( *jc_ways );
}
else
{
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
bli_thread_partition_2x2( ( *n_threads_per_gemm ), m, n, ic_ways, jc_ways );
if ( ( mr_blks >= ( *ic_ways ) ) && ( nr_blks >= ( *jc_ways ) ) )
{
lpgemm_pnl_wrk_heur_adjust_ic_jc_ways
(
MR, NR, m, n,
n_threads_per_gemm, ic_ways, jc_ways
);
}
}
( *n_threads ) = ( *n_gemms_in_parallel ) * ( *ic_ways ) * ( *jc_ways );
}
else
{
// Setting all the values to 1 in case n_threads <= 1. This ensures
// the threading parameters are valid.
*n_threads = 1;
*n_gemms_in_parallel = 1;
*n_threads_per_gemm = 1;
*jc_ways = 1;
*ic_ways = 1;
}
}
BLIS_INLINE void batch_lpgemm_u8s8s32o32_get_threading
(
dim_t batch_size,
dim_t* n_threads,
dim_t* n_gemms_in_parallel,
dim_t* n_threads_per_gemm,
dim_t* ic_ways,
dim_t* jc_ways,
dim_t m,
dim_t n,
dim_t k,
rntm_t* rntm_g
)
{
batch_lpgemm_s32o32_get_threading
(
batch_size,
n_threads, n_gemms_in_parallel, n_threads_per_gemm,
ic_ways, jc_ways,
m, n, k, rntm_g,
U8S8S32OS32
);
}
BLIS_INLINE void batch_lpgemm_s8s8s32o32_get_threading
(
dim_t batch_size,
dim_t* n_threads,
dim_t* n_gemms_in_parallel,
dim_t* n_threads_per_gemm,
dim_t* ic_ways,
dim_t* jc_ways,
dim_t m,
dim_t n,
dim_t k,
rntm_t* rntm_g
)
{
batch_lpgemm_s32o32_get_threading
(
batch_size,
n_threads, n_gemms_in_parallel, n_threads_per_gemm,
ic_ways, jc_ways,
m, n, k, rntm_g,
S8S8S32OS32
);
}
BLIS_INLINE void lpgemm_u8s8s32o32_get_threading
(
dim_t* n_threads,
@@ -544,23 +674,6 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
}
}
BLIS_INLINE void calculate_n_threads_per_gemm
(
dim_t batch_size,
dim_t* n_threads,
dim_t* n_gemms_in_parallel,
dim_t* n_threads_per_gemm,
rntm_t* rntm_g
)
{
*n_threads = bli_rntm_num_threads( rntm_g ); \
*n_gemms_in_parallel = -1; \
if( *n_threads == 1 ) *n_gemms_in_parallel = 1; \
else if( *n_gemms_in_parallel < 1 ) *n_gemms_in_parallel = bli_min(*n_threads, batch_size); \
/* ToDo: All the leftover thrads might go under-utilized. Could be optimized further. */ \
*n_threads_per_gemm = ( *n_threads ) / *n_gemms_in_parallel;
}
BLIS_INLINE void batch_lpgemm_bf16bf16f32of32_get_threading
(
dim_t batch_size,
@@ -723,6 +836,99 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
}
}
}
BLIS_INLINE void batch_lpgemm_f32f32f32of32_get_threading
(
dim_t batch_size,
dim_t* n_threads,
dim_t* n_gemms_in_parallel,
dim_t* n_threads_per_gemm,
dim_t* ic_ways,
dim_t* jc_ways,
dim_t m,
dim_t n,
dim_t k,
rntm_t* rntm_g
)
{
calculate_n_threads_per_gemm(batch_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
// Query the context for SUP limits.
const dim_t MT = lpgemm_get_sup_thres_MT_global_cntx( F32F32F32OF32 );
const dim_t NT = lpgemm_get_sup_thres_NT_global_cntx( F32F32F32OF32 );
const dim_t KT = lpgemm_get_sup_thres_KT_global_cntx( F32F32F32OF32 );
// Query the context for various blocksizes.
dim_t NR = lpgemm_get_block_size_NR_global_cntx( F32F32F32OF32 );
dim_t MR = lpgemm_get_block_size_MR_global_cntx( F32F32F32OF32 );
dim_t MC = lpgemm_get_block_size_MC_global_cntx( F32F32F32OF32 );
dim_t NC = lpgemm_get_block_size_NC_global_cntx( F32F32F32OF32 );
dim_t KC = lpgemm_get_block_size_KC_global_cntx( F32F32F32OF32 );
const dim_t MT_2 = MT / 2;
/* The user is not allowed to set ic_ways or jc_ways */
if ( ( *n_threads_per_gemm ) > 1 )
{
dim_t mr_blks = ( m + MR - 1 ) / MR;
dim_t nr_blks = ( n + NR - 1 ) / NR;
if ( n <= NR )
{
( *ic_ways ) = ( *n_threads_per_gemm );
( *jc_ways ) = 1;
( *n_threads_per_gemm ) = ( *ic_ways ) * ( *jc_ways );
}
else if ( m <= MR )
{
( *jc_ways ) = ( *n_threads_per_gemm );
( *ic_ways ) = 1;
( *n_threads_per_gemm ) = ( *ic_ways ) * ( *jc_ways );
}
else
{
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
bli_thread_partition_2x2( ( *n_threads_per_gemm ), m, n, ic_ways, jc_ways );
if ( ( mr_blks >= ( *ic_ways ) ) && ( nr_blks >= ( *jc_ways ) ) )
{
lpgemm_adjust_ic_jc_ways
(
m, n, k,
MC, NC, KC, MR, NR,
n_threads_per_gemm, ic_ways, jc_ways, 5
);
}
}
( *n_threads ) = ( *n_gemms_in_parallel ) * ( *ic_ways ) * ( *jc_ways );
}
else
{
// Setting all the values to 1 in case n_threads <= 1. This ensures
// the threading parameters are valid.
*n_threads = 1;
*n_gemms_in_parallel = 1;
*n_threads_per_gemm = 1;
*jc_ways = 1;
*ic_ways = 1;
}
// Native -> SUP path.
const dim_t m_ic = m / ( *ic_ways );
const dim_t n_jc = n / ( *jc_ways );
const dim_t page_size = bli_info_get_page_size();
const dim_t page_size_b_floatx2 =
2 * ( page_size / sizeof( float ) );
if ( ( m >= MT ) && ( n >= NT ) && ( k >= KT ) )
{
if (((k >= page_size_b_floatx2) && (m_ic > MT_2) && (n_jc >= NT)) ||
((bli_cpuid_is_avx512_supported() == FALSE) && (k > page_size_b_floatx2)))
{
bli_rntm_set_pack_b( 1, rntm_g );
bli_rntm_set_pack_a( 1, rntm_g );
}
}
}
#define GEN_LPGEMM_OPENMP_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \
void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
@@ -944,11 +1150,14 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
} \
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32)
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32)
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32)
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32)
#define GEN_LPGEMM_OPENMP_DECORATOR_MP(A_type,B_type,C_type,LPGEMM_SFX) \
void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
( \
const dim_t m, \
const dim_t m, \
const dim_t n, \
const dim_t k, \
const A_type* a, \
@@ -1048,6 +1257,141 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
GEN_LPGEMM_OPENMP_DECORATOR_MP(bfloat16, int8_t, float, bf16s4f32of32)
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_MP(A_type,B_type,C_type,LPGEMM_SFX, LPGEMM_PARENT_SFX) \
void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
( \
const dim_t batch_size, \
const dim_t* m, \
const dim_t* n, \
const dim_t* k, \
const A_type** a, \
const dim_t* rs_a, \
const dim_t* cs_a, \
const AOCL_MEMORY_TAG* mtag_a, \
const B_type** b, \
const dim_t* rs_b, \
const dim_t* cs_b, \
AOCL_MEMORY_TAG* mtag_b, \
C_type** c, \
const dim_t* rs_c, \
const dim_t* cs_c, \
const C_type* alpha, \
const C_type* beta, \
rntm_t* rntm_g, \
lpgemm_cntx_t* lcntx, \
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
AOCL_STORAGE_TYPE c_downscale \
) \
{ \
/* For now, Assuming all the problems in GEMM are of same size.
* To-Do: optimize work distribution for case where a batch contains
GEMM problems of different sizes.
*/ \
dim_t n_threads; \
/* Factorization of threads along m and n dimension respectively.*/ \
dim_t ic_ways; \
dim_t jc_ways; \
dim_t n_gemms_in_parallel; \
dim_t n_threads_per_gemm; \
\
/* Assuming all the problems in GEMM are of same size */ \
batch_lpgemm_ ## LPGEMM_PARENT_SFX ## _get_threading \
( \
batch_size, \
&n_threads, \
&n_gemms_in_parallel, \
&n_threads_per_gemm, \
&ic_ways, &jc_ways, \
m[0], n[0], k[0], rntm_g \
); \
\
/* Set the packing block allocator field of the rntm. This will be
* inherited by all of the child threads when they make local copies of
* the rntm below.*/ \
bli_pba_rntm_set_pba( rntm_g ); \
\
thrcomm_t static_lpgemm_comms[BLIS_LPGEMM_NUM_STATIC_COMMS]; \
thrcomm_t* cur_lpgemm_comms = static_lpgemm_comms; \
err_t bli_errors = BLIS_SUCCESS; \
\
if ( jc_ways * n_gemms_in_parallel > BLIS_LPGEMM_NUM_STATIC_COMMS ) \
{ \
cur_lpgemm_comms = bli_malloc_intl( jc_ways * n_gemms_in_parallel * \
sizeof( thrcomm_t ), &bli_errors ); \
} \
for( dim_t i = 0; i < n_gemms_in_parallel * jc_ways; i++ ) \
{ \
bli_thrcomm_init( ic_ways, &cur_lpgemm_comms[i] ); \
} \
\
dim_t MC = lpgemm_get_block_size_MC_global_cntx( BF16BF16F32OF32 ); \
\
_Pragma( "omp parallel num_threads(n_threads)" ) \
{ \
/* Create a thread-local copy of the master thread's rntm_t. This is
* necessary since we want each thread to be able to track its own
* small block pool_t as it executes down the function stack.*/ \
rntm_t rntm_l = *rntm_g; \
\
/* lpgemm_thrinfo_t object will be used to generate thrinfo_t objects
* for use in blis mt framework inside the respective mat mul driver
* functions.*/ \
lpgemm_thrinfo_t thread; \
thread.n_threads = n_threads_per_gemm; \
thread.tid = omp_get_thread_num() % n_threads_per_gemm; \
thread.ic_ways = ic_ways; \
thread.jc_ways = jc_ways; \
thread.comm = cur_lpgemm_comms + (omp_get_thread_num() / n_threads_per_gemm) * jc_ways; \
\
dim_t gemm_start; \
dim_t gemm_end; \
\
/* This structure is filled only to calculate workload distribution of GEMM problems
among threads. This struct is not passed to 5-loop.
*/ \
thrinfo_t thrinfo; \
thrinfo.n_way = n_gemms_in_parallel; \
thrinfo.work_id = omp_get_thread_num() / n_threads_per_gemm; \
bli_thread_range_sub( &thrinfo, batch_size, 1, FALSE, &gemm_start, &gemm_end ); \
\
for( dim_t i = gemm_start; i < gemm_end; i++ ) \
{ \
/* Decide whether to go with pack-based implementation
or kernel-level implementation */ \
if( ( m[i] / ic_ways ) > MC ) \
{ \
mtag_b[i] = PACK_KC; \
} \
else \
{ \
mtag_b[i] = UNPACKED; \
} \
\
lpgemm_rowvar_ ## LPGEMM_SFX \
( \
m[i], n[i], k[i], \
a[i], rs_a[i], cs_a[i], mtag_a[i], \
b[i], rs_b[i], cs_b[i], mtag_b[i], \
c[i], rs_c[i], cs_c[i],\
alpha[i], \
beta[i], \
&rntm_l, \
&thread, \
lcntx, \
pre_op_list[i], \
post_op_list[i], c_downscale \
); \
} \
} \
if ( jc_ways * n_gemms_in_parallel > BLIS_LPGEMM_NUM_STATIC_COMMS ) \
{ \
bli_free_intl( cur_lpgemm_comms ); \
} \
} \
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_MP(bfloat16, int8_t, float, bf16s4f32of32, bf16bf16f32of32)
BLIS_INLINE void lpgemm_eltwise_ops_bf16of32_get_threading
(
dim_t* n_threads,
@@ -1468,7 +1812,84 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
} \
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32)
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32)
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32)
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32)
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_MP(A_type,B_type,C_type,LPGEMM_SFX) \
void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
( \
const dim_t batch_size, \
const dim_t* m, \
const dim_t* n, \
const dim_t* k, \
const A_type** a, \
const dim_t* rs_a, \
const dim_t* cs_a, \
const AOCL_MEMORY_TAG* mtag_a, \
const B_type** b, \
const dim_t* rs_b, \
const dim_t* cs_b, \
const AOCL_MEMORY_TAG* mtag_b, \
C_type** c, \
const dim_t* rs_c, \
const dim_t* cs_c, \
const C_type* alpha, \
const C_type* beta, \
rntm_t* rntm_g, \
lpgemm_cntx_t* lcntx, \
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
AOCL_STORAGE_TYPE c_downscale \
) \
{ \
dim_t n_threads = 1; \
\
/* Factorization of threads along m and n dimension respectively.*/ \
dim_t ic_ways = 1; \
dim_t jc_ways = 1; \
\
/* Set the packing block allocator field of the rntm. This will be
* inherited by all of the child threads when they make local copies of
* the rntm below.*/ \
bli_pba_rntm_set_pba( rntm_g ); \
\
thrcomm_t static_lpgemm_comm; \
thrcomm_t* cur_lpgemm_comm = &static_lpgemm_comm; \
\
bli_thrcomm_init( ic_ways, cur_lpgemm_comm ); \
/* lpgemm_thrinfo_t object will be used to generate thrinfo_t objects
* for use in blis mt framework inside the respective mat mul driver
* functions.*/ \
lpgemm_thrinfo_t thread; \
thread.n_threads = n_threads; \
thread.tid = 0; \
thread.ic_ways = ic_ways; \
thread.jc_ways = jc_ways; \
thread.comm = cur_lpgemm_comm; \
dim_t gemm_start = 0; \
dim_t gemm_end = batch_size; \
\
for( dim_t i = gemm_start; i < gemm_end; i++ ) \
{ \
lpgemm_rowvar_ ## LPGEMM_SFX \
( \
m[i], n[i], k[i], \
a[i], rs_a[i], cs_a[i], mtag_a[i], \
b[i], rs_b[i], cs_b[i], mtag_b[i], \
c[i], rs_c[i], cs_c[i],\
alpha[i], \
beta[i], \
rntm_g, \
&thread, \
lcntx, \
pre_op_list[i], \
post_op_list[i], c_downscale \
); \
} \
} \
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_MP(bfloat16,int8_t,float,bf16s4f32of32)
#define GEN_UTIL_ELTWISE_OPS_DECORATOR(A_type,B_type,LPGEMM_SFX) \
void lpgemm_eltwise_ops_ ## LPGEMM_SFX ## _thread_decorator \

View File

@@ -101,6 +101,39 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
); \
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32)
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32)
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32)
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32)
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN_MXP(A_type,B_type,C_type,LPGEMM_SFX) \
void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
( \
const dim_t batch_size, \
const dim_t* m, \
const dim_t* n, \
const dim_t* k, \
const A_type** a, \
const dim_t* rs_a, \
const dim_t* cs_a, \
const AOCL_MEMORY_TAG* mtag_a, \
const B_type** b, \
const dim_t* rs_b, \
const dim_t* cs_b, \
AOCL_MEMORY_TAG* mtag_b, \
C_type** c, \
const dim_t* rs_c, \
const dim_t* cs_c, \
const C_type* alpha, \
const C_type* beta, \
rntm_t* rntm_g, \
lpgemm_cntx_t* lcntx, \
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
AOCL_STORAGE_TYPE c_downscale \
); \
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN_MXP(bfloat16,int8_t,float,bf16s4f32of32)
#define GEN_LPGEMM_OPENMP_DECORATOR_FN1(A_type,B_type,C_type,LPGEMM_SFX) \
@@ -116,7 +149,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
const B_type* b, \
const dim_t rs_b, \
const dim_t cs_b, \
const AOCL_MEMORY_TAG mtag_b, \
AOCL_MEMORY_TAG mtag_b, \
C_type* c, \
const dim_t rs_c, \
const dim_t cs_c, \
@@ -213,6 +246,38 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
); \
GEN_BATCH_LPGEMM_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32)
GEN_BATCH_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32)
GEN_BATCH_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32)
GEN_BATCH_LPGEMM_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32)
#define GEN_BATCH_LPGEMM_DECORATOR_FN_MP(A_type,B_type,C_type,LPGEMM_SFX) \
void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
( \
const dim_t bs, \
const dim_t* m, \
const dim_t* n, \
const dim_t* k, \
const A_type** a, \
const dim_t* rs_a, \
const dim_t* cs_a, \
const AOCL_MEMORY_TAG* mtag_a, \
const B_type** b, \
const dim_t* rs_b, \
const dim_t* cs_b, \
const AOCL_MEMORY_TAG* mtag_b, \
C_type** c, \
const dim_t* rs_c, \
const dim_t* cs_c, \
const C_type* alpha, \
const C_type* beta, \
rntm_t* rntm_g, \
lpgemm_cntx_t* lcntx, \
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
AOCL_STORAGE_TYPE c_downscale \
); \
GEN_BATCH_LPGEMM_DECORATOR_FN_MP(bfloat16,int8_t,float,bf16s4f32of32)
#define GEN_LPGEMM_DECORATOR_FN1(A_type,B_type,C_type,LPGEMM_SFX) \

View File

@@ -1,6 +1,17 @@
bf16bf16f32obf16:bs=5
r t n n p 83 3847 1930 83 3847 3847 scale=vector,zp=scalar,relu,clip
c n t n n 50 1297 707 50 1297 50 scale=vector,zp=vector,bias=na,clip
c n n n n 12 128 1605 12 1605 12 scale=vector,zp=vector
r n t n r 63 95 1319 1319 1319 95 scale=vector,zp=scalar,relu,clip
r n n n r 65 2283 911 911 2283 2283 bias=na,swish
*:bs=5
r t t n n 92 1479 589 92 589 1479 scale=vector,zp=vector,bias=na,clip
r n n n r 67 21 1823 1823 21 21 scale=vector,zp=scalar,relu,clip
r n t n n 43 2240 1553 1553 1553 2240 scale=vector,zp=scalar,relu,clip
r t n n p 143 1943 730 143 1943 1943 bias=na,swish
r n n n r 79 2676 1995 1995 2676 2676 bias=na,swish
bf16s4f32of32:bs=4
r t n n r 43 1110 271 43 1110 1110 scale=vector,zp=vector,bias=na,clip
r t n n r 79 1177 1968 79 1177 1177 scale=vector,zp=scalar,relu,clip
r n t n r 92 2872 1482 1482 1482 2872 scale=vector,zp=vector,bias=na,clip
r n t n r 88 3397 1130 1130 1130 3397 scale=vector,zp=vector
bf16s4f32obf16:bs=5
r n n n r 17 2714 468 468 2714 2714 scale=vector,zp=vector,bias=na,clip
r n n n r 140 3764 1519 1519 3764 3764 scale=vector,zp=vector
r n t n r 17 1758 1034 1034 1034 1758 scale=vector,zp=vector,bias=na,clip
r n n n r 130 1822 1293 1293 1822 1822 scale=vector,zp=vector
r t t n r 21 2771 1882 21 1882 2771 bias=na,swish

View File

@@ -111,6 +111,13 @@ void mat_mul_ ## BLAS_SFX \
GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32)
GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
GEN_BLIS_MAT_MUL_FUNC(float,float,float,float,f32f32f32of32)
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
double get_gflops
(
@@ -211,6 +218,13 @@ void mat_mul_bench_driver_ ## BLAS_SFX \
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,float,f32f32f32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
@@ -1230,6 +1244,13 @@ void mat_mul_bench_main_ ## BLAS_SFX \
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32,bf16bf16f32of32,bf16s4f32of32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16,bf16bf16f32of32,bf16s4f32of32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,float,f32f32f32of32,f32f32f32of32,bf16s4f32of32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32,u8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8,u8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32,s8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8,s8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32,bf16bf16f32of32,bf16s4f32of32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16,bf16bf16f32of32,bf16s4f32of32)
int main( int argc, char** argv )
{
@@ -1428,6 +1449,148 @@ int main( int argc, char** argv )
post_ops_str_dest, FALSE
);
}
if ( ( strcmp( gemm_type_str, "f32f32f32of32" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
for( dim_t i = 0; i < bs; i++ )
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
global_can_dscale = 'y';
global_dscale_out = 'n';
global_pre_op = 'n';
GEN_FUNC_NAME(mat_mul_bench_main_,f32f32f32of32)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, FALSE
);
}
if ( ( strcmp( gemm_type_str, "u8s8s32os32" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
// Copy the original post op str to a temp string buffer.
// Done so that strtok can be applied on the same (strtok
// is a destructive parser.
for( dim_t i = 0; i < bs; i++ )
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
global_dscale_out = 'n';
global_pre_op = 'n';
DSCALE_CLIP_MIN = INT_MIN;
DSCALE_CLIP_MAX = INT_MAX;
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, FALSE
);
}
if ( ( strcmp( gemm_type_str, "u8s8s32os8" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
// Copy the original post op str to a temp string buffer.
// Done so that strtok can be applied on the same (strtok
// is a destructive parser.
for( dim_t i = 0; i < bs; i++ )
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
global_dscale_out = 'y';
global_pre_op = 'n';
DSCALE_CLIP_MIN = -128;
DSCALE_CLIP_MAX = +127;
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os8)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, FALSE
);
}
if ( ( strcmp( gemm_type_str, "s8s8s32os32" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
// Copy the original post op str to a temp string buffer.
// Done so that strtok can be applied on the same (strtok
// is a destructive parser.
for( dim_t i = 0; i < bs; i++ )
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
global_dscale_out = 'n';
global_pre_op = 'n';
DSCALE_CLIP_MIN = INT_MIN;
DSCALE_CLIP_MAX = INT_MAX;
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os32)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, FALSE
);
}
if ( ( strcmp( gemm_type_str, "s8s8s32os8" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
// Copy the original post op str to a temp string buffer.
// Done so that strtok can be applied on the same (strtok
// is a destructive parser.
for( dim_t i = 0; i < bs; i++ )
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
global_dscale_out = 'y';
global_pre_op = 'n';
DSCALE_CLIP_MIN = -128;
DSCALE_CLIP_MAX = +127;
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os8)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, FALSE
);
}
if ( strcmp( gemm_type_str, "bf16s4f32of32" ) == 0 )
{
// Copy the original post op str to a temp string buffer.
// Done so that strtok can be applied on the same (strtok
// is a destructive parser.
for( dim_t i = 0; i < bs; i++ )
{
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
if ( ( op_b[i] != 'r' ) && ( op_b[i] != 'R' ) )
{
printf("Int4 B matrix only permitted if B reodering "
"is enabled.\n");
goto skip_exec;
}
}
global_dscale_out = 'n';
global_pre_op = 'y';
GEN_FUNC_NAME(mat_mul_bench_main_, bf16s4f32of32)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, TRUE
);
}
if ( strcmp( gemm_type_str, "bf16s4f32obf16" ) == 0 )
{
// Copy the original post op str to a temp string buffer.
// Done so that strtok can be applied on the same (strtok
// is a destructive parser.
for( dim_t i = 0; i < bs; i++ )
{
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
if ( ( op_b[i] != 'r' ) && ( op_b[i] != 'R' ) )
{
printf("Int4 B matrix only permitted if B reodering "
"is enabled.\n");
goto skip_exec;
}
}
global_dscale_out = 'y';
global_pre_op = 'y';
GEN_FUNC_NAME(mat_mul_bench_main_, bf16s4f32obf16)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, TRUE
);
}
skip_exec:
}
}

View File

@@ -6671,7 +6671,7 @@ POST_OPS_DOWNSCALE_5x2F:
post_ops_attr.post_op_c_i + 2 ) );
zero_point3 = _mm_set1_ps( *( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 3 ) );
zero_point0 = _mm_set1_ps( *( (float* )post_ops_list_temp->op_args1 +
zero_point4 = _mm_set1_ps( *( (float* )post_ops_list_temp->op_args1 +
post_ops_attr.post_op_c_i + 4 ) );
}
//c[0, 0-3]

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -141,7 +141,7 @@ void packb_nr64_s8s8s32os32
else
{
packb_nr64_s8s8s32os32_col_major(pack_b_buffer_s8s8s32o32,
pack_b_column_sum, b,
pack_b_column_sum, b,
cs_b, NC, KC, rs_p, cs_p);
}
}
@@ -198,7 +198,7 @@ void packb_nr64_s8s8s32os32_row_major
{
//load the temp buffer to compute column sum of B matrix
sum1 = _mm512_loadu_si512( pack_b_column_sum + jc );
sum2 = _mm512_loadu_si512( pack_b_column_sum + 16 + jc );
sum2 = _mm512_loadu_si512( pack_b_column_sum + 16 + jc );
//offset 16- as 16 int32 elements fit in 1 zmm register
sum3 = _mm512_loadu_si512( pack_b_column_sum + 32 + jc );
sum4 = _mm512_loadu_si512( pack_b_column_sum + 48 + jc );
@@ -212,28 +212,28 @@ void packb_nr64_s8s8s32os32_row_major
d0 = _mm512_loadu_si512( b + ( ldb * ( kr + 3 ) ) + jc );
//add all the columns : sum = add (sum, a0, b0, c0, d0)
sum1 =
sum1 =
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 0)),
_mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 0)),
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 0))))) , mul_128));
sum2 =
sum2 =
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 1)),
_mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 1)),
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 1))))) , mul_128));
sum3 =
sum3 =
_mm512_add_epi32 ( sum3, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 2)),
_mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 2)),
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 2))))) , mul_128));
sum4 =
sum4 =
_mm512_add_epi32 ( sum4, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 3)),
@@ -262,13 +262,13 @@ void packb_nr64_s8s8s32os32_row_major
a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1]
c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3]
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
( ( jc * KC_updated ) + ( ( kr + 0 ) * NR ) ), a01 );
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
( ( jc * KC_updated ) + ( ( kr + 1 ) * NR ) ) , a0 );
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
( ( jc * KC_updated ) + ( ( kr + 2 ) * NR ) ), c01 );
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
( ( jc * KC_updated ) + ( ( kr + 3 ) * NR ) ), c0 );
}
// Handle k remainder.
@@ -282,25 +282,25 @@ void packb_nr64_s8s8s32os32_row_major
d0 = _mm512_setzero_si512();
//add all the columns : sum = add (sum, a0, b0, c0)
sum1 =
sum1 =
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 0)),
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 0)))), mul_128));
sum2 =
sum2 =
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 1)),
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 1)))), mul_128));
sum3 =
sum3 =
_mm512_add_epi32 ( sum3, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 2)),
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 2)))), mul_128));
sum4 =
sum4 =
_mm512_add_epi32 ( sum4, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 3)),
@@ -315,22 +315,22 @@ void packb_nr64_s8s8s32os32_row_major
d0 = _mm512_setzero_si512();
//add all the columns : sum = add (sum, a0, b0)
sum1 =
sum1 =
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)),
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 0))), mul_128));
sum2 =
sum2 =
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)),
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 1))), mul_128));
sum3 =
sum3 =
_mm512_add_epi32 ( sum3, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)),
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 2))), mul_128));
sum4 =
sum4 =
_mm512_add_epi32 ( sum4, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)),
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 3))), mul_128));
@@ -378,13 +378,13 @@ void packb_nr64_s8s8s32os32_row_major
a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1]
c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3]
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
( ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) ), a01 );
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
( ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) ) , a0 );
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
( ( jc * KC_updated ) + ( ( k_full_pieces + 2 ) * NR ) ), c01 );
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
( ( jc * KC_updated ) + ( ( k_full_pieces + 3 ) * NR ) ), c0 );
}
//store the sum column
@@ -506,14 +506,14 @@ void packb_nr48_s8s8s32os32_row_major
d0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 3 ) ) );
//add all the columns : sum = add (sum, a0, b0, c0, d0)
sum1 =
sum1 =
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)),
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( d0_32, 0))))) , mul_128));
sum2 =
sum2 =
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)),
@@ -553,11 +553,11 @@ void packb_nr48_s8s8s32os32_row_major
d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 3 ) ) + ( 32 ) );
//add all the columns : sum = add (sum, a0_32, b0_32, c0_32, d0_32)
sum3 =
_mm512_add_epi32
( sum3, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ),
sum3 =
_mm512_add_epi32
( sum3, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ),
_mm512_cvtepi8_epi32( d0_16 )))) , mul_128 )
);
@@ -595,7 +595,7 @@ void packb_nr48_s8s8s32os32_row_major
d0_32 = _mm256_setzero_si256();
//add all the columns : sum = add (sum, a0, b0, c0)
sum1 =
sum1 =
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)),
@@ -612,10 +612,10 @@ void packb_nr48_s8s8s32os32_row_major
c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) + ( 32 ) );
d0_16 = _mm_setzero_si128();
sum3 =
_mm512_add_epi32
sum3 =
_mm512_add_epi32
( sum3, _mm512_sllv_epi32 (_mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
_mm512_cvtepi8_epi32( c0_16 ))) , mul_128)
);
}
@@ -627,12 +627,12 @@ void packb_nr48_s8s8s32os32_row_major
d0_32 = _mm256_setzero_si256();
//add all the columns : sum = add (sum, a0, b0)
sum1 =
sum1 =
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)),
_mm512_cvtepi8_epi32( _mm256_extracti32x4_epi32 ( b0_32, 0) )) , mul_128 ));
sum2 =
sum2 =
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)),
_mm512_cvtepi8_epi32( _mm256_extracti32x4_epi32 ( b0_32, 1) )) , mul_128 ));
@@ -643,7 +643,7 @@ void packb_nr48_s8s8s32os32_row_major
d0_16 = _mm_setzero_si128();
sum3 =
_mm512_add_epi32
_mm512_add_epi32
( sum3, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
_mm512_cvtepi8_epi32( b0_16 )) , mul_128)
);
@@ -656,11 +656,11 @@ void packb_nr48_s8s8s32os32_row_major
d0_32 = _mm256_setzero_si256();
//add all the columns : sum = add (sum, a0, b0)
sum1 =
sum1 =
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)) , mul_128));
sum2 =
sum2 =
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)) , mul_128));
@@ -669,7 +669,7 @@ void packb_nr48_s8s8s32os32_row_major
c0_16 = _mm_setzero_si128();
d0_16 = _mm_setzero_si128();
sum3 =
sum3 =
_mm512_add_epi32 ( sum3, _mm512_sllv_epi32 (
_mm512_cvtepi8_epi32( a0_16 ) , mul_128));
}
@@ -767,14 +767,14 @@ void packb_nr32_s8s8s32os32_row_major
d0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 3 ) ) );
//add all the columns : sum = add (sum, a0, b0, c0, d0)
sum1 =
sum1 =
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)),
_mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)),
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( d0_32, 0))))) , mul_128));
sum2 =
sum2 =
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)),
@@ -822,13 +822,13 @@ void packb_nr32_s8s8s32os32_row_major
d0_32 = _mm256_setzero_si256();
//add all the columns : sum = add (sum, a0, b0, c0)
sum1 =
sum1 =
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)),
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)))) , mul_128));
sum2 =
sum2 =
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)),
@@ -843,12 +843,12 @@ void packb_nr32_s8s8s32os32_row_major
d0_32 = _mm256_setzero_si256();
//add all the columns : sum = add (sum, a0, b0)
sum1 =
sum1 =
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)),
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0))) , mul_128));
sum2 =
sum2 =
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)),
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1))) , mul_128));
@@ -861,7 +861,7 @@ void packb_nr32_s8s8s32os32_row_major
d0_32 = _mm256_setzero_si256();
//add all the columns : sum = add (sum, a0, b0)
sum1 =
sum1 =
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)) , mul_128));
@@ -941,10 +941,10 @@ void packb_nr16_s8s8s32os32_row_major
d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 3 ) ) );
//add all the columns : sum = add (sum, a0, b0, c0, d0)
sum1 =
_mm512_add_epi32
sum1 =
_mm512_add_epi32
( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ),
_mm512_cvtepi8_epi32( d0_16 )))) , mul_128 )
);
@@ -983,9 +983,9 @@ void packb_nr16_s8s8s32os32_row_major
d0_16 = _mm_setzero_si128();
sum1 =
_mm512_add_epi32
_mm512_add_epi32
( sum1, _mm512_sllv_epi32 (_mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
_mm512_cvtepi8_epi32( c0_16 ))) , mul_128)
);
}
@@ -996,8 +996,8 @@ void packb_nr16_s8s8s32os32_row_major
c0_16 = _mm_setzero_si128();
d0_16 = _mm_setzero_si128();
sum1 =
_mm512_add_epi32
sum1 =
_mm512_add_epi32
( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
_mm512_cvtepi8_epi32( b0_16 )) , mul_128)
);
@@ -1009,9 +1009,9 @@ void packb_nr16_s8s8s32os32_row_major
c0_16 = _mm_setzero_si128();
d0_16 = _mm_setzero_si128();
sum1 =
_mm512_add_epi32
( sum1,
sum1 =
_mm512_add_epi32
( sum1,
_mm512_sllv_epi32 ( _mm512_cvtepi8_epi32( a0_16 ) , mul_128 )
);
}
@@ -1090,11 +1090,11 @@ void packb_nrlt16_s8s8s32os32_row_major
d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf3 );
//add all the columns : sum = add (sum, a0, b0, c0, d0)
sum1 =
_mm512_add_epi32
sum1 =
_mm512_add_epi32
( sum1,
_mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ),
_mm512_cvtepi8_epi32( d0_16 )))) , mul_128 )
);
@@ -1118,7 +1118,7 @@ void packb_nrlt16_s8s8s32os32_row_major
// Last 4x16 elements.
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm );
// The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the
// The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the
// original data, but is here due to the packing in 4 16byte chunks format.
kr_new += 1;
}
@@ -1138,10 +1138,10 @@ void packb_nrlt16_s8s8s32os32_row_major
d0_16 = _mm_setzero_si128();
sum1 =
_mm512_add_epi32
( sum1,
_mm512_add_epi32
( sum1,
_mm512_sllv_epi32 (_mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
_mm512_cvtepi8_epi32( c0_16 ))) , mul_128)
);
@@ -1156,8 +1156,8 @@ void packb_nrlt16_s8s8s32os32_row_major
c0_16 = _mm_setzero_si128();
d0_16 = _mm_setzero_si128();
sum1 =
_mm512_add_epi32
sum1 =
_mm512_add_epi32
( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
_mm512_cvtepi8_epi32( b0_16 )) , mul_128)
);
@@ -1171,7 +1171,7 @@ void packb_nrlt16_s8s8s32os32_row_major
c0_16 = _mm_setzero_si128();
d0_16 = _mm_setzero_si128();
sum1 =
sum1 =
_mm512_add_epi32
( sum1,
_mm512_sllv_epi32 ( _mm512_cvtepi8_epi32( a0_16 ) , mul_128 )
@@ -1399,7 +1399,7 @@ void packb_nr64_s8s8s32os32_col_major
}
*rs_p = NR * 4;
*cs_p = NR / 4;
*cs_p = NR;
}
//Extract 16 8-bit elements from each 128-bit lane of 512-bit register and convert them into
@@ -1891,7 +1891,7 @@ void packb_nrlt16_s8s8s32o32_col_major
}
// sum/reduce < 16 (max 15) int32 values into one final sum as int.
// insert sum of all columns into one 512 bit, multiply with 128 and
// insert sum of all columns into one 512 bit, multiply with 128 and
// store into pack_b_column_sum
__m512i sum0, sum1;
sum0 = _mm512_set_epi32