mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Implemented batch_matmul for f32 & int8 datatypes
Details: - The batch matmul performs a series of matmuls, processing more than one GEMM problem at once. - Introduced a new parameter called batch_size for the user to indicate number of GEMM problems in a batch/group. - This operation supports processing GEMM problems with different parameters including dims,post-ops,stor-schemes etc., - This operation is optimized for problems where all the GEMMs in a batch are of same size and shape. - For now, the threads are distributed among different GEMM problems equally irrespective of their dimensions which leads to better performance for batches with identical GEMMs but performs sub-optimally for batches with non-identical GEMMs. - Optimizations for batches with non-identical GEMMs is in progress. - Added bench and input files for batch_matmul. - Added logger functionality for batch_matmul APIs. AMD-Internal: [SWLCSG-2944] Change-Id: I83e26c1f30a5dd5a31139f6706ac74be0aa6bd9a
This commit is contained in:
committed by
Nallani Bhaskar
parent
ef4286a97e
commit
852cdc6a9a
@@ -41,34 +41,22 @@
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
{
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
|
||||
return; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
if( get_jit_kernels_generated() == FALSE )
|
||||
{
|
||||
bli_print_msg(" Could not generate bf16bf16f32of32 "
|
||||
" kernels using JIT.", __FILE__, __LINE__ );
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"bf16bf16f32of32", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
( ( float* ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
( ( float* ) beta ), \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
@@ -88,6 +76,33 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
if( jit_kernels_generated == FALSE )
|
||||
{
|
||||
bli_print_msg(" Could not generate bf16bf16f32of32 "
|
||||
" kernels using JIT.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
#endif
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
int err_no = 0;
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
@@ -99,9 +114,13 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i]
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
@@ -136,7 +155,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
return;
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
@@ -182,7 +201,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
return;
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
@@ -221,7 +240,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) return;
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
@@ -258,35 +277,25 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
post_op_list, F32
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
{
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
|
||||
return; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
if( get_jit_kernels_generated() == FALSE )
|
||||
{
|
||||
bli_print_msg(" Could not generate bf16bf16f32of32 "
|
||||
" kernels using JIT.", __FILE__, __LINE__ );
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"bf16bf16f32obf16", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
( ( float* ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
( ( float* ) beta ), \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
@@ -306,6 +315,35 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
if( jit_kernels_generated == FALSE )
|
||||
{
|
||||
bli_print_msg(" Could not generate bf16bf16f32of32 "
|
||||
" kernels using JIT.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
#endif
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
@@ -317,9 +355,15 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i]
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
@@ -354,7 +398,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
return;
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
@@ -400,7 +444,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
return;
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
@@ -439,7 +483,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) return;
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
@@ -476,4 +520,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
post_op_list, BF16
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
446
addon/aocl_gemm/aocl_batch_gemm_bf16s4f32of32.c
Normal file
446
addon/aocl_gemm/aocl_batch_gemm_bf16s4f32of32.c
Normal file
@@ -0,0 +1,446 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "aocl_gemm_interface_apis.h"
|
||||
#include "aocl_gemm_check.h"
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_post_ops.h"
|
||||
#include "lpgemm_thread_decor_openmp.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
AOCL_BGEMM_MATMUL(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"bf16s4f32of32", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
( ( float* ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
( ( float* ) beta ), \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
lpgemm_pre_op pre_op_list[batch_size][AOCL_MAX_PRE_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16s4f32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
if( jit_kernels_generated == FALSE )
|
||||
{
|
||||
bli_print_msg(" Could not generate bf16bf16f32of32 "
|
||||
" kernels using JIT.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
#endif
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_bf16s4f32of32",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// Convert pre op struct to pre op linked list format.
|
||||
err_t err = lpgemm_translate_to_pre_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i]->pre_ops,
|
||||
pre_op_list[bs_i],
|
||||
m[bs_i], n[bs_i], k[bs_i]
|
||||
);
|
||||
if (err != BLIS_SUCCESS) goto err_hndl;
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16S4F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_bf16s4f32of32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
pre_op_list, post_op_list, F32
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_bf16s4f32of32_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
pre_op_list, post_op_list, F32
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"bf16s4f32obf16", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
( ( float* ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
( ( float* ) beta ), \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
lpgemm_pre_op pre_op_list[batch_size][AOCL_MAX_PRE_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
if( jit_kernels_generated == FALSE )
|
||||
{
|
||||
bli_print_msg(" Could not generate bf16bf16f32of32 "
|
||||
" kernels using JIT.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
#endif
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_bf16s4f32obf16",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// Convert pre op struct to pre op linked list format.
|
||||
err_t err = lpgemm_translate_to_pre_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i]->pre_ops,
|
||||
pre_op_list[bs_i],
|
||||
m[bs_i], n[bs_i], k[bs_i]
|
||||
);
|
||||
if (err != BLIS_SUCCESS) goto err_hndl;
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16S4F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_bf16s4f32of32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(float**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
pre_op_list, post_op_list, BF16
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_bf16s4f32of32_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(float**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
pre_op_list, post_op_list, BF16
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
280
addon/aocl_gemm/aocl_batch_gemm_f32f32f32of32.c
Normal file
280
addon/aocl_gemm/aocl_batch_gemm_f32f32f32of32.c
Normal file
@@ -0,0 +1,280 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "aocl_gemm_interface_apis.h"
|
||||
#include "aocl_gemm_check.h"
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_post_ops.h"
|
||||
#include "lpgemm_thread_decor_openmp.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"f32f32f32of32", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
( ( float* ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
( ( float* ) beta ), \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
float *a_local[batch_size], *b_local[batch_size];
|
||||
dim_t m_local[batch_size], n_local[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if AVX2 ISA is supported, lpgemm fp32 matmul only works with it.
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX2 ISA not supported by processor, "
|
||||
"cannot perform f32f32f32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_f32f32f32of32",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
rs_a[bs_i] = ldb[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = lda[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
|
||||
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
if ( bli_is_trans(blis_transb ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// swap m & n in case of col-major matrices
|
||||
m_local[bs_i] = n[bs_i];
|
||||
n_local[bs_i] = m[bs_i];
|
||||
|
||||
// swap a & b pointers in case of col-major matrices
|
||||
a_local[bs_i] = (float*)(b[bs_i]);
|
||||
b_local[bs_i] = (float*)(a[bs_i]);
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[bs_i] = m[bs_i];
|
||||
n_local[bs_i] = n[bs_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[bs_i] = (float*)(a[bs_i]);
|
||||
b_local[bs_i] = (float*)(b[bs_i]);
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// By default enable packing for B matrix. Before the 5 loop, based on
|
||||
// the input dimensions, the smart threading logic will adjust it
|
||||
// (disable/enable) accordingly.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( F32F32F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_f32f32f32of32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const float**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const float**)b_local, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
// Setting pack A and B by default for non open mp case.
|
||||
bli_rntm_set_pack_a( 1, &rntm_g );
|
||||
bli_rntm_set_pack_b( 1, &rntm_g );
|
||||
|
||||
batch_lpgemm_f32f32f32of32_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const float**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const float**)b_local, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
533
addon/aocl_gemm/aocl_batch_gemm_s8s8s32os32.c
Normal file
533
addon/aocl_gemm/aocl_batch_gemm_s8s8s32os32.c
Normal file
@@ -0,0 +1,533 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "aocl_gemm_interface_apis.h"
|
||||
#include "aocl_gemm_check.h"
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_post_ops.h"
|
||||
#include "lpgemm_thread_decor_openmp.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"s8s8s32os32", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
alpha, \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
beta, \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
int8_t *a_local[batch_size];
|
||||
int8_t *b_local[batch_size];
|
||||
dim_t m_local[batch_size], n_local[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
|
||||
"cannot perform s8s8s32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_s8s8s32os32",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ((order[bs_i] == 'c') || (order[bs_i] == 'C'));
|
||||
|
||||
if ( is_column_major == TRUE )
|
||||
{
|
||||
// Column major support disabled for int API's till micro-kernel
|
||||
// post-ops are updated to account for column major.
|
||||
if (post_op_unparsed[bs_i] != NULL )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported with Post-ops.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
rs_a[bs_i] = ldb[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = lda[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
|
||||
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
if ( bli_is_trans(blis_transb ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// swap m & n in case of col-major matrices
|
||||
m_local[bs_i] = n[bs_i];
|
||||
n_local[bs_i] = m[bs_i];
|
||||
|
||||
// swap a & b pointers in case of col-major matrices
|
||||
a_local[bs_i] = (int8_t*)(b[bs_i]);
|
||||
b_local[bs_i] = (int8_t*)(a[bs_i]);
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[bs_i] = m[bs_i];
|
||||
n_local[bs_i] = n[bs_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[bs_i] = (int8_t*)(a[bs_i]);
|
||||
b_local[bs_i] = (int8_t*)(b[bs_i]);
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_s8s8s32o32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S32
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_s8s8s32o32_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S32
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"s8s8s32os8", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
alpha, \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
beta, \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
int8_t *a_local[batch_size];
|
||||
int8_t *b_local[batch_size];
|
||||
dim_t m_local[batch_size], n_local[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
|
||||
"cannot perform s8s8s32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_s8s8s32os8",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
// Column major support disabled for int API's till micro-kernel
|
||||
// post-ops are updated to account for column major.
|
||||
if (post_op_unparsed[bs_i] != NULL )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported with Post-ops.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
rs_a[bs_i] = ldb[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = lda[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
|
||||
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
if ( bli_is_trans(blis_transb ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// swap m & n in case of col-major matrices
|
||||
m_local[bs_i] = n[bs_i];
|
||||
n_local[bs_i] = m[bs_i];
|
||||
|
||||
// swap a & b pointers in case of col-major matrices
|
||||
a_local[bs_i] = (int8_t*)(b[bs_i]);
|
||||
b_local[bs_i] = (int8_t*)(a[bs_i]);
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[bs_i] = m[bs_i];
|
||||
n_local[bs_i] = n[bs_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[bs_i] = (int8_t*)(a[bs_i]);
|
||||
b_local[bs_i] = (int8_t*)(b[bs_i]);
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_s8s8s32o32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S8
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_s8s8s32o32_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S8
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
411
addon/aocl_gemm/aocl_batch_gemm_u8s8s32os32.c
Normal file
411
addon/aocl_gemm/aocl_batch_gemm_u8s8s32os32.c
Normal file
@@ -0,0 +1,411 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "aocl_gemm_interface_apis.h"
|
||||
#include "aocl_gemm_check.h"
|
||||
#include "lpgemm_types.h"
|
||||
#include "lpgemm_post_ops.h"
|
||||
#include "lpgemm_thread_decor_openmp.h"
|
||||
#include "lpgemm_5loop_interface_apis.h"
|
||||
#include "lpgemm_config.h"
|
||||
#include "lpgemm_utils.h"
|
||||
#include "lpgemm_logger.h"
|
||||
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"u8s8s32os32", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
alpha, \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
beta, \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
|
||||
"cannot perform u8s8s32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_u8s8s32os32",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_u8s8s32o32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S32
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_u8s8s32o32_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S32
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"u8s8s32os8", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
alpha, \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
beta, \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
|
||||
"cannot perform u8s8s32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_u8s8s32os8",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_u8s8s32o32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S8
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_u8s8s32o32_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, S8
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
@@ -9,14 +9,14 @@
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
@@ -58,190 +58,191 @@ AOCL_GEMM_MATMUL(bfloat16, int8_t, float, float, bf16s4f32of32)
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.",
|
||||
__FILE__, __LINE__);
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
// check for validity of params.
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
AOCL_GEMM_CHECK(
|
||||
"bf16s4f32of32",
|
||||
order, transa, transb,
|
||||
m, n, k,
|
||||
a, lda, mem_format_a,
|
||||
b, ldb, mem_format_b,
|
||||
c, ldc, err_no);
|
||||
AOCL_GEMM_CHECK(
|
||||
"bf16s4f32of32",
|
||||
order, transa, transb,
|
||||
m, n, k,
|
||||
a, lda, mem_format_a,
|
||||
b, ldb, mem_format_b,
|
||||
c, ldc, err_no
|
||||
);
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
|
||||
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
|
||||
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
|
||||
|
||||
bool is_row_major = ((order == 'r') || (order == 'R'));
|
||||
bool is_column_major = ((order == 'c') || (order == 'C'));
|
||||
bool is_row_major = ((order == 'r') || (order == 'R'));
|
||||
bool is_column_major = ((order == 'c') || (order == 'C'));
|
||||
|
||||
// The strides are set assuming a row major kernel.
|
||||
inc_t rs_a = lda;
|
||||
inc_t cs_a = 1;
|
||||
// The strides are set assuming a row major kernel.
|
||||
inc_t rs_a = lda;
|
||||
inc_t cs_a = 1;
|
||||
|
||||
if (bli_is_trans(blis_transa))
|
||||
{
|
||||
rs_a = 1;
|
||||
cs_a = lda;
|
||||
}
|
||||
inc_t rs_b = ldb;
|
||||
inc_t cs_b = 1;
|
||||
if (bli_is_trans(blis_transa))
|
||||
{
|
||||
rs_a = 1;
|
||||
cs_a = lda;
|
||||
}
|
||||
inc_t rs_b = ldb;
|
||||
inc_t cs_b = 1;
|
||||
|
||||
if (bli_is_trans(blis_transb))
|
||||
{
|
||||
rs_b = 1;
|
||||
cs_b = ldb;
|
||||
}
|
||||
const inc_t rs_c = ldc;
|
||||
const inc_t cs_c = 1;
|
||||
if (bli_is_trans(blis_transb))
|
||||
{
|
||||
rs_b = 1;
|
||||
cs_b = ldb;
|
||||
}
|
||||
const inc_t rs_c = ldc;
|
||||
const inc_t cs_c = 1;
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a;
|
||||
AOCL_MEMORY_TAG mtag_b;
|
||||
AOCL_MEMORY_TAG mtag_a;
|
||||
AOCL_MEMORY_TAG mtag_b;
|
||||
|
||||
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
|
||||
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
|
||||
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
|
||||
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if ((is_row_major == TRUE) && (mtag_a == REORDERED))
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__);
|
||||
// Reorder is not supported for A matrix
|
||||
if ((is_row_major == TRUE) && (mtag_a == REORDERED))
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
else if ((is_column_major == TRUE) && ((mtag_b == REORDERED) || (mtag_a == REORDERED)))
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__);
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
else if ((is_column_major == TRUE) && ((mtag_b == REORDERED) || (mtag_a == REORDERED)))
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
}
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ((is_row_major == TRUE) && (mtag_b == UNPACKED))
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
else if ((is_column_major == TRUE) && (mtag_a == UNPACKED))
|
||||
{
|
||||
mtag_a = PACK;
|
||||
}
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ((is_row_major == TRUE) && (mtag_b == UNPACKED))
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
else if ((is_column_major == TRUE) && (mtag_a == UNPACKED))
|
||||
{
|
||||
mtag_a = PACK;
|
||||
}
|
||||
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if ((is_row_major == TRUE) && (bli_is_trans(blis_transa)))
|
||||
{
|
||||
mtag_a = PACK;
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb)))
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if ((is_row_major == TRUE) && (bli_is_trans(blis_transa)))
|
||||
{
|
||||
mtag_a = PACK;
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb)))
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS];
|
||||
err_t err = lpgemm_translate_to_pre_ops_list
|
||||
(
|
||||
post_op_unparsed->pre_ops,
|
||||
pre_op_list,
|
||||
m, n, k
|
||||
);
|
||||
if (err != BLIS_SUCCESS)
|
||||
// Convert pre op struct to pre op linked list format.
|
||||
lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS];
|
||||
err_t err = lpgemm_translate_to_pre_ops_list
|
||||
(
|
||||
post_op_unparsed->pre_ops,
|
||||
pre_op_list,
|
||||
m, n, k
|
||||
);
|
||||
if (err != BLIS_SUCCESS)
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed,
|
||||
post_op_list,
|
||||
(void *)c, (void *)(&order),
|
||||
m, n
|
||||
);
|
||||
if (err != BLIS_SUCCESS)
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed,
|
||||
post_op_list,
|
||||
(void *)c, (void *)(&order),
|
||||
m, n
|
||||
);
|
||||
if (err != BLIS_SUCCESS)
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global(&rntm_g);
|
||||
bli_pba_rntm_set_pba(&rntm_g);
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global(&rntm_g);
|
||||
bli_pba_rntm_set_pba(&rntm_g);
|
||||
|
||||
lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(BF16S4F32OF32);
|
||||
lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(BF16S4F32OF32);
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
|
||||
if (is_column_major == TRUE)
|
||||
{
|
||||
// Swapping inputs not possible in case of mixed precision.
|
||||
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
|
||||
if (is_column_major == TRUE)
|
||||
{
|
||||
// Swapping inputs not possible in case of mixed precision.
|
||||
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_bf16s4f32of32_openmp_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g, pre_op_list,
|
||||
post_op_list, F32
|
||||
);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_bf16s4f32of32_openmp_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g, pre_op_list,
|
||||
post_op_list, F32
|
||||
);
|
||||
}
|
||||
#else
|
||||
// Swapping inputs to induce row major computation for column major inputs.
|
||||
if (is_column_major == TRUE)
|
||||
{
|
||||
// Swapping inputs not possible in case of mixed precision.
|
||||
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
|
||||
// Swapping inputs to induce row major computation for column major inputs.
|
||||
if (is_column_major == TRUE)
|
||||
{
|
||||
// Swapping inputs not possible in case of mixed precision.
|
||||
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_bf16s4f32of32_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g, pre_op_list,
|
||||
post_op_list, F32
|
||||
);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_bf16s4f32of32_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g, pre_op_list,
|
||||
post_op_list, F32
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
@@ -263,185 +264,188 @@ AOCL_GEMM_MATMUL(bfloat16, int8_t, bfloat16, float, bf16s4f32obf16)
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.",
|
||||
__FILE__, __LINE__);
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if (bli_cpuid_is_avx512bf16_supported() == FALSE)
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
// check for validity of params.
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
AOCL_GEMM_CHECK(
|
||||
"bf16s4f32obf16",
|
||||
order, transa, transb,
|
||||
m, n, k,
|
||||
a, lda, mem_format_a,
|
||||
b, ldb, mem_format_b,
|
||||
c, ldc, err_no);
|
||||
AOCL_GEMM_CHECK(
|
||||
"bf16s4f32obf16",
|
||||
order, transa, transb,
|
||||
m, n, k,
|
||||
a, lda, mem_format_a,
|
||||
b, ldb, mem_format_b,
|
||||
c, ldc, err_no
|
||||
);
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
|
||||
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(transa, &blis_transa);
|
||||
bli_param_map_netlib_to_blis_trans(transb, &blis_transb);
|
||||
|
||||
bool is_row_major = ((order == 'r') || (order == 'R'));
|
||||
bool is_column_major = ((order == 'c') || (order == 'C'));
|
||||
bool is_row_major = ((order == 'r') || (order == 'R'));
|
||||
bool is_column_major = ((order == 'c') || (order == 'C'));
|
||||
|
||||
// The strides are set assuming a row major kernel.
|
||||
inc_t rs_a = lda;
|
||||
inc_t cs_a = 1;
|
||||
// The strides are set assuming a row major kernel.
|
||||
inc_t rs_a = lda;
|
||||
inc_t cs_a = 1;
|
||||
|
||||
if (bli_is_trans(blis_transa))
|
||||
{
|
||||
rs_a = 1;
|
||||
cs_a = lda;
|
||||
}
|
||||
if (bli_is_trans(blis_transa))
|
||||
{
|
||||
rs_a = 1;
|
||||
cs_a = lda;
|
||||
}
|
||||
|
||||
inc_t rs_b = ldb;
|
||||
inc_t cs_b = 1;
|
||||
inc_t rs_b = ldb;
|
||||
inc_t cs_b = 1;
|
||||
|
||||
if (bli_is_trans(blis_transb))
|
||||
{
|
||||
rs_b = 1;
|
||||
cs_b = ldb;
|
||||
}
|
||||
const inc_t rs_c = ldc;
|
||||
const inc_t cs_c = 1;
|
||||
if (bli_is_trans(blis_transb))
|
||||
{
|
||||
rs_b = 1;
|
||||
cs_b = ldb;
|
||||
}
|
||||
const inc_t rs_c = ldc;
|
||||
const inc_t cs_c = 1;
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a;
|
||||
AOCL_MEMORY_TAG mtag_b;
|
||||
AOCL_MEMORY_TAG mtag_a;
|
||||
AOCL_MEMORY_TAG mtag_b;
|
||||
|
||||
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
|
||||
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
|
||||
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
|
||||
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if ((is_row_major == TRUE) && (mtag_a == REORDERED))
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__);
|
||||
// Reorder is not supported for A matrix
|
||||
if ((is_row_major == TRUE) && (mtag_a == REORDERED))
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
else if ((is_column_major == TRUE) && ((mtag_b == REORDERED) || (mtag_a == REORDERED)))
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__);
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
else if ((is_column_major == TRUE) && ((mtag_b == REORDERED) || (mtag_a == REORDERED)))
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
}
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ((is_row_major == TRUE) && (mtag_b == UNPACKED))
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
else if ((is_column_major == TRUE) && (mtag_a == UNPACKED))
|
||||
{
|
||||
mtag_a = PACK;
|
||||
}
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ((is_row_major == TRUE) && (mtag_b == UNPACKED))
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
else if ((is_column_major == TRUE) && (mtag_a == UNPACKED))
|
||||
{
|
||||
mtag_a = PACK;
|
||||
}
|
||||
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if ((is_row_major == TRUE) && (bli_is_trans(blis_transa)))
|
||||
{
|
||||
mtag_a = PACK;
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb)))
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if ((is_row_major == TRUE) && (bli_is_trans(blis_transa)))
|
||||
{
|
||||
mtag_a = PACK;
|
||||
}
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb)))
|
||||
{
|
||||
mtag_b = PACK;
|
||||
}
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS];
|
||||
err_t err = lpgemm_translate_to_pre_ops_list(
|
||||
post_op_unparsed->pre_ops, pre_op_list,
|
||||
m, n, k);
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS];
|
||||
err_t err = lpgemm_translate_to_pre_ops_list(
|
||||
post_op_unparsed->pre_ops, pre_op_list,
|
||||
m, n, k);
|
||||
|
||||
if (err != BLIS_SUCCESS)
|
||||
if (err != BLIS_SUCCESS)
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
err = lpgemm_translate_to_post_ops_list(
|
||||
post_op_unparsed, post_op_list,
|
||||
(void *)c, (void *)(&order),
|
||||
m, n);
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
err = lpgemm_translate_to_post_ops_list(
|
||||
post_op_unparsed, post_op_list,
|
||||
(void *)c, (void *)(&order),
|
||||
m, n);
|
||||
|
||||
if (err != BLIS_SUCCESS)
|
||||
if (err != BLIS_SUCCESS)
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global(&rntm_g);
|
||||
bli_pba_rntm_set_pba(&rntm_g);
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global(&rntm_g);
|
||||
bli_pba_rntm_set_pba(&rntm_g);
|
||||
|
||||
lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(BF16S4F32OF32);
|
||||
lpgemm_cntx_t *lcntx_g = lpgemm_get_global_cntx_obj(BF16S4F32OF32);
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
// Swapping inputs to induce row major computation for column major inputs.
|
||||
if (is_column_major == TRUE)
|
||||
{
|
||||
// Swapping inputs not possible in case of mixed precision.
|
||||
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
|
||||
// Swapping inputs to induce row major computation for column major inputs.
|
||||
if (is_column_major == TRUE)
|
||||
{
|
||||
// Swapping inputs not possible in case of mixed precision.
|
||||
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_bf16s4f32of32_openmp_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(float *)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g, pre_op_list,
|
||||
post_op_list, BF16
|
||||
);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_bf16s4f32of32_openmp_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(float *)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g, pre_op_list,
|
||||
post_op_list, BF16
|
||||
);
|
||||
}
|
||||
#else
|
||||
// Swapping inputs to induce row major computation for column major inputs.
|
||||
if (is_column_major == TRUE)
|
||||
{
|
||||
// Swapping inputs not possible in case of mixed precision.
|
||||
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
|
||||
// Swapping inputs to induce row major computation for column major inputs.
|
||||
if (is_column_major == TRUE)
|
||||
{
|
||||
// Swapping inputs not possible in case of mixed precision.
|
||||
bli_print_msg(" column major not supported yet in bf16s4f32o<f32/bf16>.", __FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_bf16s4f32of32_thread_decorator(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(float*)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g, pre_op_list,
|
||||
post_op_list, BF16);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
lpgemm_bf16s4f32of32_thread_decorator
|
||||
(
|
||||
m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(float*)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g, pre_op_list,
|
||||
post_op_list, BF16
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
|
||||
@@ -109,7 +109,8 @@
|
||||
m, n, k, \
|
||||
a, lda, mtag_a, \
|
||||
b, ldb, mtag_b, \
|
||||
c, ldc \
|
||||
c, ldc, \
|
||||
err_no \
|
||||
) \
|
||||
{ \
|
||||
int32_t info = 0; \
|
||||
@@ -170,7 +171,7 @@
|
||||
\
|
||||
sprintf( print_msg, "** On entry to %6s, parameter number %2i of problem %ld had an illegal value", op_str, info, gemm_no); \
|
||||
bli_print_msg(print_msg, __FILE__, __LINE__); \
|
||||
return; \
|
||||
err_no = info; \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
@@ -177,5 +177,14 @@ BLIS_EXPORT_ADDON void aocl_batch_gemm_ ## LP_SFX \
|
||||
|
||||
AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32);
|
||||
AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16);
|
||||
AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32);
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32);
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8);
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32);
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8);
|
||||
AOCL_BGEMM_MATMUL(bfloat16,int8_t,float,float,bf16s4f32of32);
|
||||
AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16);
|
||||
|
||||
|
||||
|
||||
#endif // AOCL_GEMM_INTERFACE_H
|
||||
|
||||
@@ -45,6 +45,11 @@ static bli_pthread_once_t once_check_lpgemm_logger_init = BLIS_PTHREAD_ONCE_INIT
|
||||
|
||||
static bool lpgemm_logger_enabled = FALSE;
|
||||
|
||||
bool is_logger_enabled()
|
||||
{
|
||||
return lpgemm_logger_enabled;
|
||||
}
|
||||
|
||||
FILE* lpgemm_start_logger_fn(void)
|
||||
{
|
||||
lpgemm_init_logger();
|
||||
@@ -83,7 +88,7 @@ void lpgemm_stop_logger_fn( FILE* fd )
|
||||
ops_str_len += c_ops_str_len; \
|
||||
} while ( 0 ); \
|
||||
|
||||
static void lpgemm_get_pre_ops_str( aocl_post_op* post_ops, char* ops_str )
|
||||
void lpgemm_get_pre_ops_str( aocl_post_op* post_ops, char* ops_str )
|
||||
{
|
||||
if ( post_ops == NULL )
|
||||
{
|
||||
@@ -148,7 +153,7 @@ static void lpgemm_get_pre_ops_str( aocl_post_op* post_ops, char* ops_str )
|
||||
}
|
||||
}
|
||||
|
||||
static void lpgemm_get_post_ops_str( aocl_post_op* post_ops, char* ops_str )
|
||||
void lpgemm_get_post_ops_str( aocl_post_op* post_ops, char* ops_str )
|
||||
{
|
||||
if ( ( post_ops == NULL ) || ( post_ops->seq_length <= 0 ) )
|
||||
{
|
||||
@@ -308,11 +313,54 @@ void lpgemm_write_logger_gemm_fn
|
||||
}
|
||||
}
|
||||
|
||||
void batch_lpgemm_write_logger_gemm_fn
|
||||
(
|
||||
FILE* fd,
|
||||
char* op_type,
|
||||
const char* order,
|
||||
const char* transa,
|
||||
const char* transb,
|
||||
const dim_t batch_size,
|
||||
const dim_t* m,
|
||||
const dim_t* n,
|
||||
const dim_t* k,
|
||||
const float* alpha,
|
||||
const dim_t* lda,
|
||||
const char* mem_format_a,
|
||||
const dim_t* ldb,
|
||||
const char* mem_format_b,
|
||||
const float* beta,
|
||||
const dim_t* ldc,
|
||||
aocl_post_op** post_op_unparsed
|
||||
)
|
||||
{
|
||||
if ( ( lpgemm_logger_enabled == TRUE ) && ( fd != NULL ) )
|
||||
{
|
||||
char pre_ops_str[1024] = {0};
|
||||
|
||||
char post_ops_str[2048] = {0};
|
||||
|
||||
fprintf(fd, "%s:bs=%ld\n", op_type, batch_size);
|
||||
for( dim_t i = 0; i < batch_size; i++ )
|
||||
{
|
||||
lpgemm_get_pre_ops_str( post_op_unparsed[i], pre_ops_str );
|
||||
lpgemm_get_post_ops_str( post_op_unparsed[i], post_ops_str );
|
||||
fprintf( fd, "%c %c %c %c %c %ld %ld %ld %ld %ld %ld "\
|
||||
":pre_ops=[%s]:post_ops=[%s] %f %f\n",
|
||||
order[i], transa[i], transb[i], mem_format_a[i], mem_format_b[i],
|
||||
m[i], n[i], k[i], lda[i], ldb[i], ldc[i],
|
||||
pre_ops_str, post_ops_str,
|
||||
(float)(alpha[i]), (float)(beta[i]) );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void lpgemm_write_logger_time_break_fn( FILE* fd, double stime )
|
||||
{
|
||||
if ( ( lpgemm_logger_enabled == TRUE ) && ( fd != NULL ) )
|
||||
{
|
||||
fprintf( fd, "%f \n", stime );
|
||||
fprintf( fd, "time:%f \n", stime );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -42,6 +42,9 @@
|
||||
|
||||
FILE* lpgemm_start_logger_fn(void);
|
||||
void lpgemm_stop_logger_fn( FILE* fd );
|
||||
void lpgemm_get_post_ops_str( aocl_post_op* post_ops, char* ops_str );
|
||||
void lpgemm_get_pre_ops_str( aocl_post_op* post_ops, char* ops_str );
|
||||
bool is_logger_enabled();
|
||||
void lpgemm_write_logger_gemm_fn
|
||||
(
|
||||
FILE* fd,
|
||||
@@ -61,6 +64,26 @@ void lpgemm_write_logger_gemm_fn
|
||||
const dim_t ldc,
|
||||
aocl_post_op* post_op_unparsed
|
||||
);
|
||||
void batch_lpgemm_write_logger_gemm_fn
|
||||
(
|
||||
FILE* fd,
|
||||
char* op_type,
|
||||
const char* order,
|
||||
const char* transa,
|
||||
const char* transb,
|
||||
const dim_t batch_size,
|
||||
const dim_t* m,
|
||||
const dim_t* n,
|
||||
const dim_t* k,
|
||||
const float* alpha,
|
||||
const dim_t* lda,
|
||||
const char* mem_format_a,
|
||||
const dim_t* ldb,
|
||||
const char* mem_format_b,
|
||||
const float* beta,
|
||||
const dim_t* ldc,
|
||||
aocl_post_op** post_op_unparsed
|
||||
);
|
||||
void lpgemm_write_logger_time_break_fn( FILE* fd, double stime );
|
||||
|
||||
#define LPGEMM_START_LOGGER() \
|
||||
@@ -81,6 +104,33 @@ void lpgemm_write_logger_time_break_fn( FILE* fd, double stime );
|
||||
#define LPGEMM_WRITE_LOGGER(...) \
|
||||
lpgemm_write_logger_gemm_fn( fd, __VA_ARGS__ ); \
|
||||
|
||||
#define BATCH_LPGEMM_WRITE_LOGGER( op_type, order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
alpha, lda, mem_format_a, \
|
||||
ldb, mem_format_b, beta, \
|
||||
ldc, post_op_unparsed ) \
|
||||
{ \
|
||||
if ( ( is_logger_enabled() ) && ( fd != NULL ) ) \
|
||||
{ \
|
||||
char pre_ops_str[1024] = {0}; \
|
||||
\
|
||||
char post_ops_str[2048] = {0}; \
|
||||
\
|
||||
fprintf(fd, "%s:bs=%ld\n", op_type, batch_size); \
|
||||
for( dim_t i = 0; i < batch_size; i++ ) \
|
||||
{ \
|
||||
lpgemm_get_pre_ops_str( post_op_unparsed[i], pre_ops_str ); \
|
||||
lpgemm_get_post_ops_str( post_op_unparsed[i], post_ops_str ); \
|
||||
fprintf( fd, "%c %c %c %c %c %ld %ld %ld %ld %ld %ld "\
|
||||
":pre_ops=[%s]:post_ops=[%s] %f %f\n", \
|
||||
order[i], transa[i], transb[i], mem_format_a[i], mem_format_b[i], \
|
||||
m[i], n[i], k[i], lda[i], ldb[i], ldc[i], \
|
||||
pre_ops_str, post_ops_str, \
|
||||
(float)(alpha[i]), (float)(beta[i]) ); \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#define LPGEMM_START_LOGGER(...)
|
||||
@@ -89,6 +139,12 @@ void lpgemm_write_logger_time_break_fn( FILE* fd, double stime );
|
||||
|
||||
#define LPGEMM_WRITE_LOGGER(...)
|
||||
|
||||
#define BATCH_LPGEMM_WRITE_LOGGER(op_type, order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
alpha, lda, mem_format_a, \
|
||||
ldb, mem_format_b, beta, \
|
||||
ldc, post_op_unparsed)
|
||||
|
||||
#endif
|
||||
|
||||
void lpgemm_init_logger();
|
||||
|
||||
@@ -72,35 +72,6 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32);
|
||||
LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32);
|
||||
LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16);
|
||||
|
||||
#define BATCH_LPGEMM_5LOOP(A_type,B_type,C_type,LP_SFX) \
|
||||
void batch_lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
const dim_t m, \
|
||||
const dim_t n, \
|
||||
const dim_t k, \
|
||||
const A_type** a, \
|
||||
const dim_t rs_a, \
|
||||
const dim_t cs_a, \
|
||||
const AOCL_MEMORY_TAG mtag_a, \
|
||||
const B_type** b, \
|
||||
dim_t rs_b, \
|
||||
dim_t cs_b, \
|
||||
AOCL_MEMORY_TAG mtag_b, \
|
||||
C_type** c, \
|
||||
const dim_t rs_c, \
|
||||
const dim_t cs_c, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
rntm_t* rntm, \
|
||||
lpgemm_thrinfo_t* thread, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_post_op* post_op_list, \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
) \
|
||||
|
||||
BATCH_LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32);
|
||||
|
||||
|
||||
#define LPGEMM_5LOOP1(A_type,B_type,C_type,LP_SFX) \
|
||||
void lpgemm_rowvar_ ## LP_SFX \
|
||||
( \
|
||||
|
||||
@@ -43,6 +43,23 @@
|
||||
|
||||
#define BLIS_LPGEMM_NUM_STATIC_COMMS 96
|
||||
|
||||
BLIS_INLINE void calculate_n_threads_per_gemm
|
||||
(
|
||||
dim_t batch_size,
|
||||
dim_t* n_threads,
|
||||
dim_t* n_gemms_in_parallel,
|
||||
dim_t* n_threads_per_gemm,
|
||||
rntm_t* rntm_g
|
||||
)
|
||||
{
|
||||
*n_threads = bli_rntm_num_threads( rntm_g ); \
|
||||
*n_gemms_in_parallel = -1; \
|
||||
if( *n_threads == 1 ) *n_gemms_in_parallel = 1; \
|
||||
else if( *n_gemms_in_parallel < 1 ) *n_gemms_in_parallel = bli_min(*n_threads, batch_size); \
|
||||
/* ToDo: All the leftover thrads might go under-utilized. Could be optimized further. */ \
|
||||
*n_threads_per_gemm = ( *n_threads ) / *n_gemms_in_parallel;
|
||||
}
|
||||
|
||||
BLIS_INLINE dim_t next_factor
|
||||
(
|
||||
const dim_t nt,
|
||||
@@ -437,6 +454,119 @@ BLIS_INLINE void lpgemm_s32o32_get_threading
|
||||
}
|
||||
}
|
||||
|
||||
BLIS_INLINE void batch_lpgemm_s32o32_get_threading
|
||||
(
|
||||
dim_t batch_size,
|
||||
dim_t* n_threads,
|
||||
dim_t* n_gemms_in_parallel,
|
||||
dim_t* n_threads_per_gemm,
|
||||
dim_t* ic_ways,
|
||||
dim_t* jc_ways,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
rntm_t* rntm_g,
|
||||
AOCL_OPERATION_TYPE op_type
|
||||
)
|
||||
{
|
||||
|
||||
calculate_n_threads_per_gemm(batch_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
|
||||
|
||||
if ( ( *n_threads_per_gemm ) > 1 )
|
||||
{
|
||||
|
||||
dim_t NR = lpgemm_get_block_size_NR_global_cntx( op_type );
|
||||
dim_t MR = lpgemm_get_block_size_MR_global_cntx( op_type );
|
||||
dim_t mr_blks = ( m + MR - 1 ) / MR;
|
||||
dim_t nr_blks = ( n + NR - 1 ) / NR;
|
||||
|
||||
if ( n <= NR )
|
||||
{
|
||||
( *ic_ways ) = ( *n_threads_per_gemm );
|
||||
( *jc_ways ) = 1;
|
||||
( *n_threads_per_gemm ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else if ( m <= MR )
|
||||
{
|
||||
( *jc_ways ) = ( *n_threads_per_gemm );
|
||||
( *ic_ways ) = 1;
|
||||
( *n_threads_per_gemm ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else
|
||||
{
|
||||
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
|
||||
bli_thread_partition_2x2( ( *n_threads_per_gemm ), m, n, ic_ways, jc_ways );
|
||||
if ( ( mr_blks >= ( *ic_ways ) ) && ( nr_blks >= ( *jc_ways ) ) )
|
||||
{
|
||||
lpgemm_pnl_wrk_heur_adjust_ic_jc_ways
|
||||
(
|
||||
MR, NR, m, n,
|
||||
n_threads_per_gemm, ic_ways, jc_ways
|
||||
);
|
||||
}
|
||||
}
|
||||
( *n_threads ) = ( *n_gemms_in_parallel ) * ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else
|
||||
{
|
||||
// Setting all the values to 1 in case n_threads <= 1. This ensures
|
||||
// the threading parameters are valid.
|
||||
*n_threads = 1;
|
||||
*n_gemms_in_parallel = 1;
|
||||
*n_threads_per_gemm = 1;
|
||||
*jc_ways = 1;
|
||||
*ic_ways = 1;
|
||||
}
|
||||
}
|
||||
|
||||
BLIS_INLINE void batch_lpgemm_u8s8s32o32_get_threading
|
||||
(
|
||||
dim_t batch_size,
|
||||
dim_t* n_threads,
|
||||
dim_t* n_gemms_in_parallel,
|
||||
dim_t* n_threads_per_gemm,
|
||||
dim_t* ic_ways,
|
||||
dim_t* jc_ways,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
rntm_t* rntm_g
|
||||
)
|
||||
{
|
||||
batch_lpgemm_s32o32_get_threading
|
||||
(
|
||||
batch_size,
|
||||
n_threads, n_gemms_in_parallel, n_threads_per_gemm,
|
||||
ic_ways, jc_ways,
|
||||
m, n, k, rntm_g,
|
||||
U8S8S32OS32
|
||||
);
|
||||
}
|
||||
|
||||
BLIS_INLINE void batch_lpgemm_s8s8s32o32_get_threading
|
||||
(
|
||||
dim_t batch_size,
|
||||
dim_t* n_threads,
|
||||
dim_t* n_gemms_in_parallel,
|
||||
dim_t* n_threads_per_gemm,
|
||||
dim_t* ic_ways,
|
||||
dim_t* jc_ways,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
rntm_t* rntm_g
|
||||
)
|
||||
{
|
||||
batch_lpgemm_s32o32_get_threading
|
||||
(
|
||||
batch_size,
|
||||
n_threads, n_gemms_in_parallel, n_threads_per_gemm,
|
||||
ic_ways, jc_ways,
|
||||
m, n, k, rntm_g,
|
||||
S8S8S32OS32
|
||||
);
|
||||
}
|
||||
|
||||
BLIS_INLINE void lpgemm_u8s8s32o32_get_threading
|
||||
(
|
||||
dim_t* n_threads,
|
||||
@@ -544,23 +674,6 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
|
||||
}
|
||||
}
|
||||
|
||||
BLIS_INLINE void calculate_n_threads_per_gemm
|
||||
(
|
||||
dim_t batch_size,
|
||||
dim_t* n_threads,
|
||||
dim_t* n_gemms_in_parallel,
|
||||
dim_t* n_threads_per_gemm,
|
||||
rntm_t* rntm_g
|
||||
)
|
||||
{
|
||||
*n_threads = bli_rntm_num_threads( rntm_g ); \
|
||||
*n_gemms_in_parallel = -1; \
|
||||
if( *n_threads == 1 ) *n_gemms_in_parallel = 1; \
|
||||
else if( *n_gemms_in_parallel < 1 ) *n_gemms_in_parallel = bli_min(*n_threads, batch_size); \
|
||||
/* ToDo: All the leftover thrads might go under-utilized. Could be optimized further. */ \
|
||||
*n_threads_per_gemm = ( *n_threads ) / *n_gemms_in_parallel;
|
||||
}
|
||||
|
||||
BLIS_INLINE void batch_lpgemm_bf16bf16f32of32_get_threading
|
||||
(
|
||||
dim_t batch_size,
|
||||
@@ -723,6 +836,99 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
|
||||
}
|
||||
}
|
||||
}
|
||||
BLIS_INLINE void batch_lpgemm_f32f32f32of32_get_threading
|
||||
(
|
||||
dim_t batch_size,
|
||||
dim_t* n_threads,
|
||||
dim_t* n_gemms_in_parallel,
|
||||
dim_t* n_threads_per_gemm,
|
||||
dim_t* ic_ways,
|
||||
dim_t* jc_ways,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
rntm_t* rntm_g
|
||||
)
|
||||
{
|
||||
|
||||
calculate_n_threads_per_gemm(batch_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
|
||||
|
||||
// Query the context for SUP limits.
|
||||
const dim_t MT = lpgemm_get_sup_thres_MT_global_cntx( F32F32F32OF32 );
|
||||
const dim_t NT = lpgemm_get_sup_thres_NT_global_cntx( F32F32F32OF32 );
|
||||
const dim_t KT = lpgemm_get_sup_thres_KT_global_cntx( F32F32F32OF32 );
|
||||
|
||||
// Query the context for various blocksizes.
|
||||
dim_t NR = lpgemm_get_block_size_NR_global_cntx( F32F32F32OF32 );
|
||||
dim_t MR = lpgemm_get_block_size_MR_global_cntx( F32F32F32OF32 );
|
||||
dim_t MC = lpgemm_get_block_size_MC_global_cntx( F32F32F32OF32 );
|
||||
dim_t NC = lpgemm_get_block_size_NC_global_cntx( F32F32F32OF32 );
|
||||
dim_t KC = lpgemm_get_block_size_KC_global_cntx( F32F32F32OF32 );
|
||||
|
||||
const dim_t MT_2 = MT / 2;
|
||||
|
||||
/* The user is not allowed to set ic_ways or jc_ways */
|
||||
if ( ( *n_threads_per_gemm ) > 1 )
|
||||
{
|
||||
dim_t mr_blks = ( m + MR - 1 ) / MR;
|
||||
dim_t nr_blks = ( n + NR - 1 ) / NR;
|
||||
|
||||
if ( n <= NR )
|
||||
{
|
||||
( *ic_ways ) = ( *n_threads_per_gemm );
|
||||
( *jc_ways ) = 1;
|
||||
( *n_threads_per_gemm ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else if ( m <= MR )
|
||||
{
|
||||
( *jc_ways ) = ( *n_threads_per_gemm );
|
||||
( *ic_ways ) = 1;
|
||||
( *n_threads_per_gemm ) = ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else
|
||||
{
|
||||
// If BLIS_NUM_THREADS are set, generate jc,ic from the same.
|
||||
bli_thread_partition_2x2( ( *n_threads_per_gemm ), m, n, ic_ways, jc_ways );
|
||||
if ( ( mr_blks >= ( *ic_ways ) ) && ( nr_blks >= ( *jc_ways ) ) )
|
||||
{
|
||||
lpgemm_adjust_ic_jc_ways
|
||||
(
|
||||
m, n, k,
|
||||
MC, NC, KC, MR, NR,
|
||||
n_threads_per_gemm, ic_ways, jc_ways, 5
|
||||
);
|
||||
}
|
||||
}
|
||||
( *n_threads ) = ( *n_gemms_in_parallel ) * ( *ic_ways ) * ( *jc_ways );
|
||||
}
|
||||
else
|
||||
{
|
||||
// Setting all the values to 1 in case n_threads <= 1. This ensures
|
||||
// the threading parameters are valid.
|
||||
*n_threads = 1;
|
||||
*n_gemms_in_parallel = 1;
|
||||
*n_threads_per_gemm = 1;
|
||||
*jc_ways = 1;
|
||||
*ic_ways = 1;
|
||||
}
|
||||
|
||||
// Native -> SUP path.
|
||||
const dim_t m_ic = m / ( *ic_ways );
|
||||
const dim_t n_jc = n / ( *jc_ways );
|
||||
const dim_t page_size = bli_info_get_page_size();
|
||||
const dim_t page_size_b_floatx2 =
|
||||
2 * ( page_size / sizeof( float ) );
|
||||
|
||||
if ( ( m >= MT ) && ( n >= NT ) && ( k >= KT ) )
|
||||
{
|
||||
if (((k >= page_size_b_floatx2) && (m_ic > MT_2) && (n_jc >= NT)) ||
|
||||
((bli_cpuid_is_avx512_supported() == FALSE) && (k > page_size_b_floatx2)))
|
||||
{
|
||||
bli_rntm_set_pack_b( 1, rntm_g );
|
||||
bli_rntm_set_pack_a( 1, rntm_g );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define GEN_LPGEMM_OPENMP_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
@@ -944,11 +1150,14 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
} \
|
||||
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32)
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32)
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
|
||||
#define GEN_LPGEMM_OPENMP_DECORATOR_MP(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
( \
|
||||
const dim_t m, \
|
||||
const dim_t m, \
|
||||
const dim_t n, \
|
||||
const dim_t k, \
|
||||
const A_type* a, \
|
||||
@@ -1048,6 +1257,141 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
|
||||
GEN_LPGEMM_OPENMP_DECORATOR_MP(bfloat16, int8_t, float, bf16s4f32of32)
|
||||
|
||||
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_MP(A_type,B_type,C_type,LPGEMM_SFX, LPGEMM_PARENT_SFX) \
|
||||
void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
( \
|
||||
const dim_t batch_size, \
|
||||
const dim_t* m, \
|
||||
const dim_t* n, \
|
||||
const dim_t* k, \
|
||||
const A_type** a, \
|
||||
const dim_t* rs_a, \
|
||||
const dim_t* cs_a, \
|
||||
const AOCL_MEMORY_TAG* mtag_a, \
|
||||
const B_type** b, \
|
||||
const dim_t* rs_b, \
|
||||
const dim_t* cs_b, \
|
||||
AOCL_MEMORY_TAG* mtag_b, \
|
||||
C_type** c, \
|
||||
const dim_t* rs_c, \
|
||||
const dim_t* cs_c, \
|
||||
const C_type* alpha, \
|
||||
const C_type* beta, \
|
||||
rntm_t* rntm_g, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
|
||||
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
) \
|
||||
{ \
|
||||
/* For now, Assuming all the problems in GEMM are of same size.
|
||||
* To-Do: optimize work distribution for case where a batch contains
|
||||
GEMM problems of different sizes.
|
||||
*/ \
|
||||
dim_t n_threads; \
|
||||
/* Factorization of threads along m and n dimension respectively.*/ \
|
||||
dim_t ic_ways; \
|
||||
dim_t jc_ways; \
|
||||
dim_t n_gemms_in_parallel; \
|
||||
dim_t n_threads_per_gemm; \
|
||||
\
|
||||
/* Assuming all the problems in GEMM are of same size */ \
|
||||
batch_lpgemm_ ## LPGEMM_PARENT_SFX ## _get_threading \
|
||||
( \
|
||||
batch_size, \
|
||||
&n_threads, \
|
||||
&n_gemms_in_parallel, \
|
||||
&n_threads_per_gemm, \
|
||||
&ic_ways, &jc_ways, \
|
||||
m[0], n[0], k[0], rntm_g \
|
||||
); \
|
||||
\
|
||||
/* Set the packing block allocator field of the rntm. This will be
|
||||
* inherited by all of the child threads when they make local copies of
|
||||
* the rntm below.*/ \
|
||||
bli_pba_rntm_set_pba( rntm_g ); \
|
||||
\
|
||||
thrcomm_t static_lpgemm_comms[BLIS_LPGEMM_NUM_STATIC_COMMS]; \
|
||||
thrcomm_t* cur_lpgemm_comms = static_lpgemm_comms; \
|
||||
err_t bli_errors = BLIS_SUCCESS; \
|
||||
\
|
||||
if ( jc_ways * n_gemms_in_parallel > BLIS_LPGEMM_NUM_STATIC_COMMS ) \
|
||||
{ \
|
||||
cur_lpgemm_comms = bli_malloc_intl( jc_ways * n_gemms_in_parallel * \
|
||||
sizeof( thrcomm_t ), &bli_errors ); \
|
||||
} \
|
||||
for( dim_t i = 0; i < n_gemms_in_parallel * jc_ways; i++ ) \
|
||||
{ \
|
||||
bli_thrcomm_init( ic_ways, &cur_lpgemm_comms[i] ); \
|
||||
} \
|
||||
\
|
||||
dim_t MC = lpgemm_get_block_size_MC_global_cntx( BF16BF16F32OF32 ); \
|
||||
\
|
||||
_Pragma( "omp parallel num_threads(n_threads)" ) \
|
||||
{ \
|
||||
/* Create a thread-local copy of the master thread's rntm_t. This is
|
||||
* necessary since we want each thread to be able to track its own
|
||||
* small block pool_t as it executes down the function stack.*/ \
|
||||
rntm_t rntm_l = *rntm_g; \
|
||||
\
|
||||
/* lpgemm_thrinfo_t object will be used to generate thrinfo_t objects
|
||||
* for use in blis mt framework inside the respective mat mul driver
|
||||
* functions.*/ \
|
||||
lpgemm_thrinfo_t thread; \
|
||||
thread.n_threads = n_threads_per_gemm; \
|
||||
thread.tid = omp_get_thread_num() % n_threads_per_gemm; \
|
||||
thread.ic_ways = ic_ways; \
|
||||
thread.jc_ways = jc_ways; \
|
||||
thread.comm = cur_lpgemm_comms + (omp_get_thread_num() / n_threads_per_gemm) * jc_ways; \
|
||||
\
|
||||
dim_t gemm_start; \
|
||||
dim_t gemm_end; \
|
||||
\
|
||||
/* This structure is filled only to calculate workload distribution of GEMM problems
|
||||
among threads. This struct is not passed to 5-loop.
|
||||
*/ \
|
||||
thrinfo_t thrinfo; \
|
||||
thrinfo.n_way = n_gemms_in_parallel; \
|
||||
thrinfo.work_id = omp_get_thread_num() / n_threads_per_gemm; \
|
||||
bli_thread_range_sub( &thrinfo, batch_size, 1, FALSE, &gemm_start, &gemm_end ); \
|
||||
\
|
||||
for( dim_t i = gemm_start; i < gemm_end; i++ ) \
|
||||
{ \
|
||||
/* Decide whether to go with pack-based implementation
|
||||
or kernel-level implementation */ \
|
||||
if( ( m[i] / ic_ways ) > MC ) \
|
||||
{ \
|
||||
mtag_b[i] = PACK_KC; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
mtag_b[i] = UNPACKED; \
|
||||
} \
|
||||
\
|
||||
lpgemm_rowvar_ ## LPGEMM_SFX \
|
||||
( \
|
||||
m[i], n[i], k[i], \
|
||||
a[i], rs_a[i], cs_a[i], mtag_a[i], \
|
||||
b[i], rs_b[i], cs_b[i], mtag_b[i], \
|
||||
c[i], rs_c[i], cs_c[i],\
|
||||
alpha[i], \
|
||||
beta[i], \
|
||||
&rntm_l, \
|
||||
&thread, \
|
||||
lcntx, \
|
||||
pre_op_list[i], \
|
||||
post_op_list[i], c_downscale \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
if ( jc_ways * n_gemms_in_parallel > BLIS_LPGEMM_NUM_STATIC_COMMS ) \
|
||||
{ \
|
||||
bli_free_intl( cur_lpgemm_comms ); \
|
||||
} \
|
||||
} \
|
||||
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_MP(bfloat16, int8_t, float, bf16s4f32of32, bf16bf16f32of32)
|
||||
|
||||
BLIS_INLINE void lpgemm_eltwise_ops_bf16of32_get_threading
|
||||
(
|
||||
dim_t* n_threads,
|
||||
@@ -1468,7 +1812,84 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
} \
|
||||
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32)
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32)
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
|
||||
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_MP(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
( \
|
||||
const dim_t batch_size, \
|
||||
const dim_t* m, \
|
||||
const dim_t* n, \
|
||||
const dim_t* k, \
|
||||
const A_type** a, \
|
||||
const dim_t* rs_a, \
|
||||
const dim_t* cs_a, \
|
||||
const AOCL_MEMORY_TAG* mtag_a, \
|
||||
const B_type** b, \
|
||||
const dim_t* rs_b, \
|
||||
const dim_t* cs_b, \
|
||||
const AOCL_MEMORY_TAG* mtag_b, \
|
||||
C_type** c, \
|
||||
const dim_t* rs_c, \
|
||||
const dim_t* cs_c, \
|
||||
const C_type* alpha, \
|
||||
const C_type* beta, \
|
||||
rntm_t* rntm_g, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
|
||||
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
) \
|
||||
{ \
|
||||
dim_t n_threads = 1; \
|
||||
\
|
||||
/* Factorization of threads along m and n dimension respectively.*/ \
|
||||
dim_t ic_ways = 1; \
|
||||
dim_t jc_ways = 1; \
|
||||
\
|
||||
/* Set the packing block allocator field of the rntm. This will be
|
||||
* inherited by all of the child threads when they make local copies of
|
||||
* the rntm below.*/ \
|
||||
bli_pba_rntm_set_pba( rntm_g ); \
|
||||
\
|
||||
thrcomm_t static_lpgemm_comm; \
|
||||
thrcomm_t* cur_lpgemm_comm = &static_lpgemm_comm; \
|
||||
\
|
||||
bli_thrcomm_init( ic_ways, cur_lpgemm_comm ); \
|
||||
/* lpgemm_thrinfo_t object will be used to generate thrinfo_t objects
|
||||
* for use in blis mt framework inside the respective mat mul driver
|
||||
* functions.*/ \
|
||||
lpgemm_thrinfo_t thread; \
|
||||
thread.n_threads = n_threads; \
|
||||
thread.tid = 0; \
|
||||
thread.ic_ways = ic_ways; \
|
||||
thread.jc_ways = jc_ways; \
|
||||
thread.comm = cur_lpgemm_comm; \
|
||||
dim_t gemm_start = 0; \
|
||||
dim_t gemm_end = batch_size; \
|
||||
\
|
||||
for( dim_t i = gemm_start; i < gemm_end; i++ ) \
|
||||
{ \
|
||||
lpgemm_rowvar_ ## LPGEMM_SFX \
|
||||
( \
|
||||
m[i], n[i], k[i], \
|
||||
a[i], rs_a[i], cs_a[i], mtag_a[i], \
|
||||
b[i], rs_b[i], cs_b[i], mtag_b[i], \
|
||||
c[i], rs_c[i], cs_c[i],\
|
||||
alpha[i], \
|
||||
beta[i], \
|
||||
rntm_g, \
|
||||
&thread, \
|
||||
lcntx, \
|
||||
pre_op_list[i], \
|
||||
post_op_list[i], c_downscale \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_MP(bfloat16,int8_t,float,bf16s4f32of32)
|
||||
|
||||
#define GEN_UTIL_ELTWISE_OPS_DECORATOR(A_type,B_type,LPGEMM_SFX) \
|
||||
void lpgemm_eltwise_ops_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
|
||||
@@ -101,6 +101,39 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
); \
|
||||
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32)
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32)
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
|
||||
|
||||
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN_MXP(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
( \
|
||||
const dim_t batch_size, \
|
||||
const dim_t* m, \
|
||||
const dim_t* n, \
|
||||
const dim_t* k, \
|
||||
const A_type** a, \
|
||||
const dim_t* rs_a, \
|
||||
const dim_t* cs_a, \
|
||||
const AOCL_MEMORY_TAG* mtag_a, \
|
||||
const B_type** b, \
|
||||
const dim_t* rs_b, \
|
||||
const dim_t* cs_b, \
|
||||
AOCL_MEMORY_TAG* mtag_b, \
|
||||
C_type** c, \
|
||||
const dim_t* rs_c, \
|
||||
const dim_t* cs_c, \
|
||||
const C_type* alpha, \
|
||||
const C_type* beta, \
|
||||
rntm_t* rntm_g, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
|
||||
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
); \
|
||||
|
||||
GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN_MXP(bfloat16,int8_t,float,bf16s4f32of32)
|
||||
|
||||
|
||||
#define GEN_LPGEMM_OPENMP_DECORATOR_FN1(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
@@ -116,7 +149,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
const B_type* b, \
|
||||
const dim_t rs_b, \
|
||||
const dim_t cs_b, \
|
||||
const AOCL_MEMORY_TAG mtag_b, \
|
||||
AOCL_MEMORY_TAG mtag_b, \
|
||||
C_type* c, \
|
||||
const dim_t rs_c, \
|
||||
const dim_t cs_c, \
|
||||
@@ -213,6 +246,38 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
); \
|
||||
|
||||
GEN_BATCH_LPGEMM_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
GEN_BATCH_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32)
|
||||
GEN_BATCH_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32)
|
||||
GEN_BATCH_LPGEMM_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
|
||||
#define GEN_BATCH_LPGEMM_DECORATOR_FN_MP(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
( \
|
||||
const dim_t bs, \
|
||||
const dim_t* m, \
|
||||
const dim_t* n, \
|
||||
const dim_t* k, \
|
||||
const A_type** a, \
|
||||
const dim_t* rs_a, \
|
||||
const dim_t* cs_a, \
|
||||
const AOCL_MEMORY_TAG* mtag_a, \
|
||||
const B_type** b, \
|
||||
const dim_t* rs_b, \
|
||||
const dim_t* cs_b, \
|
||||
const AOCL_MEMORY_TAG* mtag_b, \
|
||||
C_type** c, \
|
||||
const dim_t* rs_c, \
|
||||
const dim_t* cs_c, \
|
||||
const C_type* alpha, \
|
||||
const C_type* beta, \
|
||||
rntm_t* rntm_g, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
|
||||
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
); \
|
||||
|
||||
GEN_BATCH_LPGEMM_DECORATOR_FN_MP(bfloat16,int8_t,float,bf16s4f32of32)
|
||||
|
||||
|
||||
#define GEN_LPGEMM_DECORATOR_FN1(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
|
||||
@@ -1,6 +1,17 @@
|
||||
bf16bf16f32obf16:bs=5
|
||||
r t n n p 83 3847 1930 83 3847 3847 scale=vector,zp=scalar,relu,clip
|
||||
c n t n n 50 1297 707 50 1297 50 scale=vector,zp=vector,bias=na,clip
|
||||
c n n n n 12 128 1605 12 1605 12 scale=vector,zp=vector
|
||||
r n t n r 63 95 1319 1319 1319 95 scale=vector,zp=scalar,relu,clip
|
||||
r n n n r 65 2283 911 911 2283 2283 bias=na,swish
|
||||
*:bs=5
|
||||
r t t n n 92 1479 589 92 589 1479 scale=vector,zp=vector,bias=na,clip
|
||||
r n n n r 67 21 1823 1823 21 21 scale=vector,zp=scalar,relu,clip
|
||||
r n t n n 43 2240 1553 1553 1553 2240 scale=vector,zp=scalar,relu,clip
|
||||
r t n n p 143 1943 730 143 1943 1943 bias=na,swish
|
||||
r n n n r 79 2676 1995 1995 2676 2676 bias=na,swish
|
||||
bf16s4f32of32:bs=4
|
||||
r t n n r 43 1110 271 43 1110 1110 scale=vector,zp=vector,bias=na,clip
|
||||
r t n n r 79 1177 1968 79 1177 1177 scale=vector,zp=scalar,relu,clip
|
||||
r n t n r 92 2872 1482 1482 1482 2872 scale=vector,zp=vector,bias=na,clip
|
||||
r n t n r 88 3397 1130 1130 1130 3397 scale=vector,zp=vector
|
||||
bf16s4f32obf16:bs=5
|
||||
r n n n r 17 2714 468 468 2714 2714 scale=vector,zp=vector,bias=na,clip
|
||||
r n n n r 140 3764 1519 1519 3764 3764 scale=vector,zp=vector
|
||||
r n t n r 17 1758 1034 1034 1034 1758 scale=vector,zp=vector,bias=na,clip
|
||||
r n n n r 130 1822 1293 1293 1822 1822 scale=vector,zp=vector
|
||||
r t t n r 21 2771 1882 21 1882 2771 bias=na,swish
|
||||
|
||||
@@ -111,6 +111,13 @@ void mat_mul_ ## BLAS_SFX \
|
||||
|
||||
GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
GEN_BLIS_MAT_MUL_FUNC(float,float,float,float,f32f32f32of32)
|
||||
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
|
||||
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
|
||||
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
|
||||
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
|
||||
GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
|
||||
|
||||
double get_gflops
|
||||
(
|
||||
@@ -211,6 +218,13 @@ void mat_mul_bench_driver_ ## BLAS_SFX \
|
||||
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,float,f32f32f32of32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
|
||||
|
||||
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
|
||||
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
|
||||
@@ -1230,6 +1244,13 @@ void mat_mul_bench_main_ ## BLAS_SFX \
|
||||
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32,bf16bf16f32of32,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16,bf16bf16f32of32,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,float,f32f32f32of32,f32f32f32of32,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32,u8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8,u8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32,s8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8,s8s8s32os32,u8s4s32os32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32,bf16bf16f32of32,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16,bf16bf16f32of32,bf16s4f32of32)
|
||||
|
||||
int main( int argc, char** argv )
|
||||
{
|
||||
@@ -1428,6 +1449,148 @@ int main( int argc, char** argv )
|
||||
post_ops_str_dest, FALSE
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( gemm_type_str, "f32f32f32of32" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
for( dim_t i = 0; i < bs; i++ )
|
||||
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
|
||||
global_can_dscale = 'y';
|
||||
global_dscale_out = 'n';
|
||||
global_pre_op = 'n';
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,f32f32f32of32)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
bs, m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest, FALSE
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( gemm_type_str, "u8s8s32os32" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
// Copy the original post op str to a temp string buffer.
|
||||
// Done so that strtok can be applied on the same (strtok
|
||||
// is a destructive parser.
|
||||
for( dim_t i = 0; i < bs; i++ )
|
||||
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'n';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = INT_MIN;
|
||||
DSCALE_CLIP_MAX = INT_MAX;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
bs, m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest, FALSE
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( gemm_type_str, "u8s8s32os8" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
// Copy the original post op str to a temp string buffer.
|
||||
// Done so that strtok can be applied on the same (strtok
|
||||
// is a destructive parser.
|
||||
for( dim_t i = 0; i < bs; i++ )
|
||||
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'y';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = -128;
|
||||
DSCALE_CLIP_MAX = +127;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os8)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
bs, m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest, FALSE
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( gemm_type_str, "s8s8s32os32" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
// Copy the original post op str to a temp string buffer.
|
||||
// Done so that strtok can be applied on the same (strtok
|
||||
// is a destructive parser.
|
||||
for( dim_t i = 0; i < bs; i++ )
|
||||
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'n';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = INT_MIN;
|
||||
DSCALE_CLIP_MAX = INT_MAX;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os32)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
bs, m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest, FALSE
|
||||
);
|
||||
}
|
||||
if ( ( strcmp( gemm_type_str, "s8s8s32os8" ) == 0 ) ||
|
||||
( strcmp( gemm_type_str, "*" ) == 0 ) )
|
||||
{
|
||||
// Copy the original post op str to a temp string buffer.
|
||||
// Done so that strtok can be applied on the same (strtok
|
||||
// is a destructive parser.
|
||||
for( dim_t i = 0; i < bs; i++ )
|
||||
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
|
||||
global_dscale_out = 'y';
|
||||
global_pre_op = 'n';
|
||||
DSCALE_CLIP_MIN = -128;
|
||||
DSCALE_CLIP_MAX = +127;
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os8)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
bs, m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest, FALSE
|
||||
);
|
||||
}
|
||||
if ( strcmp( gemm_type_str, "bf16s4f32of32" ) == 0 )
|
||||
{
|
||||
// Copy the original post op str to a temp string buffer.
|
||||
// Done so that strtok can be applied on the same (strtok
|
||||
// is a destructive parser.
|
||||
for( dim_t i = 0; i < bs; i++ )
|
||||
{
|
||||
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
|
||||
if ( ( op_b[i] != 'r' ) && ( op_b[i] != 'R' ) )
|
||||
{
|
||||
printf("Int4 B matrix only permitted if B reodering "
|
||||
"is enabled.\n");
|
||||
goto skip_exec;
|
||||
}
|
||||
}
|
||||
global_dscale_out = 'n';
|
||||
global_pre_op = 'y';
|
||||
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_, bf16s4f32of32)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
bs, m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest, TRUE
|
||||
);
|
||||
}
|
||||
if ( strcmp( gemm_type_str, "bf16s4f32obf16" ) == 0 )
|
||||
{
|
||||
// Copy the original post op str to a temp string buffer.
|
||||
// Done so that strtok can be applied on the same (strtok
|
||||
// is a destructive parser.
|
||||
for( dim_t i = 0; i < bs; i++ )
|
||||
{
|
||||
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
|
||||
if ( ( op_b[i] != 'r' ) && ( op_b[i] != 'R' ) )
|
||||
{
|
||||
printf("Int4 B matrix only permitted if B reodering "
|
||||
"is enabled.\n");
|
||||
goto skip_exec;
|
||||
}
|
||||
}
|
||||
global_dscale_out = 'y';
|
||||
global_pre_op = 'y';
|
||||
|
||||
GEN_FUNC_NAME(mat_mul_bench_main_, bf16s4f32obf16)
|
||||
(
|
||||
fin, fout, stor_order, transa, transb, op_a, op_b,
|
||||
bs, m, n, k, stride_a, stride_b, stride_c,
|
||||
post_ops_str_dest, TRUE
|
||||
);
|
||||
}
|
||||
skip_exec:
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6671,7 +6671,7 @@ POST_OPS_DOWNSCALE_5x2F:
|
||||
post_ops_attr.post_op_c_i + 2 ) );
|
||||
zero_point3 = _mm_set1_ps( *( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 3 ) );
|
||||
zero_point0 = _mm_set1_ps( *( (float* )post_ops_list_temp->op_args1 +
|
||||
zero_point4 = _mm_set1_ps( *( (float* )post_ops_list_temp->op_args1 +
|
||||
post_ops_attr.post_op_c_i + 4 ) );
|
||||
}
|
||||
//c[0, 0-3]
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -141,7 +141,7 @@ void packb_nr64_s8s8s32os32
|
||||
else
|
||||
{
|
||||
packb_nr64_s8s8s32os32_col_major(pack_b_buffer_s8s8s32o32,
|
||||
pack_b_column_sum, b,
|
||||
pack_b_column_sum, b,
|
||||
cs_b, NC, KC, rs_p, cs_p);
|
||||
}
|
||||
}
|
||||
@@ -198,7 +198,7 @@ void packb_nr64_s8s8s32os32_row_major
|
||||
{
|
||||
//load the temp buffer to compute column sum of B matrix
|
||||
sum1 = _mm512_loadu_si512( pack_b_column_sum + jc );
|
||||
sum2 = _mm512_loadu_si512( pack_b_column_sum + 16 + jc );
|
||||
sum2 = _mm512_loadu_si512( pack_b_column_sum + 16 + jc );
|
||||
//offset 16- as 16 int32 elements fit in 1 zmm register
|
||||
sum3 = _mm512_loadu_si512( pack_b_column_sum + 32 + jc );
|
||||
sum4 = _mm512_loadu_si512( pack_b_column_sum + 48 + jc );
|
||||
@@ -212,28 +212,28 @@ void packb_nr64_s8s8s32os32_row_major
|
||||
d0 = _mm512_loadu_si512( b + ( ldb * ( kr + 3 ) ) + jc );
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0, c0, d0)
|
||||
sum1 =
|
||||
sum1 =
|
||||
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 0)),
|
||||
_mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 0)),
|
||||
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 0))))) , mul_128));
|
||||
|
||||
sum2 =
|
||||
sum2 =
|
||||
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 1)),
|
||||
_mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 1)),
|
||||
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 1))))) , mul_128));
|
||||
|
||||
sum3 =
|
||||
sum3 =
|
||||
_mm512_add_epi32 ( sum3, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 2)),
|
||||
_mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 2)),
|
||||
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( d0, 2))))) , mul_128));
|
||||
|
||||
sum4 =
|
||||
sum4 =
|
||||
_mm512_add_epi32 ( sum4, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 3)),
|
||||
@@ -262,13 +262,13 @@ void packb_nr64_s8s8s32os32_row_major
|
||||
a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1]
|
||||
c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3]
|
||||
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
( ( jc * KC_updated ) + ( ( kr + 0 ) * NR ) ), a01 );
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
( ( jc * KC_updated ) + ( ( kr + 1 ) * NR ) ) , a0 );
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
( ( jc * KC_updated ) + ( ( kr + 2 ) * NR ) ), c01 );
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
( ( jc * KC_updated ) + ( ( kr + 3 ) * NR ) ), c0 );
|
||||
}
|
||||
// Handle k remainder.
|
||||
@@ -282,25 +282,25 @@ void packb_nr64_s8s8s32os32_row_major
|
||||
d0 = _mm512_setzero_si512();
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0, c0)
|
||||
sum1 =
|
||||
sum1 =
|
||||
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 0)),
|
||||
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 0)))), mul_128));
|
||||
|
||||
sum2 =
|
||||
sum2 =
|
||||
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 1)),
|
||||
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 1)))), mul_128));
|
||||
|
||||
sum3 =
|
||||
sum3 =
|
||||
_mm512_add_epi32 ( sum3, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 2)),
|
||||
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( c0, 2)))), mul_128));
|
||||
|
||||
sum4 =
|
||||
sum4 =
|
||||
_mm512_add_epi32 ( sum4, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 3)),
|
||||
@@ -315,22 +315,22 @@ void packb_nr64_s8s8s32os32_row_major
|
||||
d0 = _mm512_setzero_si512();
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0)
|
||||
sum1 =
|
||||
sum1 =
|
||||
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 0)),
|
||||
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 0))), mul_128));
|
||||
|
||||
sum2 =
|
||||
sum2 =
|
||||
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 1)),
|
||||
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 1))), mul_128));
|
||||
|
||||
sum3 =
|
||||
sum3 =
|
||||
_mm512_add_epi32 ( sum3, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 2)),
|
||||
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 2))), mul_128));
|
||||
|
||||
sum4 =
|
||||
sum4 =
|
||||
_mm512_add_epi32 ( sum4, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( a0, 3)),
|
||||
_mm512_cvtepi8_epi32(_mm512_extracti32x4_epi32 ( b0, 3))), mul_128));
|
||||
@@ -378,13 +378,13 @@ void packb_nr64_s8s8s32os32_row_major
|
||||
a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1]
|
||||
c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3]
|
||||
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
( ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) ), a01 );
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
( ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) ) , a0 );
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
( ( jc * KC_updated ) + ( ( k_full_pieces + 2 ) * NR ) ), c01 );
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 +
|
||||
( ( jc * KC_updated ) + ( ( k_full_pieces + 3 ) * NR ) ), c0 );
|
||||
}
|
||||
//store the sum column
|
||||
@@ -506,14 +506,14 @@ void packb_nr48_s8s8s32os32_row_major
|
||||
d0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 3 ) ) );
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0, c0, d0)
|
||||
sum1 =
|
||||
sum1 =
|
||||
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)),
|
||||
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( d0_32, 0))))) , mul_128));
|
||||
|
||||
sum2 =
|
||||
sum2 =
|
||||
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)),
|
||||
@@ -553,11 +553,11 @@ void packb_nr48_s8s8s32os32_row_major
|
||||
d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 3 ) ) + ( 32 ) );
|
||||
|
||||
//add all the columns : sum = add (sum, a0_32, b0_32, c0_32, d0_32)
|
||||
sum3 =
|
||||
_mm512_add_epi32
|
||||
( sum3, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ),
|
||||
sum3 =
|
||||
_mm512_add_epi32
|
||||
( sum3, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ),
|
||||
_mm512_cvtepi8_epi32( d0_16 )))) , mul_128 )
|
||||
);
|
||||
|
||||
@@ -595,7 +595,7 @@ void packb_nr48_s8s8s32os32_row_major
|
||||
d0_32 = _mm256_setzero_si256();
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0, c0)
|
||||
sum1 =
|
||||
sum1 =
|
||||
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)),
|
||||
@@ -612,10 +612,10 @@ void packb_nr48_s8s8s32os32_row_major
|
||||
c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( k_full_pieces + 2 ) ) + ( 32 ) );
|
||||
d0_16 = _mm_setzero_si128();
|
||||
|
||||
sum3 =
|
||||
_mm512_add_epi32
|
||||
sum3 =
|
||||
_mm512_add_epi32
|
||||
( sum3, _mm512_sllv_epi32 (_mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
|
||||
_mm512_cvtepi8_epi32( c0_16 ))) , mul_128)
|
||||
);
|
||||
}
|
||||
@@ -627,12 +627,12 @@ void packb_nr48_s8s8s32os32_row_major
|
||||
d0_32 = _mm256_setzero_si256();
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0)
|
||||
sum1 =
|
||||
sum1 =
|
||||
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)),
|
||||
_mm512_cvtepi8_epi32( _mm256_extracti32x4_epi32 ( b0_32, 0) )) , mul_128 ));
|
||||
|
||||
sum2 =
|
||||
sum2 =
|
||||
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)),
|
||||
_mm512_cvtepi8_epi32( _mm256_extracti32x4_epi32 ( b0_32, 1) )) , mul_128 ));
|
||||
@@ -643,7 +643,7 @@ void packb_nr48_s8s8s32os32_row_major
|
||||
d0_16 = _mm_setzero_si128();
|
||||
|
||||
sum3 =
|
||||
_mm512_add_epi32
|
||||
_mm512_add_epi32
|
||||
( sum3, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
|
||||
_mm512_cvtepi8_epi32( b0_16 )) , mul_128)
|
||||
);
|
||||
@@ -656,11 +656,11 @@ void packb_nr48_s8s8s32os32_row_major
|
||||
d0_32 = _mm256_setzero_si256();
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0)
|
||||
sum1 =
|
||||
sum1 =
|
||||
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
|
||||
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)) , mul_128));
|
||||
|
||||
sum2 =
|
||||
sum2 =
|
||||
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
|
||||
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)) , mul_128));
|
||||
|
||||
@@ -669,7 +669,7 @@ void packb_nr48_s8s8s32os32_row_major
|
||||
c0_16 = _mm_setzero_si128();
|
||||
d0_16 = _mm_setzero_si128();
|
||||
|
||||
sum3 =
|
||||
sum3 =
|
||||
_mm512_add_epi32 ( sum3, _mm512_sllv_epi32 (
|
||||
_mm512_cvtepi8_epi32( a0_16 ) , mul_128));
|
||||
}
|
||||
@@ -767,14 +767,14 @@ void packb_nr32_s8s8s32os32_row_major
|
||||
d0_32 = _mm256_maskz_loadu_epi8( 0xFFFFFFFF, b + ( ldb * ( kr + 3 ) ) );
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0, c0, d0)
|
||||
sum1 =
|
||||
sum1 =
|
||||
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)),
|
||||
_mm512_add_epi32 (_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)),
|
||||
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( d0_32, 0))))) , mul_128));
|
||||
|
||||
sum2 =
|
||||
sum2 =
|
||||
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)),
|
||||
@@ -822,13 +822,13 @@ void packb_nr32_s8s8s32os32_row_major
|
||||
d0_32 = _mm256_setzero_si256();
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0, c0)
|
||||
sum1 =
|
||||
sum1 =
|
||||
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0)),
|
||||
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( c0_32, 0)))) , mul_128));
|
||||
|
||||
sum2 =
|
||||
sum2 =
|
||||
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1)),
|
||||
@@ -843,12 +843,12 @@ void packb_nr32_s8s8s32os32_row_major
|
||||
d0_32 = _mm256_setzero_si256();
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0)
|
||||
sum1 =
|
||||
sum1 =
|
||||
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)),
|
||||
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 0))) , mul_128));
|
||||
|
||||
sum2 =
|
||||
sum2 =
|
||||
_mm512_add_epi32 ( sum2, _mm512_sllv_epi32 (
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 1)),
|
||||
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( b0_32, 1))) , mul_128));
|
||||
@@ -861,7 +861,7 @@ void packb_nr32_s8s8s32os32_row_major
|
||||
d0_32 = _mm256_setzero_si256();
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0)
|
||||
sum1 =
|
||||
sum1 =
|
||||
_mm512_add_epi32 ( sum1, _mm512_sllv_epi32 (
|
||||
_mm512_cvtepi8_epi32(_mm256_extracti32x4_epi32 ( a0_32, 0)) , mul_128));
|
||||
|
||||
@@ -941,10 +941,10 @@ void packb_nr16_s8s8s32os32_row_major
|
||||
d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, b + ( ldb * ( kr + 3 ) ) );
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0, c0, d0)
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ),
|
||||
_mm512_cvtepi8_epi32( d0_16 )))) , mul_128 )
|
||||
);
|
||||
@@ -983,9 +983,9 @@ void packb_nr16_s8s8s32os32_row_major
|
||||
d0_16 = _mm_setzero_si128();
|
||||
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
_mm512_add_epi32
|
||||
( sum1, _mm512_sllv_epi32 (_mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
|
||||
_mm512_cvtepi8_epi32( c0_16 ))) , mul_128)
|
||||
);
|
||||
}
|
||||
@@ -996,8 +996,8 @@ void packb_nr16_s8s8s32os32_row_major
|
||||
c0_16 = _mm_setzero_si128();
|
||||
d0_16 = _mm_setzero_si128();
|
||||
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
|
||||
_mm512_cvtepi8_epi32( b0_16 )) , mul_128)
|
||||
);
|
||||
@@ -1009,9 +1009,9 @@ void packb_nr16_s8s8s32os32_row_major
|
||||
c0_16 = _mm_setzero_si128();
|
||||
d0_16 = _mm_setzero_si128();
|
||||
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
( sum1,
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
( sum1,
|
||||
_mm512_sllv_epi32 ( _mm512_cvtepi8_epi32( a0_16 ) , mul_128 )
|
||||
);
|
||||
}
|
||||
@@ -1090,11 +1090,11 @@ void packb_nrlt16_s8s8s32os32_row_major
|
||||
d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf3 );
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0, c0, d0)
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
( sum1,
|
||||
_mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( c0_16 ),
|
||||
_mm512_cvtepi8_epi32( d0_16 )))) , mul_128 )
|
||||
);
|
||||
@@ -1118,7 +1118,7 @@ void packb_nrlt16_s8s8s32os32_row_major
|
||||
// Last 4x16 elements.
|
||||
_mm512_storeu_si512( pack_b_buffer_s8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm );
|
||||
|
||||
// The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the
|
||||
// The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the
|
||||
// original data, but is here due to the packing in 4 16byte chunks format.
|
||||
kr_new += 1;
|
||||
}
|
||||
@@ -1138,10 +1138,10 @@ void packb_nrlt16_s8s8s32os32_row_major
|
||||
d0_16 = _mm_setzero_si128();
|
||||
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
( sum1,
|
||||
_mm512_add_epi32
|
||||
( sum1,
|
||||
_mm512_sllv_epi32 (_mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
|
||||
_mm512_add_epi32 ( _mm512_cvtepi8_epi32( b0_16 ),
|
||||
_mm512_cvtepi8_epi32( c0_16 ))) , mul_128)
|
||||
);
|
||||
|
||||
@@ -1156,8 +1156,8 @@ void packb_nrlt16_s8s8s32os32_row_major
|
||||
c0_16 = _mm_setzero_si128();
|
||||
d0_16 = _mm_setzero_si128();
|
||||
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
( sum1, _mm512_sllv_epi32 ( _mm512_add_epi32 ( _mm512_cvtepi8_epi32( a0_16 ),
|
||||
_mm512_cvtepi8_epi32( b0_16 )) , mul_128)
|
||||
);
|
||||
@@ -1171,7 +1171,7 @@ void packb_nrlt16_s8s8s32os32_row_major
|
||||
c0_16 = _mm_setzero_si128();
|
||||
d0_16 = _mm_setzero_si128();
|
||||
|
||||
sum1 =
|
||||
sum1 =
|
||||
_mm512_add_epi32
|
||||
( sum1,
|
||||
_mm512_sllv_epi32 ( _mm512_cvtepi8_epi32( a0_16 ) , mul_128 )
|
||||
@@ -1399,7 +1399,7 @@ void packb_nr64_s8s8s32os32_col_major
|
||||
}
|
||||
|
||||
*rs_p = NR * 4;
|
||||
*cs_p = NR / 4;
|
||||
*cs_p = NR;
|
||||
}
|
||||
|
||||
//Extract 16 8-bit elements from each 128-bit lane of 512-bit register and convert them into
|
||||
@@ -1891,7 +1891,7 @@ void packb_nrlt16_s8s8s32o32_col_major
|
||||
}
|
||||
|
||||
// sum/reduce < 16 (max 15) int32 values into one final sum as int.
|
||||
// insert sum of all columns into one 512 bit, multiply with 128 and
|
||||
// insert sum of all columns into one 512 bit, multiply with 128 and
|
||||
// store into pack_b_column_sum
|
||||
__m512i sum0, sum1;
|
||||
sum0 = _mm512_set_epi32
|
||||
|
||||
Reference in New Issue
Block a user