Added new Int8 batch_gemm APIs

Details:
- Added u8s8s32of32|bf16|u8 batch_gemm APIs.
- Fixed some bugs in bench file for bf16 API.

Change-Id: I55380238869350a848f2deec0641d7b9b416b192
This commit is contained in:
Meghana Vankadari
2025-02-07 05:38:52 +05:30
parent 3a7523b51b
commit da3d0c6034
6 changed files with 2341 additions and 948 deletions

View File

@@ -531,3 +531,735 @@ err_hndl:;
LPGEMM_STOP_LOGGER();
}
AOCL_BGEMM_MATMUL(int8_t,int8_t,float,int32_t,s8s8s32of32)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"s8s8s32of32", \
order, transa, transb, \
batch_size, m, n, k, \
alpha, \
lda, mem_format_a, \
ldb, mem_format_b, \
beta, \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
int8_t *a_local[batch_size];
int8_t *b_local[batch_size];
dim_t m_local[batch_size], n_local[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
{
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
"cannot perform s8s8s32of32 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_s8s8s32of32",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
if( is_column_major == TRUE )
{
// Column major support disabled for int API's till micro-kernel
// post-ops are updated to account for column major.
if (post_op_unparsed[bs_i] != NULL )
{
bli_print_msg("Column major inputs not supported with Post-ops.",
__FILE__, __LINE__);
goto err_hndl;
}
rs_a[bs_i] = ldb[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = ldb[bs_i];
}
rs_b[bs_i] = lda[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = lda[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
// Inputs swapped in column major, A becomes B from kernel point of view.
if ( bli_is_trans(blis_transb ) )
{
mtag_a[bs_i] = PACK;
}
// swap m & n in case of col-major matrices
m_local[bs_i] = n[bs_i];
n_local[bs_i] = m[bs_i];
// swap a & b pointers in case of col-major matrices
a_local[bs_i] = (int8_t*)(b[bs_i]);
b_local[bs_i] = (int8_t*)(a[bs_i]);
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
// copy the values of m & n
m_local[bs_i] = m[bs_i];
n_local[bs_i] = n[bs_i];
// copy the values of a & b pointers
a_local[bs_i] = (int8_t*)(a[bs_i]);
b_local[bs_i] = (int8_t*)(b[bs_i]);
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_s8s8s32o32_openmp_thread_decorator
(
batch_size, m_local, n_local, k,
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
#else
batch_lpgemm_s8s8s32o32_thread_decorator
(
batch_size, m_local, n_local, k,
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}
AOCL_BGEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"s8s8s32obf16", \
order, transa, transb, \
batch_size, m, n, k, \
alpha, \
lda, mem_format_a, \
ldb, mem_format_b, \
beta, \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
int8_t *a_local[batch_size];
int8_t *b_local[batch_size];
dim_t m_local[batch_size], n_local[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
{
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
"cannot perform s8s8s32obf16 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_s8s8s32obf16",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
if( is_column_major == TRUE )
{
// Column major support disabled for int API's till micro-kernel
// post-ops are updated to account for column major.
if (post_op_unparsed[bs_i] != NULL )
{
bli_print_msg("Column major inputs not supported with Post-ops.",
__FILE__, __LINE__);
goto err_hndl;
}
rs_a[bs_i] = ldb[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = ldb[bs_i];
}
rs_b[bs_i] = lda[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = lda[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
// Inputs swapped in column major, A becomes B from kernel point of view.
if ( bli_is_trans(blis_transb ) )
{
mtag_a[bs_i] = PACK;
}
// swap m & n in case of col-major matrices
m_local[bs_i] = n[bs_i];
n_local[bs_i] = m[bs_i];
// swap a & b pointers in case of col-major matrices
a_local[bs_i] = (int8_t*)(b[bs_i]);
b_local[bs_i] = (int8_t*)(a[bs_i]);
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
// copy the values of m & n
m_local[bs_i] = m[bs_i];
n_local[bs_i] = n[bs_i];
// copy the values of a & b pointers
a_local[bs_i] = (int8_t*)(a[bs_i]);
b_local[bs_i] = (int8_t*)(b[bs_i]);
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_s8s8s32o32_openmp_thread_decorator
(
batch_size, m_local, n_local, k,
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, BF16
);
#else
batch_lpgemm_s8s8s32o32_thread_decorator
(
batch_size, m_local, n_local, k,
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, BF16
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}
AOCL_BGEMM_MATMUL(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"s8s8s32ou8", \
order, transa, transb, \
batch_size, m, n, k, \
alpha, \
lda, mem_format_a, \
ldb, mem_format_b, \
beta, \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
int8_t *a_local[batch_size];
int8_t *b_local[batch_size];
dim_t m_local[batch_size], n_local[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
{
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
"cannot perform s8s8s32ou8 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_s8s8s32ou8",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
if( is_column_major == TRUE )
{
// Column major support disabled for int API's till micro-kernel
// post-ops are updated to account for column major.
if (post_op_unparsed[bs_i] != NULL )
{
bli_print_msg("Column major inputs not supported with Post-ops.",
__FILE__, __LINE__);
goto err_hndl;
}
rs_a[bs_i] = ldb[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = ldb[bs_i];
}
rs_b[bs_i] = lda[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = lda[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
// Inputs swapped in column major, A becomes B from kernel point of view.
if ( bli_is_trans(blis_transb ) )
{
mtag_a[bs_i] = PACK;
}
// swap m & n in case of col-major matrices
m_local[bs_i] = n[bs_i];
n_local[bs_i] = m[bs_i];
// swap a & b pointers in case of col-major matrices
a_local[bs_i] = (int8_t*)(b[bs_i]);
b_local[bs_i] = (int8_t*)(a[bs_i]);
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
// copy the values of m & n
m_local[bs_i] = m[bs_i];
n_local[bs_i] = n[bs_i];
// copy the values of a & b pointers
a_local[bs_i] = (int8_t*)(a[bs_i]);
b_local[bs_i] = (int8_t*)(b[bs_i]);
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_s8s8s32o32_openmp_thread_decorator
(
batch_size, m_local, n_local, k,
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, U8
);
#else
batch_lpgemm_s8s8s32o32_thread_decorator
(
batch_size, m_local, n_local, k,
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, U8
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -409,3 +409,551 @@ err_hndl:;
LPGEMM_STOP_LOGGER();
}
AOCL_BGEMM_MATMUL(uint8_t,int8_t,float,int32_t,u8s8s32of32)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"u8s8s32of32", \
order, transa, transb, \
batch_size, m, n, k, \
alpha, \
lda, mem_format_a, \
ldb, mem_format_b, \
beta, \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
{
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
"cannot perform u8s8s32of32 gemm.", __FILE__, __LINE__ );
goto err_hndl; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_u8s8s32of32",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
if( is_column_major == TRUE )
{
bli_print_msg("Column major inputs not supported.",
__FILE__, __LINE__);
goto err_hndl;
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_u8s8s32o32_openmp_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
#else
batch_lpgemm_u8s8s32o32_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}
AOCL_BGEMM_MATMUL(uint8_t,int8_t,bfloat16,int32_t,u8s8s32obf16)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"u8s8s32obf16", \
order, transa, transb, \
batch_size, m, n, k, \
alpha, \
lda, mem_format_a, \
ldb, mem_format_b, \
beta, \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
{
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
"cannot perform u8s8s32obf16 gemm.", __FILE__, __LINE__ );
goto err_hndl; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_u8s8s32obf16",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
if( is_column_major == TRUE )
{
bli_print_msg("Column major inputs not supported.",
__FILE__, __LINE__);
goto err_hndl;
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_u8s8s32o32_openmp_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, BF16
);
#else
batch_lpgemm_u8s8s32o32_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, BF16
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}
AOCL_BGEMM_MATMUL(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8)
{
LPGEMM_START_LOGGER();
BATCH_LPGEMM_WRITE_LOGGER \
(
"u8s8s32ou8", \
order, transa, transb, \
batch_size, m, n, k, \
alpha, \
lda, mem_format_a, \
ldb, mem_format_b, \
beta, \
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
{
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
"cannot perform u8s8s32ou8 gemm.", __FILE__, __LINE__ );
goto err_hndl; // Error.
}
/* Initialize BLIS. */
bli_init_auto();
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
trans_t blis_transa;
trans_t blis_transb;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_u8s8s32ou8",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
if( is_column_major == TRUE )
{
bli_print_msg("Column major inputs not supported.",
__FILE__, __LINE__);
goto err_hndl;
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_u8s8s32o32_openmp_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, U8
);
#else
batch_lpgemm_u8s8s32o32_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(int32_t**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, U8
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -173,15 +173,25 @@ BLIS_EXPORT_ADDON void aocl_batch_gemm_ ## LP_SFX \
aocl_post_op** post_op_unparsed \
) \
// bf16 APIs
AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32);
AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16);
AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32);
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32);
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8);
AOCL_BGEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32);
AOCL_BGEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8);
AOCL_BGEMM_MATMUL(bfloat16,int8_t,float,float,bf16s4f32of32);
AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16);
// f32 APIs
AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32);
// u8s8 APIs
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32);
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8);
AOCL_BGEMM_MATMUL(uint8_t,int8_t,float,int32_t,u8s8s32of32);
AOCL_BGEMM_MATMUL(uint8_t,int8_t,bfloat16,int32_t,u8s8s32obf16);
AOCL_BGEMM_MATMUL(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8);
// s8s8 APIs
AOCL_BGEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32);
AOCL_BGEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8);
AOCL_BGEMM_MATMUL(int8_t,int8_t,float,int32_t,s8s8s32of32);
AOCL_BGEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16);
AOCL_BGEMM_MATMUL(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8);

File diff suppressed because it is too large Load Diff

View File

@@ -215,36 +215,6 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
(\
ACCUM_type temp_accum,\
aocl_post_op* post_op, \
dim_t j \
)\
{ \
dim_t j_scale = j; \
if ( ( post_op->sum )->scale_factor_len == 1 ) \
{ \
j_scale = 0; \
} \
\
dim_t j_zp = j; \
if ( ( post_op->sum )->zero_point_len == 1 ) \
{ \
j_zp = 0; \
} \
\
ACCUM_type out_temp_accum = \
( ACCUM_type )min( \
max( nearbyintf( ( SCALE_type )( temp_accum ) * \
( *( ( SCALE_type* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
*( ( C_type* )( post_op->sum )->zero_point + j_zp ), \
DSCALE_CLIP_MIN ), \
DSCALE_CLIP_MAX ); \
return out_temp_accum; \
}\
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,u8s8s32ou8)
@@ -257,92 +227,6 @@ GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,s8s8s32of32)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,s8s8s32obf16)
static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float zp_float = 0.0;
bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->zero_point + j_zp ),
&zp_float );
float out_temp_accum = ( temp_accum *
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
zp_float );
return out_temp_accum;
}
static inline float mat_mul_accuracy_check_downscale_f32f32f32of32
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float out_temp_accum = ( temp_accum *
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
*( ( float* )( post_op->sum )->zero_point + j_zp ) );
return out_temp_accum;
}
#define GEN_MAT_MUL_ACC_CHK_ACCUM(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \
static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
(\
A_type* a, \
B_type* b, \
C_type* c_ref, \
ACCUM_type temp_accum,\
ACCUM_type alpha, \
ACCUM_type beta, \
dim_t rs_a, \
dim_t rs_b, \
dim_t cs_a, \
dim_t cs_b, \
dim_t rs_c_ref, \
dim_t cs_c_ref, \
dim_t i, \
dim_t j, \
dim_t k, \
dim_t pre_op_ld, /* Ignored */ \
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
) \
{ \
( void ) pre_op; \
temp_accum = (ACCUM_type) 0; \
for ( dim_t p = 0; p < k; ++p) \
{ \
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \
*( b + ( rs_b * p ) + ( cs_b * j ) ) ); \
} \
\
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) \
+ ( alpha * temp_accum ); \
return temp_accum; \
} \
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
@@ -353,370 +237,6 @@ GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8)
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
static inline int32_t mat_mul_accuracy_check_accum_u8s8s32obf16
(
uint8_t* a,
int8_t* b,
bfloat16* c_ref,
int32_t temp_accum,
int32_t alpha,
int32_t beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld, /* Ignored */
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
)
{
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
}
float c_ref_float;
bfloat16_to_float(( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ), &c_ref_float);
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
return temp_accum;
}
static inline int32_t mat_mul_accuracy_check_accum_s8s8s32obf16
(
int8_t* a,
int8_t* b,
bfloat16* c_ref,
int32_t temp_accum,
int32_t alpha,
int32_t beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld, /* Ignored */
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
)
{
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
}
float c_ref_float;
bfloat16_to_float(( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ), &c_ref_float);
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
return temp_accum;
}
static inline int32_t mat_mul_accuracy_check_accum_u8s8s32of32
(
uint8_t* a,
int8_t* b,
float* c_ref,
int32_t temp_accum,
int32_t alpha,
int32_t beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld, /* Ignored */
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
)
{
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
}
float c_ref_float = *(c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) );
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
return temp_accum;
}
static inline int32_t mat_mul_accuracy_check_accum_s8s8s32of32
(
int8_t* a,
int8_t* b,
float* c_ref,
int32_t temp_accum,
int32_t alpha,
int32_t beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld, /* Ignored */
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
)
{
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
}
float c_ref_float = *(c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) );
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
return temp_accum;
}
static inline float mat_mul_accuracy_check_accum_bf16bf16f32of32
(
bfloat16* a,
bfloat16* b,
float* c_ref,
float temp_accum,
float alpha,
float beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld, /* Ignored */ \
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
)
{
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
float a_float, b_float;
bfloat16_to_float( *( a + i * rs_a + p * cs_a ) , &a_float);
bfloat16_to_float( *( b + p * rs_b + j * cs_b ) , &b_float);
temp_accum += ( ( a_float ) * ( b_float ) );
}
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) )
+ ( alpha * temp_accum );
return temp_accum;
}
static inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16
(
bfloat16* a,
bfloat16* b,
bfloat16* c_ref,
float temp_accum,
float alpha,
float beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld, /* Ignored */
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
)
{
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
float a_float, b_float;
bfloat16_to_float( *( a + i*rs_a + p*cs_a ), &a_float );
bfloat16_to_float( *( b + p*rs_b + j*cs_b ), &b_float );
temp_accum += ( ( a_float ) * ( b_float ) );
}
float c_ref_float;
bfloat16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ), &c_ref_float );
temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum );
return temp_accum;
}
static inline float get_s4_to_f32_scale_val
(
int8_t* b,
dim_t p,
dim_t j,
dim_t n,
dim_t b_inc,
aocl_pre_op* pre_op
)
{
float b_float = 0.0;
int8_t b_val = 0;
dim_t group_size = pre_op->group_size;
/* Even index will have data at low 4 bits, and odd at hi 4 bits.
* B matrix increments has to be halved to account for 4 bit
* traversal. */
if ( ( b_inc % 2 ) != 0 )
{
b_val = ( ( *( b + ( b_inc / 2 ) ) ) >> 4 ) & 0x0F;
}
else
{
b_val = ( *( b + ( b_inc / 2 ) ) ) & 0x0F;
}
/* Signed scale. */
if ( b_val & 0x08 )
{
b_val = b_val | 0xF0;
}
if ( ( pre_op != NULL ) && ( pre_op->seq_length > 0 ) )
{
dim_t j_zp=0, j_scale=0;
if(group_size!=0)
{
j_zp = ( ( p / group_size ) * n ) + j;
if ( ( pre_op->b_zp != NULL ) &&
( ( pre_op->b_zp )->zero_point_len == 1 ) )
{
j_zp = p / group_size;
}
j_scale = ( ( p / group_size ) * n ) + j;
if ( ( pre_op->b_scl != NULL ) &&
( ( pre_op->b_scl )->scale_factor_len == 1 ) )
{
j_scale = (p / group_size);
}
}
// Assuming only 1 scale and zp.
int8_t zp = 0;
if ( ( pre_op->b_zp != NULL ) &&
( ( pre_op->b_zp )->zero_point != NULL ) )
{
zp = *( ( int8_t* )( pre_op->b_zp )->zero_point + j_zp );
}
float scale_factor = 1.0;
if ( ( pre_op->b_scl != NULL ) &&
( ( pre_op->b_scl )->scale_factor != NULL ) )
{
if( pre_op->b_scl->scale_factor_type == AOCL_GEMM_F32 )
{
scale_factor = *( ( float* )( pre_op->b_scl )->scale_factor + j_scale );
}
else
{
bfloat16_to_float( *( ( bfloat16* )( pre_op->b_scl )->scale_factor + j_scale ) , &scale_factor);
}
}
b_float = (float)( b_val - zp ) * scale_factor;
}
else
{
b_float = (float)( b_val);
}
return b_float;
}
static inline float mat_mul_accuracy_check_accum_bf16s4f32of32
(
bfloat16* a,
int8_t* b,
float* c_ref,
float temp_accum,
float alpha,
float beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld,
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
)
{
for ( dim_t p = 0; p < k; ++p)
{
float a_float, b_float;
bfloat16_to_float( *( a + i * rs_a + p * cs_a ) , &a_float);
/* Get B matrix int4_t value and upscale it to float. */
dim_t b_inc = ( rs_b * p ) + ( cs_b * j );
b_float = get_s4_to_f32_scale_val( b, p, j, pre_op_ld, b_inc, pre_op );
temp_accum += ( ( a_float ) * ( b_float ) );
}
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) )
+ ( alpha * temp_accum );
return temp_accum;
}
static inline float mat_mul_accuracy_check_accum_bf16s4f32obf16
(
bfloat16* a,
int8_t* b,
bfloat16* c_ref,
float temp_accum,
float alpha,
float beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld,
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
)
{
for ( dim_t p = 0; p < k; ++p)
{
float a_float, b_float;
bfloat16_to_float( *( a + i*rs_a + p*cs_a ), &a_float );
/* Get B matrix int4_t value and upscale it to float. */
dim_t b_inc = ( rs_b * p ) + ( cs_b * j );
b_float = get_s4_to_f32_scale_val( b, p, j, pre_op_ld, b_inc, pre_op );
temp_accum += ( ( a_float ) * ( b_float ) );
}
float c_ref_float;
bfloat16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ), &c_ref_float );
temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum );
return temp_accum;
}
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os8)
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32ou8)
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os32)
@@ -873,8 +393,8 @@ GEN_CLIP_POST_OP_VAL_INT(s8s8s32ou8)
GEN_CLIP_POST_OP_VAL_INT(s8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16bf16f32obf16)
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16s4f32obf16)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32obf16)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32obf16)
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32os8)
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32ou8)
@@ -1176,8 +696,8 @@ GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,uint8_t,float,int32_t,s8s8s32ou8)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,float,float,int32_t,s8s8s32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,bfloat16,float,int32_t,s8s8s32obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16bf16f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,bfloat16,bf16bf16f32obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,float,bf16bf16f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,float,bf16bf16f32obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16s4f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,bfloat16,bf16s4f32obf16)

View File

@@ -469,9 +469,8 @@ static inline ACCUM_type SWISH_post_op_ ## BLAS_SFX \
float alpha_val; \
int32_t_to_float(*( ( int32_t* )alpha), &alpha_val); \
float swish_reference = ( temp_accum / ( 1 + \
expf( ( double )(alpha_val) * temp_accum * -1 ) ) ); \
temp_accum = round (swish_reference); \
return temp_accum; \
expf( ( double )((alpha_val) * temp_accum * -1 )) ) ); \
return swish_reference; \
} \
#define GEN_SWISH_POSTOP_FLOAT(BLAS_SFX) \
@@ -766,11 +765,493 @@ void print_matrix_bfloat16
{
float temp;
bfloat16_to_float(*(a + i*(rs_a) + j *cs_a), &temp);
printf("%f ", temp);
}
printf("\n");
}
}
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
(\
ACCUM_type temp_accum,\
aocl_post_op* post_op, \
dim_t j \
)\
{ \
dim_t j_scale = j; \
if ( ( post_op->sum )->scale_factor_len == 1 ) \
{ \
j_scale = 0; \
} \
\
dim_t j_zp = j; \
if ( ( post_op->sum )->zero_point_len == 1 ) \
{ \
j_zp = 0; \
} \
\
ACCUM_type out_temp_accum = \
( ACCUM_type )min( \
max( nearbyintf( ( SCALE_type )( temp_accum ) * \
( *( ( SCALE_type* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
*( ( C_type* )( post_op->sum )->zero_point + j_zp ), \
DSCALE_CLIP_MIN ), \
DSCALE_CLIP_MAX ); \
return out_temp_accum; \
}\
static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float zp_float = 0.0;
bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->zero_point + j_zp ),
&zp_float );
float out_temp_accum = ( temp_accum *
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
zp_float );
return out_temp_accum;
}
static inline float mat_mul_accuracy_check_downscale_f32f32f32of32
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float out_temp_accum = ( temp_accum *
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
*( ( float* )( post_op->sum )->zero_point + j_zp ) );
return out_temp_accum;
}
#define GEN_MAT_MUL_ACC_CHK_ACCUM(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \
static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
(\
A_type* a, \
B_type* b, \
C_type* c_ref, \
ACCUM_type temp_accum,\
ACCUM_type alpha, \
ACCUM_type beta, \
dim_t rs_a, \
dim_t rs_b, \
dim_t cs_a, \
dim_t cs_b, \
dim_t rs_c_ref, \
dim_t cs_c_ref, \
dim_t i, \
dim_t j, \
dim_t k, \
dim_t pre_op_ld, /* Ignored */ \
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
) \
{ \
( void ) pre_op; \
temp_accum = (ACCUM_type) 0; \
for ( dim_t p = 0; p < k; ++p) \
{ \
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \
*( b + ( rs_b * p ) + ( cs_b * j ) ) ); \
} \
\
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) \
+ ( alpha * temp_accum ); \
return temp_accum; \
} \
static inline int32_t mat_mul_accuracy_check_accum_u8s8s32obf16
(
uint8_t* a,
int8_t* b,
bfloat16* c_ref,
int32_t temp_accum,
int32_t alpha,
int32_t beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld, /* Ignored */
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
)
{
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
}
float c_ref_float;
bfloat16_to_float(( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ), &c_ref_float);
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
return temp_accum;
}
static inline int32_t mat_mul_accuracy_check_accum_s8s8s32obf16
(
int8_t* a,
int8_t* b,
bfloat16* c_ref,
int32_t temp_accum,
int32_t alpha,
int32_t beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld, /* Ignored */
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
)
{
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
}
float c_ref_float;
bfloat16_to_float(( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ), &c_ref_float);
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
return temp_accum;
}
static inline int32_t mat_mul_accuracy_check_accum_u8s8s32of32
(
uint8_t* a,
int8_t* b,
float* c_ref,
int32_t temp_accum,
int32_t alpha,
int32_t beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld, /* Ignored */
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
)
{
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
}
float c_ref_float = *(c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) );
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
return temp_accum;
}
static inline int32_t mat_mul_accuracy_check_accum_s8s8s32of32
(
int8_t* a,
int8_t* b,
float* c_ref,
int32_t temp_accum,
int32_t alpha,
int32_t beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld, /* Ignored */
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
)
{
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
}
float c_ref_float = *(c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) );
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
return temp_accum;
}
static inline float mat_mul_accuracy_check_accum_bf16bf16f32of32
(
bfloat16* a,
bfloat16* b,
float* c_ref,
float temp_accum,
float alpha,
float beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld, /* Ignored */ \
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
)
{
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
float a_float, b_float;
bfloat16_to_float( *( a + i * rs_a + p * cs_a ) , &a_float);
bfloat16_to_float( *( b + p * rs_b + j * cs_b ) , &b_float);
temp_accum += ( ( a_float ) * ( b_float ) );
}
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) )
+ ( alpha * temp_accum );
return temp_accum;
}
static inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16
(
bfloat16* a,
bfloat16* b,
bfloat16* c_ref,
float temp_accum,
float alpha,
float beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld, /* Ignored */
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
)
{
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
float a_float, b_float;
bfloat16_to_float( *( a + i*rs_a + p*cs_a ), &a_float );
bfloat16_to_float( *( b + p*rs_b + j*cs_b ), &b_float );
temp_accum += ( ( a_float ) * ( b_float ) );
}
float c_ref_float;
bfloat16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ), &c_ref_float );
temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum );
return temp_accum;
}
static inline float get_s4_to_f32_scale_val
(
int8_t* b,
dim_t p,
dim_t j,
dim_t n,
dim_t b_inc,
aocl_pre_op* pre_op
)
{
float b_float = 0.0;
int8_t b_val = 0;
dim_t group_size = pre_op->group_size;
/* Even index will have data at low 4 bits, and odd at hi 4 bits.
* B matrix increments has to be halved to account for 4 bit
* traversal. */
if ( ( b_inc % 2 ) != 0 )
{
b_val = ( ( *( b + ( b_inc / 2 ) ) ) >> 4 ) & 0x0F;
}
else
{
b_val = ( *( b + ( b_inc / 2 ) ) ) & 0x0F;
}
/* Signed scale. */
if ( b_val & 0x08 )
{
b_val = b_val | 0xF0;
}
if ( ( pre_op != NULL ) && ( pre_op->seq_length > 0 ) )
{
dim_t j_zp=0, j_scale=0;
if(group_size!=0)
{
j_zp = ( ( p / group_size ) * n ) + j;
if ( ( pre_op->b_zp != NULL ) &&
( ( pre_op->b_zp )->zero_point_len == 1 ) )
{
j_zp = p / group_size;
}
j_scale = ( ( p / group_size ) * n ) + j;
if ( ( pre_op->b_scl != NULL ) &&
( ( pre_op->b_scl )->scale_factor_len == 1 ) )
{
j_scale = (p / group_size);
}
}
// Assuming only 1 scale and zp.
int8_t zp = 0;
if ( ( pre_op->b_zp != NULL ) &&
( ( pre_op->b_zp )->zero_point != NULL ) )
{
zp = *( ( int8_t* )( pre_op->b_zp )->zero_point + j_zp );
}
float scale_factor = 1.0;
if ( ( pre_op->b_scl != NULL ) &&
( ( pre_op->b_scl )->scale_factor != NULL ) )
{
if( pre_op->b_scl->scale_factor_type == AOCL_GEMM_F32 )
{
scale_factor = *( ( float* )( pre_op->b_scl )->scale_factor + j_scale );
}
else
{
bfloat16_to_float( *( ( bfloat16* )( pre_op->b_scl )->scale_factor + j_scale ) , &scale_factor);
}
}
b_float = (float)( b_val - zp ) * scale_factor;
}
else
{
b_float = (float)( b_val);
}
return b_float;
}
static inline float mat_mul_accuracy_check_accum_bf16s4f32of32
(
bfloat16* a,
int8_t* b,
float* c_ref,
float temp_accum,
float alpha,
float beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld,
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
)
{
for ( dim_t p = 0; p < k; ++p)
{
float a_float, b_float;
bfloat16_to_float( *( a + i * rs_a + p * cs_a ) , &a_float);
/* Get B matrix int4_t value and upscale it to float. */
dim_t b_inc = ( rs_b * p ) + ( cs_b * j );
b_float = get_s4_to_f32_scale_val( b, p, j, pre_op_ld, b_inc, pre_op );
temp_accum += ( ( a_float ) * ( b_float ) );
}
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) )
+ ( alpha * temp_accum );
return temp_accum;
}
static inline float mat_mul_accuracy_check_accum_bf16s4f32obf16
(
bfloat16* a,
int8_t* b,
bfloat16* c_ref,
float temp_accum,
float alpha,
float beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
dim_t pre_op_ld,
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
)
{
for ( dim_t p = 0; p < k; ++p)
{
float a_float, b_float;
bfloat16_to_float( *( a + i*rs_a + p*cs_a ), &a_float );
/* Get B matrix int4_t value and upscale it to float. */
dim_t b_inc = ( rs_b * p ) + ( cs_b * j );
b_float = get_s4_to_f32_scale_val( b, p, j, pre_op_ld, b_inc, pre_op );
temp_accum += ( ( a_float ) * ( b_float ) );
}
float c_ref_float;
bfloat16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ), &c_ref_float );
temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum );
return temp_accum;
}
#define GEN_MAT_MUL_POST_OPS_CREATOR(C_DSCALE_type,C_type,DSCALE_type,BIAS_type,BLAS_SFX) \
static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \