mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Added new Int8 batch_gemm APIs
Details: - Added u8s8s32of32|bf16|u8 batch_gemm APIs. - Fixed some bugs in bench file for bf16 API. Change-Id: I55380238869350a848f2deec0641d7b9b416b192
This commit is contained in:
@@ -531,3 +531,735 @@ err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,float,int32_t,s8s8s32of32)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"s8s8s32of32", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
alpha, \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
beta, \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
int8_t *a_local[batch_size];
|
||||
int8_t *b_local[batch_size];
|
||||
dim_t m_local[batch_size], n_local[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
|
||||
"cannot perform s8s8s32of32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_s8s8s32of32",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
// Column major support disabled for int API's till micro-kernel
|
||||
// post-ops are updated to account for column major.
|
||||
if (post_op_unparsed[bs_i] != NULL )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported with Post-ops.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
rs_a[bs_i] = ldb[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = lda[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
|
||||
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
if ( bli_is_trans(blis_transb ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// swap m & n in case of col-major matrices
|
||||
m_local[bs_i] = n[bs_i];
|
||||
n_local[bs_i] = m[bs_i];
|
||||
|
||||
// swap a & b pointers in case of col-major matrices
|
||||
a_local[bs_i] = (int8_t*)(b[bs_i]);
|
||||
b_local[bs_i] = (int8_t*)(a[bs_i]);
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[bs_i] = m[bs_i];
|
||||
n_local[bs_i] = n[bs_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[bs_i] = (int8_t*)(a[bs_i]);
|
||||
b_local[bs_i] = (int8_t*)(b[bs_i]);
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_s8s8s32o32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_s8s8s32o32_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"s8s8s32obf16", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
alpha, \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
beta, \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
int8_t *a_local[batch_size];
|
||||
int8_t *b_local[batch_size];
|
||||
dim_t m_local[batch_size], n_local[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
|
||||
"cannot perform s8s8s32obf16 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_s8s8s32obf16",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
// Column major support disabled for int API's till micro-kernel
|
||||
// post-ops are updated to account for column major.
|
||||
if (post_op_unparsed[bs_i] != NULL )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported with Post-ops.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
rs_a[bs_i] = ldb[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = lda[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
|
||||
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
if ( bli_is_trans(blis_transb ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// swap m & n in case of col-major matrices
|
||||
m_local[bs_i] = n[bs_i];
|
||||
n_local[bs_i] = m[bs_i];
|
||||
|
||||
// swap a & b pointers in case of col-major matrices
|
||||
a_local[bs_i] = (int8_t*)(b[bs_i]);
|
||||
b_local[bs_i] = (int8_t*)(a[bs_i]);
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[bs_i] = m[bs_i];
|
||||
n_local[bs_i] = n[bs_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[bs_i] = (int8_t*)(a[bs_i]);
|
||||
b_local[bs_i] = (int8_t*)(b[bs_i]);
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_s8s8s32o32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, BF16
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_s8s8s32o32_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, BF16
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"s8s8s32ou8", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
alpha, \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
beta, \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
int8_t *a_local[batch_size];
|
||||
int8_t *b_local[batch_size];
|
||||
dim_t m_local[batch_size], n_local[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
|
||||
"cannot perform s8s8s32ou8 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_s8s8s32ou8",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
// Column major support disabled for int API's till micro-kernel
|
||||
// post-ops are updated to account for column major.
|
||||
if (post_op_unparsed[bs_i] != NULL )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported with Post-ops.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
rs_a[bs_i] = ldb[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = lda[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
|
||||
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
if ( bli_is_trans(blis_transb ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// swap m & n in case of col-major matrices
|
||||
m_local[bs_i] = n[bs_i];
|
||||
n_local[bs_i] = m[bs_i];
|
||||
|
||||
// swap a & b pointers in case of col-major matrices
|
||||
a_local[bs_i] = (int8_t*)(b[bs_i]);
|
||||
b_local[bs_i] = (int8_t*)(a[bs_i]);
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[bs_i] = m[bs_i];
|
||||
n_local[bs_i] = n[bs_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[bs_i] = (int8_t*)(a[bs_i]);
|
||||
b_local[bs_i] = (int8_t*)(b[bs_i]);
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_s8s8s32o32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, U8
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_s8s8s32o32_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const int8_t**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, U8
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
|
||||
@@ -409,3 +409,551 @@ err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,float,int32_t,u8s8s32of32)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"u8s8s32of32", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
alpha, \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
beta, \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
|
||||
"cannot perform u8s8s32of32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_u8s8s32of32",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_u8s8s32o32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_u8s8s32o32_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,bfloat16,int32_t,u8s8s32obf16)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"u8s8s32obf16", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
alpha, \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
beta, \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
|
||||
"cannot perform u8s8s32obf16 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_u8s8s32obf16",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_u8s8s32o32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, BF16
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_u8s8s32o32_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, BF16
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8)
|
||||
{
|
||||
LPGEMM_START_LOGGER();
|
||||
BATCH_LPGEMM_WRITE_LOGGER \
|
||||
(
|
||||
"u8s8s32ou8", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
alpha, \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
beta, \
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
|
||||
"cannot perform u8s8s32ou8 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl; // Error.
|
||||
}
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_u8s8s32ou8",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_u8s8s32o32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, U8
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_u8s8s32o32_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(int32_t**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, U8
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
@@ -173,15 +173,25 @@ BLIS_EXPORT_ADDON void aocl_batch_gemm_ ## LP_SFX \
|
||||
aocl_post_op** post_op_unparsed \
|
||||
) \
|
||||
|
||||
// bf16 APIs
|
||||
AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32);
|
||||
AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16);
|
||||
AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32);
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32);
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8);
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32);
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8);
|
||||
AOCL_BGEMM_MATMUL(bfloat16,int8_t,float,float,bf16s4f32of32);
|
||||
AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16);
|
||||
// f32 APIs
|
||||
AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32);
|
||||
// u8s8 APIs
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32);
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8);
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,float,int32_t,u8s8s32of32);
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,bfloat16,int32_t,u8s8s32obf16);
|
||||
AOCL_BGEMM_MATMUL(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8);
|
||||
// s8s8 APIs
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32);
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8);
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,float,int32_t,s8s8s32of32);
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16);
|
||||
AOCL_BGEMM_MATMUL(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8);
|
||||
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -215,36 +215,6 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,float,int32_t,s8s8s32of32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
|
||||
|
||||
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
|
||||
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
|
||||
(\
|
||||
ACCUM_type temp_accum,\
|
||||
aocl_post_op* post_op, \
|
||||
dim_t j \
|
||||
)\
|
||||
{ \
|
||||
dim_t j_scale = j; \
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 ) \
|
||||
{ \
|
||||
j_scale = 0; \
|
||||
} \
|
||||
\
|
||||
dim_t j_zp = j; \
|
||||
if ( ( post_op->sum )->zero_point_len == 1 ) \
|
||||
{ \
|
||||
j_zp = 0; \
|
||||
} \
|
||||
\
|
||||
ACCUM_type out_temp_accum = \
|
||||
( ACCUM_type )min( \
|
||||
max( nearbyintf( ( SCALE_type )( temp_accum ) * \
|
||||
( *( ( SCALE_type* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
|
||||
*( ( C_type* )( post_op->sum )->zero_point + j_zp ), \
|
||||
DSCALE_CLIP_MIN ), \
|
||||
DSCALE_CLIP_MAX ); \
|
||||
return out_temp_accum; \
|
||||
}\
|
||||
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int32_t,float,u8s8s32ou8)
|
||||
@@ -257,92 +227,6 @@ GEN_MAT_MUL_ACC_CHK_DOWNSCALE(float,int32_t,float,s8s8s32of32)
|
||||
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(bfloat16,int32_t,float,s8s8s32obf16)
|
||||
|
||||
|
||||
static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16
|
||||
(
|
||||
float temp_accum,
|
||||
aocl_post_op* post_op,
|
||||
dim_t j
|
||||
)
|
||||
{
|
||||
dim_t j_scale = j;
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 )
|
||||
{
|
||||
j_scale = 0;
|
||||
}
|
||||
|
||||
dim_t j_zp = j;
|
||||
if ( ( post_op->sum )->zero_point_len == 1 )
|
||||
{
|
||||
j_zp = 0;
|
||||
}
|
||||
|
||||
float zp_float = 0.0;
|
||||
bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->zero_point + j_zp ),
|
||||
&zp_float );
|
||||
float out_temp_accum = ( temp_accum *
|
||||
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
|
||||
zp_float );
|
||||
return out_temp_accum;
|
||||
}
|
||||
static inline float mat_mul_accuracy_check_downscale_f32f32f32of32
|
||||
(
|
||||
float temp_accum,
|
||||
aocl_post_op* post_op,
|
||||
dim_t j
|
||||
)
|
||||
{
|
||||
dim_t j_scale = j;
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 )
|
||||
{
|
||||
j_scale = 0;
|
||||
}
|
||||
|
||||
dim_t j_zp = j;
|
||||
if ( ( post_op->sum )->zero_point_len == 1 )
|
||||
{
|
||||
j_zp = 0;
|
||||
}
|
||||
float out_temp_accum = ( temp_accum *
|
||||
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
|
||||
*( ( float* )( post_op->sum )->zero_point + j_zp ) );
|
||||
return out_temp_accum;
|
||||
}
|
||||
|
||||
#define GEN_MAT_MUL_ACC_CHK_ACCUM(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \
|
||||
static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
|
||||
(\
|
||||
A_type* a, \
|
||||
B_type* b, \
|
||||
C_type* c_ref, \
|
||||
ACCUM_type temp_accum,\
|
||||
ACCUM_type alpha, \
|
||||
ACCUM_type beta, \
|
||||
dim_t rs_a, \
|
||||
dim_t rs_b, \
|
||||
dim_t cs_a, \
|
||||
dim_t cs_b, \
|
||||
dim_t rs_c_ref, \
|
||||
dim_t cs_c_ref, \
|
||||
dim_t i, \
|
||||
dim_t j, \
|
||||
dim_t k, \
|
||||
dim_t pre_op_ld, /* Ignored */ \
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
|
||||
) \
|
||||
{ \
|
||||
( void ) pre_op; \
|
||||
temp_accum = (ACCUM_type) 0; \
|
||||
for ( dim_t p = 0; p < k; ++p) \
|
||||
{ \
|
||||
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \
|
||||
*( b + ( rs_b * p ) + ( cs_b * j ) ) ); \
|
||||
} \
|
||||
\
|
||||
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) \
|
||||
+ ( alpha * temp_accum ); \
|
||||
return temp_accum; \
|
||||
} \
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)
|
||||
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
|
||||
@@ -353,370 +237,6 @@ GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,uint8_t,int32_t,u8s8s32ou8)
|
||||
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
|
||||
|
||||
static inline int32_t mat_mul_accuracy_check_accum_u8s8s32obf16
|
||||
(
|
||||
uint8_t* a,
|
||||
int8_t* b,
|
||||
bfloat16* c_ref,
|
||||
int32_t temp_accum,
|
||||
int32_t alpha,
|
||||
int32_t beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld, /* Ignored */
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
|
||||
)
|
||||
{
|
||||
( void ) pre_op;
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
|
||||
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
|
||||
}
|
||||
|
||||
float c_ref_float;
|
||||
bfloat16_to_float(( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ), &c_ref_float);
|
||||
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
|
||||
return temp_accum;
|
||||
}
|
||||
static inline int32_t mat_mul_accuracy_check_accum_s8s8s32obf16
|
||||
(
|
||||
int8_t* a,
|
||||
int8_t* b,
|
||||
bfloat16* c_ref,
|
||||
int32_t temp_accum,
|
||||
int32_t alpha,
|
||||
int32_t beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld, /* Ignored */
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
|
||||
)
|
||||
{
|
||||
( void ) pre_op;
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
|
||||
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
|
||||
}
|
||||
|
||||
float c_ref_float;
|
||||
bfloat16_to_float(( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ), &c_ref_float);
|
||||
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
static inline int32_t mat_mul_accuracy_check_accum_u8s8s32of32
|
||||
(
|
||||
uint8_t* a,
|
||||
int8_t* b,
|
||||
float* c_ref,
|
||||
int32_t temp_accum,
|
||||
int32_t alpha,
|
||||
int32_t beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld, /* Ignored */
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
|
||||
)
|
||||
{
|
||||
( void ) pre_op;
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
|
||||
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
|
||||
}
|
||||
|
||||
float c_ref_float = *(c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) );
|
||||
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
static inline int32_t mat_mul_accuracy_check_accum_s8s8s32of32
|
||||
(
|
||||
int8_t* a,
|
||||
int8_t* b,
|
||||
float* c_ref,
|
||||
int32_t temp_accum,
|
||||
int32_t alpha,
|
||||
int32_t beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld, /* Ignored */
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
|
||||
)
|
||||
{
|
||||
( void ) pre_op;
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
|
||||
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
|
||||
}
|
||||
|
||||
float c_ref_float = *(c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) );
|
||||
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
static inline float mat_mul_accuracy_check_accum_bf16bf16f32of32
|
||||
(
|
||||
bfloat16* a,
|
||||
bfloat16* b,
|
||||
float* c_ref,
|
||||
float temp_accum,
|
||||
float alpha,
|
||||
float beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld, /* Ignored */ \
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
|
||||
)
|
||||
{
|
||||
( void ) pre_op;
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
float a_float, b_float;
|
||||
bfloat16_to_float( *( a + i * rs_a + p * cs_a ) , &a_float);
|
||||
bfloat16_to_float( *( b + p * rs_b + j * cs_b ) , &b_float);
|
||||
temp_accum += ( ( a_float ) * ( b_float ) );
|
||||
}
|
||||
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) )
|
||||
+ ( alpha * temp_accum );
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
static inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16
|
||||
(
|
||||
bfloat16* a,
|
||||
bfloat16* b,
|
||||
bfloat16* c_ref,
|
||||
float temp_accum,
|
||||
float alpha,
|
||||
float beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld, /* Ignored */
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
|
||||
)
|
||||
{
|
||||
( void ) pre_op;
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
float a_float, b_float;
|
||||
bfloat16_to_float( *( a + i*rs_a + p*cs_a ), &a_float );
|
||||
bfloat16_to_float( *( b + p*rs_b + j*cs_b ), &b_float );
|
||||
temp_accum += ( ( a_float ) * ( b_float ) );
|
||||
}
|
||||
float c_ref_float;
|
||||
bfloat16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ), &c_ref_float );
|
||||
temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum );
|
||||
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
static inline float get_s4_to_f32_scale_val
|
||||
(
|
||||
int8_t* b,
|
||||
dim_t p,
|
||||
dim_t j,
|
||||
dim_t n,
|
||||
dim_t b_inc,
|
||||
aocl_pre_op* pre_op
|
||||
)
|
||||
{
|
||||
float b_float = 0.0;
|
||||
int8_t b_val = 0;
|
||||
|
||||
dim_t group_size = pre_op->group_size;
|
||||
|
||||
/* Even index will have data at low 4 bits, and odd at hi 4 bits.
|
||||
* B matrix increments has to be halved to account for 4 bit
|
||||
* traversal. */
|
||||
if ( ( b_inc % 2 ) != 0 )
|
||||
{
|
||||
b_val = ( ( *( b + ( b_inc / 2 ) ) ) >> 4 ) & 0x0F;
|
||||
}
|
||||
else
|
||||
{
|
||||
b_val = ( *( b + ( b_inc / 2 ) ) ) & 0x0F;
|
||||
}
|
||||
|
||||
/* Signed scale. */
|
||||
if ( b_val & 0x08 )
|
||||
{
|
||||
b_val = b_val | 0xF0;
|
||||
}
|
||||
|
||||
if ( ( pre_op != NULL ) && ( pre_op->seq_length > 0 ) )
|
||||
{
|
||||
dim_t j_zp=0, j_scale=0;
|
||||
if(group_size!=0)
|
||||
{
|
||||
j_zp = ( ( p / group_size ) * n ) + j;
|
||||
if ( ( pre_op->b_zp != NULL ) &&
|
||||
( ( pre_op->b_zp )->zero_point_len == 1 ) )
|
||||
{
|
||||
j_zp = p / group_size;
|
||||
}
|
||||
|
||||
j_scale = ( ( p / group_size ) * n ) + j;
|
||||
if ( ( pre_op->b_scl != NULL ) &&
|
||||
( ( pre_op->b_scl )->scale_factor_len == 1 ) )
|
||||
{
|
||||
j_scale = (p / group_size);
|
||||
}
|
||||
}
|
||||
|
||||
// Assuming only 1 scale and zp.
|
||||
int8_t zp = 0;
|
||||
if ( ( pre_op->b_zp != NULL ) &&
|
||||
( ( pre_op->b_zp )->zero_point != NULL ) )
|
||||
{
|
||||
zp = *( ( int8_t* )( pre_op->b_zp )->zero_point + j_zp );
|
||||
}
|
||||
|
||||
float scale_factor = 1.0;
|
||||
if ( ( pre_op->b_scl != NULL ) &&
|
||||
( ( pre_op->b_scl )->scale_factor != NULL ) )
|
||||
{
|
||||
if( pre_op->b_scl->scale_factor_type == AOCL_GEMM_F32 )
|
||||
{
|
||||
scale_factor = *( ( float* )( pre_op->b_scl )->scale_factor + j_scale );
|
||||
}
|
||||
else
|
||||
{
|
||||
bfloat16_to_float( *( ( bfloat16* )( pre_op->b_scl )->scale_factor + j_scale ) , &scale_factor);
|
||||
}
|
||||
|
||||
}
|
||||
b_float = (float)( b_val - zp ) * scale_factor;
|
||||
}
|
||||
else
|
||||
{
|
||||
b_float = (float)( b_val);
|
||||
}
|
||||
|
||||
return b_float;
|
||||
}
|
||||
|
||||
static inline float mat_mul_accuracy_check_accum_bf16s4f32of32
|
||||
(
|
||||
bfloat16* a,
|
||||
int8_t* b,
|
||||
float* c_ref,
|
||||
float temp_accum,
|
||||
float alpha,
|
||||
float beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld,
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
|
||||
)
|
||||
{
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
float a_float, b_float;
|
||||
bfloat16_to_float( *( a + i * rs_a + p * cs_a ) , &a_float);
|
||||
|
||||
/* Get B matrix int4_t value and upscale it to float. */
|
||||
dim_t b_inc = ( rs_b * p ) + ( cs_b * j );
|
||||
b_float = get_s4_to_f32_scale_val( b, p, j, pre_op_ld, b_inc, pre_op );
|
||||
|
||||
temp_accum += ( ( a_float ) * ( b_float ) );
|
||||
}
|
||||
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) )
|
||||
+ ( alpha * temp_accum );
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
static inline float mat_mul_accuracy_check_accum_bf16s4f32obf16
|
||||
(
|
||||
bfloat16* a,
|
||||
int8_t* b,
|
||||
bfloat16* c_ref,
|
||||
float temp_accum,
|
||||
float alpha,
|
||||
float beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld,
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
|
||||
)
|
||||
{
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
float a_float, b_float;
|
||||
bfloat16_to_float( *( a + i*rs_a + p*cs_a ), &a_float );
|
||||
|
||||
/* Get B matrix int4_t value and upscale it to float. */
|
||||
dim_t b_inc = ( rs_b * p ) + ( cs_b * j );
|
||||
b_float = get_s4_to_f32_scale_val( b, p, j, pre_op_ld, b_inc, pre_op );
|
||||
|
||||
temp_accum += ( ( a_float ) * ( b_float ) );
|
||||
}
|
||||
float c_ref_float;
|
||||
bfloat16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ), &c_ref_float );
|
||||
temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum );
|
||||
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os8)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32ou8)
|
||||
GEN_GELU_TANH_POSTOP_FLOAT(u8s8s32os32)
|
||||
@@ -873,8 +393,8 @@ GEN_CLIP_POST_OP_VAL_INT(s8s8s32ou8)
|
||||
GEN_CLIP_POST_OP_VAL_INT(s8s8s32os32)
|
||||
|
||||
|
||||
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16bf16f32obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16s4f32obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32obf16)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32obf16)
|
||||
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,u8s8s32ou8)
|
||||
@@ -1176,8 +696,8 @@ GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,uint8_t,float,int32_t,s8s8s32ou8)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,float,float,int32_t,s8s8s32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,bfloat16,float,int32_t,s8s8s32obf16)
|
||||
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16bf16f32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,bfloat16,bf16bf16f32obf16)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,float,bf16bf16f32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,float,bf16bf16f32obf16)
|
||||
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16s4f32of32)
|
||||
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,bfloat16,float,bfloat16,bf16s4f32obf16)
|
||||
|
||||
@@ -469,9 +469,8 @@ static inline ACCUM_type SWISH_post_op_ ## BLAS_SFX \
|
||||
float alpha_val; \
|
||||
int32_t_to_float(*( ( int32_t* )alpha), &alpha_val); \
|
||||
float swish_reference = ( temp_accum / ( 1 + \
|
||||
expf( ( double )(alpha_val) * temp_accum * -1 ) ) ); \
|
||||
temp_accum = round (swish_reference); \
|
||||
return temp_accum; \
|
||||
expf( ( double )((alpha_val) * temp_accum * -1 )) ) ); \
|
||||
return swish_reference; \
|
||||
} \
|
||||
|
||||
#define GEN_SWISH_POSTOP_FLOAT(BLAS_SFX) \
|
||||
@@ -766,11 +765,493 @@ void print_matrix_bfloat16
|
||||
{
|
||||
float temp;
|
||||
bfloat16_to_float(*(a + i*(rs_a) + j *cs_a), &temp);
|
||||
printf("%f ", temp);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
|
||||
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
|
||||
(\
|
||||
ACCUM_type temp_accum,\
|
||||
aocl_post_op* post_op, \
|
||||
dim_t j \
|
||||
)\
|
||||
{ \
|
||||
dim_t j_scale = j; \
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 ) \
|
||||
{ \
|
||||
j_scale = 0; \
|
||||
} \
|
||||
\
|
||||
dim_t j_zp = j; \
|
||||
if ( ( post_op->sum )->zero_point_len == 1 ) \
|
||||
{ \
|
||||
j_zp = 0; \
|
||||
} \
|
||||
\
|
||||
ACCUM_type out_temp_accum = \
|
||||
( ACCUM_type )min( \
|
||||
max( nearbyintf( ( SCALE_type )( temp_accum ) * \
|
||||
( *( ( SCALE_type* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
|
||||
*( ( C_type* )( post_op->sum )->zero_point + j_zp ), \
|
||||
DSCALE_CLIP_MIN ), \
|
||||
DSCALE_CLIP_MAX ); \
|
||||
return out_temp_accum; \
|
||||
}\
|
||||
|
||||
|
||||
static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16
|
||||
(
|
||||
float temp_accum,
|
||||
aocl_post_op* post_op,
|
||||
dim_t j
|
||||
)
|
||||
{
|
||||
dim_t j_scale = j;
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 )
|
||||
{
|
||||
j_scale = 0;
|
||||
}
|
||||
|
||||
dim_t j_zp = j;
|
||||
if ( ( post_op->sum )->zero_point_len == 1 )
|
||||
{
|
||||
j_zp = 0;
|
||||
}
|
||||
|
||||
float zp_float = 0.0;
|
||||
bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->zero_point + j_zp ),
|
||||
&zp_float );
|
||||
float out_temp_accum = ( temp_accum *
|
||||
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
|
||||
zp_float );
|
||||
return out_temp_accum;
|
||||
}
|
||||
static inline float mat_mul_accuracy_check_downscale_f32f32f32of32
|
||||
(
|
||||
float temp_accum,
|
||||
aocl_post_op* post_op,
|
||||
dim_t j
|
||||
)
|
||||
{
|
||||
dim_t j_scale = j;
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 )
|
||||
{
|
||||
j_scale = 0;
|
||||
}
|
||||
|
||||
dim_t j_zp = j;
|
||||
if ( ( post_op->sum )->zero_point_len == 1 )
|
||||
{
|
||||
j_zp = 0;
|
||||
}
|
||||
float out_temp_accum = ( temp_accum *
|
||||
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
|
||||
*( ( float* )( post_op->sum )->zero_point + j_zp ) );
|
||||
return out_temp_accum;
|
||||
}
|
||||
|
||||
#define GEN_MAT_MUL_ACC_CHK_ACCUM(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \
|
||||
static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
|
||||
(\
|
||||
A_type* a, \
|
||||
B_type* b, \
|
||||
C_type* c_ref, \
|
||||
ACCUM_type temp_accum,\
|
||||
ACCUM_type alpha, \
|
||||
ACCUM_type beta, \
|
||||
dim_t rs_a, \
|
||||
dim_t rs_b, \
|
||||
dim_t cs_a, \
|
||||
dim_t cs_b, \
|
||||
dim_t rs_c_ref, \
|
||||
dim_t cs_c_ref, \
|
||||
dim_t i, \
|
||||
dim_t j, \
|
||||
dim_t k, \
|
||||
dim_t pre_op_ld, /* Ignored */ \
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
|
||||
) \
|
||||
{ \
|
||||
( void ) pre_op; \
|
||||
temp_accum = (ACCUM_type) 0; \
|
||||
for ( dim_t p = 0; p < k; ++p) \
|
||||
{ \
|
||||
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \
|
||||
*( b + ( rs_b * p ) + ( cs_b * j ) ) ); \
|
||||
} \
|
||||
\
|
||||
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) \
|
||||
+ ( alpha * temp_accum ); \
|
||||
return temp_accum; \
|
||||
} \
|
||||
|
||||
static inline int32_t mat_mul_accuracy_check_accum_u8s8s32obf16
|
||||
(
|
||||
uint8_t* a,
|
||||
int8_t* b,
|
||||
bfloat16* c_ref,
|
||||
int32_t temp_accum,
|
||||
int32_t alpha,
|
||||
int32_t beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld, /* Ignored */
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
|
||||
)
|
||||
{
|
||||
( void ) pre_op;
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
|
||||
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
|
||||
}
|
||||
|
||||
float c_ref_float;
|
||||
bfloat16_to_float(( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ), &c_ref_float);
|
||||
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
|
||||
return temp_accum;
|
||||
}
|
||||
static inline int32_t mat_mul_accuracy_check_accum_s8s8s32obf16
|
||||
(
|
||||
int8_t* a,
|
||||
int8_t* b,
|
||||
bfloat16* c_ref,
|
||||
int32_t temp_accum,
|
||||
int32_t alpha,
|
||||
int32_t beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld, /* Ignored */
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
|
||||
)
|
||||
{
|
||||
( void ) pre_op;
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
|
||||
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
|
||||
}
|
||||
|
||||
float c_ref_float;
|
||||
bfloat16_to_float(( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ), &c_ref_float);
|
||||
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
static inline int32_t mat_mul_accuracy_check_accum_u8s8s32of32
|
||||
(
|
||||
uint8_t* a,
|
||||
int8_t* b,
|
||||
float* c_ref,
|
||||
int32_t temp_accum,
|
||||
int32_t alpha,
|
||||
int32_t beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld, /* Ignored */
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
|
||||
)
|
||||
{
|
||||
( void ) pre_op;
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
|
||||
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
|
||||
}
|
||||
|
||||
float c_ref_float = *(c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) );
|
||||
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
static inline int32_t mat_mul_accuracy_check_accum_s8s8s32of32
|
||||
(
|
||||
int8_t* a,
|
||||
int8_t* b,
|
||||
float* c_ref,
|
||||
int32_t temp_accum,
|
||||
int32_t alpha,
|
||||
int32_t beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld, /* Ignored */
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */
|
||||
)
|
||||
{
|
||||
( void ) pre_op;
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) *
|
||||
*( b + ( rs_b * p ) + ( cs_b * j ) ) );
|
||||
}
|
||||
|
||||
float c_ref_float = *(c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) );
|
||||
temp_accum = ( beta * c_ref_float ) + ( alpha * temp_accum );
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
static inline float mat_mul_accuracy_check_accum_bf16bf16f32of32
|
||||
(
|
||||
bfloat16* a,
|
||||
bfloat16* b,
|
||||
float* c_ref,
|
||||
float temp_accum,
|
||||
float alpha,
|
||||
float beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld, /* Ignored */ \
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
|
||||
)
|
||||
{
|
||||
( void ) pre_op;
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
float a_float, b_float;
|
||||
bfloat16_to_float( *( a + i * rs_a + p * cs_a ) , &a_float);
|
||||
bfloat16_to_float( *( b + p * rs_b + j * cs_b ) , &b_float);
|
||||
temp_accum += ( ( a_float ) * ( b_float ) );
|
||||
}
|
||||
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) )
|
||||
+ ( alpha * temp_accum );
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
static inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16
|
||||
(
|
||||
bfloat16* a,
|
||||
bfloat16* b,
|
||||
bfloat16* c_ref,
|
||||
float temp_accum,
|
||||
float alpha,
|
||||
float beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld, /* Ignored */
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
|
||||
)
|
||||
{
|
||||
( void ) pre_op;
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
float a_float, b_float;
|
||||
bfloat16_to_float( *( a + i*rs_a + p*cs_a ), &a_float );
|
||||
bfloat16_to_float( *( b + p*rs_b + j*cs_b ), &b_float );
|
||||
temp_accum += ( ( a_float ) * ( b_float ) );
|
||||
}
|
||||
float c_ref_float;
|
||||
bfloat16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ), &c_ref_float );
|
||||
temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum );
|
||||
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
static inline float get_s4_to_f32_scale_val
|
||||
(
|
||||
int8_t* b,
|
||||
dim_t p,
|
||||
dim_t j,
|
||||
dim_t n,
|
||||
dim_t b_inc,
|
||||
aocl_pre_op* pre_op
|
||||
)
|
||||
{
|
||||
float b_float = 0.0;
|
||||
int8_t b_val = 0;
|
||||
|
||||
dim_t group_size = pre_op->group_size;
|
||||
|
||||
/* Even index will have data at low 4 bits, and odd at hi 4 bits.
|
||||
* B matrix increments has to be halved to account for 4 bit
|
||||
* traversal. */
|
||||
if ( ( b_inc % 2 ) != 0 )
|
||||
{
|
||||
b_val = ( ( *( b + ( b_inc / 2 ) ) ) >> 4 ) & 0x0F;
|
||||
}
|
||||
else
|
||||
{
|
||||
b_val = ( *( b + ( b_inc / 2 ) ) ) & 0x0F;
|
||||
}
|
||||
|
||||
/* Signed scale. */
|
||||
if ( b_val & 0x08 )
|
||||
{
|
||||
b_val = b_val | 0xF0;
|
||||
}
|
||||
|
||||
if ( ( pre_op != NULL ) && ( pre_op->seq_length > 0 ) )
|
||||
{
|
||||
dim_t j_zp=0, j_scale=0;
|
||||
if(group_size!=0)
|
||||
{
|
||||
j_zp = ( ( p / group_size ) * n ) + j;
|
||||
if ( ( pre_op->b_zp != NULL ) &&
|
||||
( ( pre_op->b_zp )->zero_point_len == 1 ) )
|
||||
{
|
||||
j_zp = p / group_size;
|
||||
}
|
||||
|
||||
j_scale = ( ( p / group_size ) * n ) + j;
|
||||
if ( ( pre_op->b_scl != NULL ) &&
|
||||
( ( pre_op->b_scl )->scale_factor_len == 1 ) )
|
||||
{
|
||||
j_scale = (p / group_size);
|
||||
}
|
||||
}
|
||||
|
||||
// Assuming only 1 scale and zp.
|
||||
int8_t zp = 0;
|
||||
if ( ( pre_op->b_zp != NULL ) &&
|
||||
( ( pre_op->b_zp )->zero_point != NULL ) )
|
||||
{
|
||||
zp = *( ( int8_t* )( pre_op->b_zp )->zero_point + j_zp );
|
||||
}
|
||||
|
||||
float scale_factor = 1.0;
|
||||
if ( ( pre_op->b_scl != NULL ) &&
|
||||
( ( pre_op->b_scl )->scale_factor != NULL ) )
|
||||
{
|
||||
if( pre_op->b_scl->scale_factor_type == AOCL_GEMM_F32 )
|
||||
{
|
||||
scale_factor = *( ( float* )( pre_op->b_scl )->scale_factor + j_scale );
|
||||
}
|
||||
else
|
||||
{
|
||||
bfloat16_to_float( *( ( bfloat16* )( pre_op->b_scl )->scale_factor + j_scale ) , &scale_factor);
|
||||
}
|
||||
|
||||
}
|
||||
b_float = (float)( b_val - zp ) * scale_factor;
|
||||
}
|
||||
else
|
||||
{
|
||||
b_float = (float)( b_val);
|
||||
}
|
||||
|
||||
return b_float;
|
||||
}
|
||||
|
||||
static inline float mat_mul_accuracy_check_accum_bf16s4f32of32
|
||||
(
|
||||
bfloat16* a,
|
||||
int8_t* b,
|
||||
float* c_ref,
|
||||
float temp_accum,
|
||||
float alpha,
|
||||
float beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld,
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
|
||||
)
|
||||
{
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
float a_float, b_float;
|
||||
bfloat16_to_float( *( a + i * rs_a + p * cs_a ) , &a_float);
|
||||
|
||||
/* Get B matrix int4_t value and upscale it to float. */
|
||||
dim_t b_inc = ( rs_b * p ) + ( cs_b * j );
|
||||
b_float = get_s4_to_f32_scale_val( b, p, j, pre_op_ld, b_inc, pre_op );
|
||||
|
||||
temp_accum += ( ( a_float ) * ( b_float ) );
|
||||
}
|
||||
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) )
|
||||
+ ( alpha * temp_accum );
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
|
||||
static inline float mat_mul_accuracy_check_accum_bf16s4f32obf16
|
||||
(
|
||||
bfloat16* a,
|
||||
int8_t* b,
|
||||
bfloat16* c_ref,
|
||||
float temp_accum,
|
||||
float alpha,
|
||||
float beta,
|
||||
dim_t rs_a,
|
||||
dim_t rs_b,
|
||||
dim_t cs_a,
|
||||
dim_t cs_b,
|
||||
dim_t rs_c_ref,
|
||||
dim_t cs_c_ref,
|
||||
dim_t i,
|
||||
dim_t j,
|
||||
dim_t k,
|
||||
dim_t pre_op_ld,
|
||||
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
|
||||
)
|
||||
{
|
||||
for ( dim_t p = 0; p < k; ++p)
|
||||
{
|
||||
float a_float, b_float;
|
||||
bfloat16_to_float( *( a + i*rs_a + p*cs_a ), &a_float );
|
||||
|
||||
/* Get B matrix int4_t value and upscale it to float. */
|
||||
dim_t b_inc = ( rs_b * p ) + ( cs_b * j );
|
||||
b_float = get_s4_to_f32_scale_val( b, p, j, pre_op_ld, b_inc, pre_op );
|
||||
|
||||
temp_accum += ( ( a_float ) * ( b_float ) );
|
||||
}
|
||||
float c_ref_float;
|
||||
bfloat16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ), &c_ref_float );
|
||||
temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum );
|
||||
|
||||
return temp_accum;
|
||||
}
|
||||
|
||||
#define GEN_MAT_MUL_POST_OPS_CREATOR(C_DSCALE_type,C_type,DSCALE_type,BIAS_type,BLAS_SFX) \
|
||||
static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
|
||||
Reference in New Issue
Block a user