mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Updated aocl_batch_gemm_ APIs aligning to CBLAS batch API. (#58)
* Updated aocl_batch_gemm_ APIs aligning to CBLAS batch API. - Modified Batch-Gemm API to align with cblas_?gemm_batch_ API, and added a parameter group_size to the existing APIs. - Updated bench batch_gemm code to align to the new API definition. - Modified the hardcoded number in lpgemm_postop file. - Added necessary early return condition to account for group_count/group_size < 0. AMD-Internal: [ SWLCSG - 3592 ]
This commit is contained in:
@@ -50,7 +50,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
(
|
||||
"bf16bf16f32of32", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
group_count, group_size, m, n, k, \
|
||||
( ( float* ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
@@ -58,29 +58,11 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
bfloat16 *a_local[batch_size], *b_local[batch_size];
|
||||
dim_t m_local[batch_size], n_local[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
|
||||
// Check if AVX2 ISA is supported, lpgemm fp32 matmul only works with it.
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
|
||||
bli_print_msg(" AVX2 ISA not supported by processor, "
|
||||
"cannot perform f32f32f32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
@@ -90,194 +72,225 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
if( get_jit_kernels_generated() == FALSE )
|
||||
{
|
||||
bli_print_msg(" Could not generate bf16bf16f32of32 "
|
||||
" kernels using JIT.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
#endif
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
if( get_jit_kernels_generated() == FALSE )
|
||||
{
|
||||
bli_print_msg(" Could not generate bf16bf16f32of32 "
|
||||
" kernels using JIT.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
#endif
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
// offset to get subsequent matrix when group_count > 1
|
||||
dim_t mat_idx = 0;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
|
||||
for( dim_t gc_i = 0; gc_i < group_count; gc_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_bf16bf16f32of32",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
"batch_bf16bf16f32of32",
|
||||
order[gc_i], transa[gc_i], transb[gc_i],
|
||||
group_count, group_size[gc_i],
|
||||
m[gc_i], n[gc_i], k[gc_i],
|
||||
lda[gc_i], mem_format_a[gc_i],
|
||||
ldb[gc_i], mem_format_b[gc_i],
|
||||
ldc[gc_i],
|
||||
err_no
|
||||
);
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
// Group_size is used across
|
||||
dim_t g_sz = group_size[gc_i];
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
rs_a[bs_i] = ldb[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = ldb[bs_i];
|
||||
}
|
||||
inc_t rs_a[g_sz];
|
||||
inc_t cs_a[g_sz];
|
||||
|
||||
rs_b[bs_i] = lda[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
inc_t rs_b[g_sz];
|
||||
inc_t cs_b[g_sz];
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = lda[bs_i];
|
||||
}
|
||||
inc_t rs_c[g_sz];
|
||||
inc_t cs_c[g_sz];
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
|
||||
AOCL_MEMORY_TAG mtag_a[g_sz];
|
||||
AOCL_MEMORY_TAG mtag_b[g_sz];
|
||||
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
if ( bli_is_trans(blis_transb ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
bfloat16 *a_local[g_sz], *b_local[g_sz];
|
||||
dim_t m_local[g_sz], n_local[g_sz], k_local[g_sz];
|
||||
|
||||
// swap m & n in case of col-major matrices
|
||||
m_local[bs_i] = n[bs_i];
|
||||
n_local[bs_i] = m[bs_i];
|
||||
|
||||
// swap a & b pointers in case of col-major matrices
|
||||
a_local[bs_i] = (bfloat16*)(b[bs_i]);
|
||||
b_local[bs_i] = (bfloat16*)(a[bs_i]);
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[bs_i] = m[bs_i];
|
||||
n_local[bs_i] = n[bs_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[bs_i] = (bfloat16*)(a[bs_i]);
|
||||
b_local[bs_i] = (bfloat16*)(b[bs_i]);
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
(
|
||||
post_op_unparsed[gc_i], post_op_list,
|
||||
( void* )c[gc_i], ( void* )( (order + gc_i) ),
|
||||
m[gc_i], n[gc_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[gc_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[gc_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[gc_i] == 'c' ) || ( order[gc_i] == 'C' ) );
|
||||
|
||||
for( dim_t gs_i = 0; gs_i < g_sz; gs_i++ )
|
||||
{
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
rs_a[gs_i] = ldb[gc_i];
|
||||
cs_a[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_a[gs_i] = 1;
|
||||
cs_a[gs_i] = ldb[gc_i];
|
||||
}
|
||||
|
||||
rs_b[gs_i] = lda[gc_i];
|
||||
cs_b[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_b[gs_i] = 1;
|
||||
cs_b[gs_i] = lda[gc_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_b[gs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_a[gs_i]) );
|
||||
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
if ( ( ( mtag_b[gs_i] == REORDERED ) || ( mtag_a[gs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
if ( bli_is_trans(blis_transb ) )
|
||||
{
|
||||
mtag_a[gs_i] = PACK;
|
||||
}
|
||||
|
||||
// swap m & n in case of col-major matrices
|
||||
m_local[gs_i] = n[gc_i];
|
||||
n_local[gs_i] = m[gc_i];
|
||||
|
||||
// swap a & b pointers in case of col-major matrices
|
||||
a_local[gs_i] = (bfloat16*)(b[mat_idx + gs_i]);
|
||||
b_local[gs_i] = (bfloat16*)(a[mat_idx + gs_i]);
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[gs_i] = lda[gc_i];
|
||||
cs_a[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[gs_i] = 1;
|
||||
cs_a[gs_i] = lda[gc_i];
|
||||
}
|
||||
|
||||
rs_b[gs_i] = ldb[gc_i];
|
||||
cs_b[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[gs_i] = 1;
|
||||
cs_b[gs_i] = ldb[gc_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_a[gs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_b[gs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[gs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[gs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[gs_i] = m[gc_i];
|
||||
n_local[gs_i] = n[gc_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[gs_i] = (bfloat16*)(a[mat_idx + gs_i]);
|
||||
b_local[gs_i] = (bfloat16*)(b[mat_idx + gs_i]);
|
||||
}
|
||||
|
||||
k_local[gs_i] = k[gc_i];
|
||||
|
||||
rs_c[gs_i] = ldc[gc_i];
|
||||
cs_c[gs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[gs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[gs_i] = PACK;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_bf16bf16f32of32_openmp_thread_decorator
|
||||
(
|
||||
g_sz, m_local, n_local, k_local,
|
||||
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
|
||||
&c[mat_idx], rs_c, cs_c,
|
||||
alpha[gc_i], beta[gc_i],
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_bf16bf16f32of32_thread_decorator
|
||||
(
|
||||
g_sz, m_local, n_local, k_local,
|
||||
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
|
||||
&c[mat_idx], rs_c, cs_c,
|
||||
alpha[gc_i], beta[gc_i],
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
#endif
|
||||
mat_idx+=g_sz;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_bf16bf16f32of32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_bf16bf16f32of32_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
@@ -289,7 +302,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
(
|
||||
"bf16bf16f32obf16", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
group_count, group_size, m, n, k, \
|
||||
( ( float* ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
@@ -297,24 +310,6 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
bfloat16 *a_local[batch_size], *b_local[batch_size];
|
||||
dim_t m_local[batch_size], n_local[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
|
||||
{
|
||||
@@ -338,25 +333,25 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
}
|
||||
#endif
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
// offset to get subsequent matrix when group_count > 1
|
||||
dim_t mat_idx = 0;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
for( dim_t gc_i = 0; gc_i < group_count; gc_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_bf16bf16f32obf16",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
"batch_bf16bf16f32obf16",
|
||||
order[gc_i], transa[gc_i], transb[gc_i],
|
||||
group_count, group_size[gc_i],
|
||||
m[gc_i], n[gc_i], k[gc_i],
|
||||
lda[gc_i], mem_format_a[gc_i],
|
||||
ldb[gc_i], mem_format_b[gc_i],
|
||||
ldc[gc_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
@@ -364,163 +359,190 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
// Group_size is used across
|
||||
dim_t g_sz = group_size[gc_i];
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
rs_a[bs_i] = ldb[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
inc_t rs_a[g_sz];
|
||||
inc_t cs_a[g_sz];
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = ldb[bs_i];
|
||||
}
|
||||
inc_t rs_b[g_sz];
|
||||
inc_t cs_b[g_sz];
|
||||
|
||||
rs_b[bs_i] = lda[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
inc_t rs_c[g_sz];
|
||||
inc_t cs_c[g_sz];
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = lda[bs_i];
|
||||
}
|
||||
AOCL_MEMORY_TAG mtag_a[g_sz];
|
||||
AOCL_MEMORY_TAG mtag_b[g_sz];
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
|
||||
bfloat16 *a_local[g_sz], *b_local[g_sz];
|
||||
dim_t m_local[g_sz], n_local[g_sz], k_local[g_sz];
|
||||
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
if ( bli_is_trans(blis_transb ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// swap m & n in case of col-major matrices
|
||||
m_local[bs_i] = n[bs_i];
|
||||
n_local[bs_i] = m[bs_i];
|
||||
|
||||
// swap a & b pointers in case of col-major matrices
|
||||
a_local[bs_i] = (bfloat16*)(b[bs_i]);
|
||||
b_local[bs_i] = (bfloat16*)(a[bs_i]);
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[bs_i] = m[bs_i];
|
||||
n_local[bs_i] = n[bs_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[bs_i] = (bfloat16*)(a[bs_i]);
|
||||
b_local[bs_i] = (bfloat16*)(b[bs_i]);
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
post_op_unparsed[gc_i], post_op_list,
|
||||
( void* )c[gc_i], ( void* )( order + gc_i ),
|
||||
m[gc_i], n[gc_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[gc_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[gc_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[gc_i] == 'c' ) || ( order[gc_i] == 'C' ) );
|
||||
|
||||
for( dim_t gs_i = 0; gs_i < g_sz; gs_i++ )
|
||||
{
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
rs_a[gs_i] = ldb[gc_i];
|
||||
cs_a[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_a[gs_i] = 1;
|
||||
cs_a[gs_i] = ldb[gc_i];
|
||||
}
|
||||
|
||||
rs_b[gs_i] = lda[gc_i];
|
||||
cs_b[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_b[gs_i] = 1;
|
||||
cs_b[gs_i] = lda[gc_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_b[gs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_a[gs_i]) );
|
||||
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
if ( ( ( mtag_b[gs_i] == REORDERED ) || ( mtag_a[gs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
if ( bli_is_trans(blis_transb ) )
|
||||
{
|
||||
mtag_a[gs_i] = PACK;
|
||||
}
|
||||
|
||||
// swap m & n in case of col-major matrices
|
||||
m_local[gs_i] = n[gc_i];
|
||||
n_local[gs_i] = m[gc_i];
|
||||
|
||||
// swap a & b pointers in case of col-major matrices
|
||||
a_local[gs_i] = (bfloat16*)(b[mat_idx + gs_i]);
|
||||
b_local[gs_i] = (bfloat16*)(a[mat_idx + gs_i]);
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[gs_i] = lda[gc_i];
|
||||
cs_a[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[gs_i] = 1;
|
||||
cs_a[gs_i] = lda[gc_i];
|
||||
}
|
||||
|
||||
rs_b[gs_i] = ldb[gc_i];
|
||||
cs_b[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[gs_i] = 1;
|
||||
cs_b[gs_i] = ldb[gc_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_a[gs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_b[gs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[gs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[gs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[gs_i] = m[gc_i];
|
||||
n_local[gs_i] = n[gc_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[gs_i] = (bfloat16*)(a[mat_idx + gs_i]);
|
||||
b_local[gs_i] = (bfloat16*)(b[mat_idx + gs_i]);
|
||||
}
|
||||
|
||||
k_local[gs_i] = k[gc_i];
|
||||
|
||||
rs_c[gs_i] = ldc[gc_i];
|
||||
cs_c[gs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[gs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[gs_i] = PACK;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_bf16bf16f32of32_openmp_thread_decorator
|
||||
(
|
||||
g_sz, m_local, n_local, k_local,
|
||||
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
|
||||
(float**)&c[mat_idx], rs_c, cs_c,
|
||||
alpha[gc_i], beta[gc_i],
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, BF16
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_bf16bf16f32of32_thread_decorator
|
||||
(
|
||||
g_sz, m_local, n_local, k_local,
|
||||
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
|
||||
(float**)&c[mat_idx], rs_c, cs_c,
|
||||
alpha[gc_i], beta[gc_i],
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, BF16
|
||||
);
|
||||
#endif
|
||||
mat_idx += g_sz;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_bf16bf16f32of32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
|
||||
(float**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, BF16
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_bf16bf16f32of32_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
|
||||
(float**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, BF16
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
(
|
||||
"bf16s4f32of32", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
group_count, group_size, m, n, k, \
|
||||
( ( float* ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
@@ -58,21 +58,6 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
lpgemm_pre_op pre_op_list[batch_size][AOCL_MAX_PRE_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
|
||||
{
|
||||
@@ -87,154 +72,191 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,float,float,bf16s4f32of32)
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
bli_print_msg(" WOQ is not supported by JIT kernels.", __FILE__, __LINE__ );
|
||||
return;
|
||||
#endif
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
bli_print_msg(" WOQ is not supported by JIT kernels.", __FILE__, __LINE__ );
|
||||
return;
|
||||
#endif
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
// offset to get subsequent matrix when group_count > 1
|
||||
dim_t mat_idx = 0;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
for( dim_t gc_i = 0; gc_i < group_count; gc_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_bf16s4f32of32",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
"batch_bf16s4f32of32",
|
||||
order[gc_i], transa[gc_i], transb[gc_i],
|
||||
group_count, group_size[gc_i],
|
||||
m[gc_i], n[gc_i], k[gc_i],
|
||||
lda[gc_i], mem_format_a[gc_i],
|
||||
ldb[gc_i], mem_format_b[gc_i],
|
||||
ldc[gc_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
dim_t g_sz = group_size[gc_i];
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
inc_t rs_a[g_sz];
|
||||
inc_t cs_a[g_sz];
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
inc_t rs_b[g_sz];
|
||||
inc_t cs_b[g_sz];
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
inc_t rs_c[g_sz];
|
||||
inc_t cs_c[g_sz];
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
AOCL_MEMORY_TAG mtag_a[g_sz];
|
||||
AOCL_MEMORY_TAG mtag_b[g_sz];
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS];
|
||||
|
||||
// Convert pre op struct to pre op linked list format.
|
||||
err_t err = lpgemm_translate_to_pre_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i]->pre_ops,
|
||||
pre_op_list[bs_i],
|
||||
m[bs_i], n[bs_i], k[bs_i]
|
||||
);
|
||||
(
|
||||
post_op_unparsed[gc_i]->pre_ops,
|
||||
pre_op_list,
|
||||
m[gc_i], n[gc_i], k[gc_i]
|
||||
);
|
||||
if (err != BLIS_SUCCESS) goto err_hndl;
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
post_op_unparsed[gc_i], post_op_list,
|
||||
( void* )c[gc_i], ( void* )( order + gc_i ),
|
||||
m[gc_i], n[gc_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
bfloat16 *a_local[g_sz];
|
||||
int8_t* b_local[g_sz];
|
||||
dim_t m_local[g_sz], n_local[g_sz], k_local[g_sz];
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[gc_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[gc_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[gc_i] == 'c' ) || ( order[gc_i] == 'C' ) );
|
||||
|
||||
for( dim_t gs_i = 0; gs_i < g_sz; gs_i++ )
|
||||
{
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[gs_i] = lda[gc_i];
|
||||
cs_a[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[gs_i] = 1;
|
||||
cs_a[gs_i] = lda[gc_i];
|
||||
}
|
||||
|
||||
rs_b[gs_i] = ldb[gc_i];
|
||||
cs_b[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[gs_i] = 1;
|
||||
cs_b[gs_i] = ldb[gc_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_a[gs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_b[gs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[gs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[gs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[gs_i] = m[gc_i];
|
||||
n_local[gs_i] = n[gc_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[gs_i] = (bfloat16*)(a[mat_idx + gs_i]);
|
||||
b_local[gs_i] = (int8_t*)(b[mat_idx + gs_i]);
|
||||
}
|
||||
|
||||
k_local[gs_i] = k[gc_i];
|
||||
|
||||
rs_c[gs_i] = ldc[gc_i];
|
||||
cs_c[gs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[gs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[gs_i] = PACK;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16S4F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_bf16s4f32of32_openmp_thread_decorator
|
||||
(
|
||||
g_sz, m_local, n_local, k_local,
|
||||
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
&c[mat_idx], rs_c, cs_c,
|
||||
alpha[gc_i], beta[gc_i],
|
||||
&rntm_g, lcntx_g,
|
||||
pre_op_list, post_op_list, F32
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_bf16s4f32of32_thread_decorator
|
||||
(
|
||||
g_sz, m_local, n_local, k_local,
|
||||
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
&c[mat_idx], rs_c, cs_c,
|
||||
alpha[gc_i], beta[gc_i],
|
||||
&rntm_g, lcntx_g,
|
||||
pre_op_list, post_op_list, F32
|
||||
);
|
||||
#endif
|
||||
mat_idx += g_sz;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16S4F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_bf16s4f32of32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
pre_op_list, post_op_list, F32
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_bf16s4f32of32_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
pre_op_list, post_op_list, F32
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
@@ -246,7 +268,7 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
|
||||
(
|
||||
"bf16s4f32obf16", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
group_count, group_size, m, n, k, \
|
||||
( ( float* ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
@@ -254,26 +276,11 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
lpgemm_pre_op pre_op_list[batch_size][AOCL_MAX_PRE_OPS];
|
||||
|
||||
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
|
||||
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
|
||||
{
|
||||
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
|
||||
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
|
||||
"cannot perform bf16s4f32 gemm.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
@@ -283,30 +290,30 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
bli_print_msg(" WOQ is not supported by JIT kernels.", __FILE__, __LINE__ );
|
||||
return;
|
||||
#endif
|
||||
#ifdef LPGEMM_BF16_JIT
|
||||
bli_print_msg(" WOQ is not supported by JIT kernels.", __FILE__, __LINE__ );
|
||||
return;
|
||||
#endif
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
// offset to get subsequent matrix when group_count > 1
|
||||
dim_t mat_idx = 0;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
for( dim_t gc_i = 0; gc_i < group_count; gc_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_bf16s4f32obf16",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
"batch_bf16s4f32obf16",
|
||||
order[gc_i], transa[gc_i], transb[gc_i],
|
||||
group_count, group_size[gc_i],
|
||||
m[gc_i], n[gc_i], k[gc_i],
|
||||
lda[gc_i], mem_format_a[gc_i],
|
||||
ldb[gc_i], mem_format_b[gc_i],
|
||||
ldc[gc_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
@@ -314,125 +321,177 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
dim_t g_sz = group_size[gc_i];
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
inc_t rs_a[g_sz];
|
||||
inc_t cs_a[g_sz];
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
inc_t rs_b[g_sz];
|
||||
inc_t cs_b[g_sz];
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
inc_t rs_c[g_sz];
|
||||
inc_t cs_c[g_sz];
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
AOCL_MEMORY_TAG mtag_a[g_sz];
|
||||
AOCL_MEMORY_TAG mtag_b[g_sz];
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
bfloat16 *a_local[g_sz];
|
||||
int8_t *b_local[g_sz];
|
||||
dim_t m_local[g_sz], n_local[g_sz], k_local[g_sz];
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[bs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS];
|
||||
|
||||
// Convert pre op struct to pre op linked list format.
|
||||
err_t err = lpgemm_translate_to_pre_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i]->pre_ops,
|
||||
pre_op_list[bs_i],
|
||||
m[bs_i], n[bs_i], k[bs_i]
|
||||
);
|
||||
(
|
||||
post_op_unparsed[gc_i]->pre_ops,
|
||||
pre_op_list,
|
||||
m[gc_i], n[gc_i], k[gc_i]
|
||||
);
|
||||
if (err != BLIS_SUCCESS) goto err_hndl;
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
(
|
||||
post_op_unparsed[gc_i], post_op_list,
|
||||
( void* )c[gc_i], ( void* )( order+gc_i ),
|
||||
m[gc_i], n[gc_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
|
||||
for( dim_t gs_i = 0; gs_i < g_sz; gs_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_bf16s4f32obf16",
|
||||
order[gc_i], transa[gc_i], transb[gc_i],
|
||||
group_count, group_size[gc_i],
|
||||
m[gc_i], n[gc_i], k[gc_i],
|
||||
lda[gc_i], mem_format_a[gc_i],
|
||||
ldb[gc_i], mem_format_b[gc_i],
|
||||
ldc[gc_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
{
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[gc_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[gc_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[gc_i] == 'c' ) || ( order[gc_i] == 'C' ) );
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
bli_print_msg("Column major inputs not supported.",
|
||||
__FILE__, __LINE__);
|
||||
goto err_hndl;
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[gs_i] = lda[gc_i];
|
||||
cs_a[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[gs_i] = 1;
|
||||
cs_a[gs_i] = lda[gc_i];
|
||||
}
|
||||
|
||||
rs_b[gs_i] = ldb[gc_i];
|
||||
cs_b[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[gs_i] = 1;
|
||||
cs_b[gs_i] = ldb[gc_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[gs_i], &(mtag_a[gs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[gs_i], &(mtag_b[gs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[gs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[gs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[gs_i] = m[gc_i];
|
||||
n_local[gs_i] = n[gc_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[gs_i] = (bfloat16*)(a[mat_idx + gs_i]);
|
||||
b_local[gs_i] = (int8_t*)(b[mat_idx + gs_i]);
|
||||
}
|
||||
|
||||
k_local[gs_i] = k[gc_i];
|
||||
|
||||
rs_c[gs_i] = ldc[gc_i];
|
||||
cs_c[gs_i] = 1;
|
||||
|
||||
// From 5-loop function point of view
|
||||
// B matrix needs to be packed in a certain format in order to be loaded
|
||||
// and used in bf16 instrution. As such the mtag_b always needs to be either
|
||||
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
|
||||
// the mtag_b is set to packed to enable runtime packing.
|
||||
if ( mtag_b[gs_i] == UNPACKED )
|
||||
{
|
||||
mtag_b[gs_i] = PACK;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16S4F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_bf16s4f32of32_openmp_thread_decorator
|
||||
(
|
||||
g_sz, m_local, n_local, k_local,
|
||||
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
(float**)&c[mat_idx], rs_c, cs_c,
|
||||
alpha[gc_i], beta[gc_i],
|
||||
&rntm_g, lcntx_g,
|
||||
pre_op_list, post_op_list, BF16
|
||||
);
|
||||
#else
|
||||
batch_lpgemm_bf16s4f32of32_thread_decorator
|
||||
(
|
||||
g_sz, m_local, n_local, k_local,
|
||||
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
|
||||
(float**)&c[mat_idx], rs_c, cs_c,
|
||||
alpha[gc_i], beta[gc_i],
|
||||
&rntm_g, lcntx_g,
|
||||
pre_op_list, post_op_list, BF16
|
||||
);
|
||||
#endif
|
||||
mat_idx += g_sz;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16S4F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_bf16s4f32of32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(float**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
pre_op_list, post_op_list, BF16
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_bf16s4f32of32_thread_decorator
|
||||
(
|
||||
batch_size, m, n, k,
|
||||
a, rs_a, cs_a, mtag_a,
|
||||
b, rs_b, cs_b, mtag_b,
|
||||
(float**)c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
pre_op_list, post_op_list, BF16
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32)
|
||||
(
|
||||
"f32f32f32of32", \
|
||||
order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
group_count, group_size, m, n, k, \
|
||||
( ( float* ) alpha ), \
|
||||
lda, mem_format_a, \
|
||||
ldb, mem_format_b, \
|
||||
@@ -58,27 +58,6 @@ AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32)
|
||||
ldc, post_op_unparsed \
|
||||
);
|
||||
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
inc_t rs_a[batch_size];
|
||||
inc_t cs_a[batch_size];
|
||||
|
||||
inc_t rs_b[batch_size];
|
||||
inc_t cs_b[batch_size];
|
||||
|
||||
inc_t rs_c[batch_size];
|
||||
inc_t cs_c[batch_size];
|
||||
|
||||
AOCL_MEMORY_TAG mtag_a[batch_size];
|
||||
AOCL_MEMORY_TAG mtag_b[batch_size];
|
||||
|
||||
float *a_local[batch_size], *b_local[batch_size];
|
||||
dim_t m_local[batch_size], n_local[batch_size];
|
||||
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
|
||||
|
||||
// Check if AVX2 ISA is supported, lpgemm fp32 matmul only works with it.
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
|
||||
{
|
||||
@@ -93,22 +72,25 @@ AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32)
|
||||
// Set MC, NC, KC, NR, MR.
|
||||
aocl_lpgemm_init_global_cntx();
|
||||
|
||||
// offset to get subsequent matrix when group_count > 1
|
||||
dim_t mat_idx = 0;
|
||||
|
||||
// check for validity of params.
|
||||
int err_no = 0;
|
||||
|
||||
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
|
||||
for( dim_t gc_i = 0; gc_i < group_count; gc_i++ )
|
||||
{
|
||||
// check for validity of params.
|
||||
AOCL_BATCH_GEMM_CHECK
|
||||
(
|
||||
"batch_f32f32f32of32",
|
||||
order[bs_i], transa[bs_i], transb[bs_i],
|
||||
bs_i,
|
||||
m[bs_i], n[bs_i], k[bs_i],
|
||||
a[bs_i], lda[bs_i], mem_format_a[bs_i],
|
||||
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
|
||||
c[bs_i], ldc[bs_i],
|
||||
err_no
|
||||
"batch_f32f32f32of32",
|
||||
order[gc_i], transa[gc_i], transb[gc_i],
|
||||
group_count, group_size[gc_i],
|
||||
m[gc_i], n[gc_i], k[gc_i],
|
||||
lda[gc_i], mem_format_a[gc_i],
|
||||
ldb[gc_i], mem_format_b[gc_i],
|
||||
ldc[gc_i],
|
||||
err_no
|
||||
);
|
||||
|
||||
if ( err_no != 0 )
|
||||
@@ -116,163 +98,188 @@ AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32)
|
||||
goto err_hndl;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
|
||||
// Group_size is used across
|
||||
dim_t g_sz = group_size[gc_i];
|
||||
|
||||
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
|
||||
// Convert post op struct to post op linked list format.
|
||||
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
|
||||
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
rs_a[bs_i] = ldb[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = ldb[bs_i];
|
||||
}
|
||||
inc_t rs_a[g_sz];
|
||||
inc_t cs_a[g_sz];
|
||||
|
||||
rs_b[bs_i] = lda[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
inc_t rs_b[g_sz];
|
||||
inc_t cs_b[g_sz];
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = lda[bs_i];
|
||||
}
|
||||
inc_t rs_c[g_sz];
|
||||
inc_t cs_c[g_sz];
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
|
||||
AOCL_MEMORY_TAG mtag_a[g_sz];
|
||||
AOCL_MEMORY_TAG mtag_b[g_sz];
|
||||
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
if ( bli_is_trans(blis_transb ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// swap m & n in case of col-major matrices
|
||||
m_local[bs_i] = n[bs_i];
|
||||
n_local[bs_i] = m[bs_i];
|
||||
|
||||
// swap a & b pointers in case of col-major matrices
|
||||
a_local[bs_i] = (float*)(b[bs_i]);
|
||||
b_local[bs_i] = (float*)(a[bs_i]);
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[bs_i] = lda[bs_i];
|
||||
cs_a[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[bs_i] = 1;
|
||||
cs_a[bs_i] = lda[bs_i];
|
||||
}
|
||||
|
||||
rs_b[bs_i] = ldb[bs_i];
|
||||
cs_b[bs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[bs_i] = 1;
|
||||
cs_b[bs_i] = ldb[bs_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[bs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[bs_i] = PACK;
|
||||
}
|
||||
|
||||
if( bli_is_trans(blis_transb ) && ( mtag_b[bs_i] == UNPACKED ) )
|
||||
{
|
||||
mtag_b[bs_i] = PACK;
|
||||
}
|
||||
|
||||
// copy the values of m & n
|
||||
m_local[bs_i] = m[bs_i];
|
||||
n_local[bs_i] = n[bs_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[bs_i] = (float*)(a[bs_i]);
|
||||
b_local[bs_i] = (float*)(b[bs_i]);
|
||||
}
|
||||
|
||||
rs_c[bs_i] = ldc[bs_i];
|
||||
cs_c[bs_i] = 1;
|
||||
float *a_local[g_sz], *b_local[g_sz];
|
||||
dim_t m_local[g_sz], n_local[g_sz], k_local[g_sz];
|
||||
|
||||
err_t err = lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
post_op_unparsed[bs_i], post_op_list[bs_i],
|
||||
( void* )c[bs_i], ( void* )( (order + bs_i) ),
|
||||
m[bs_i], n[bs_i]
|
||||
);
|
||||
(
|
||||
post_op_unparsed[gc_i], post_op_list,
|
||||
( void* )c[gc_i], ( void* )( order + gc_i ),
|
||||
m[gc_i], n[gc_i]
|
||||
);
|
||||
|
||||
if( err != BLIS_SUCCESS ) goto err_hndl;
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans( transa[gc_i], &blis_transa );
|
||||
bli_param_map_netlib_to_blis_trans( transb[gc_i], &blis_transb );
|
||||
|
||||
bool is_column_major = ( ( order[gc_i] == 'c' ) || ( order[gc_i] == 'C' ) );
|
||||
|
||||
for( dim_t gs_i = 0; gs_i < g_sz; gs_i++ )
|
||||
{
|
||||
if( is_column_major == TRUE )
|
||||
{
|
||||
rs_a[gs_i] = ldb[gc_i];
|
||||
cs_a[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_a[gs_i] = 1;
|
||||
cs_a[gs_i] = ldb[gc_i];
|
||||
}
|
||||
|
||||
rs_b[gs_i] = lda[gc_i];
|
||||
cs_b[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_b[gs_i] = 1;
|
||||
cs_b[gs_i] = lda[gc_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_b[gs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_a[gs_i]) );
|
||||
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
// Reorder is not supported for column major matrices.
|
||||
if ( ( ( mtag_b[gs_i] == REORDERED ) || ( mtag_a[gs_i] == REORDERED ) ) )
|
||||
{
|
||||
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
// Inputs swapped in column major, A becomes B from kernel point of view.
|
||||
if ( bli_is_trans(blis_transb ) )
|
||||
{
|
||||
mtag_a[gs_i] = PACK;
|
||||
}
|
||||
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_b[gs_i] = PACK;
|
||||
}
|
||||
// swap m & n in case of col-major matrices
|
||||
m_local[gs_i] = n[gc_i];
|
||||
n_local[gs_i] = m[gc_i];
|
||||
|
||||
// swap a & b pointers in case of col-major matrices
|
||||
a_local[gs_i] = (float*)(b[mat_idx + gs_i]);
|
||||
b_local[gs_i] = (float*)(a[mat_idx + gs_i]);
|
||||
}
|
||||
else // row-major
|
||||
{
|
||||
rs_a[gs_i] = lda[gc_i];
|
||||
cs_a[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transa ) )
|
||||
{
|
||||
rs_a[gs_i] = 1;
|
||||
cs_a[gs_i] = lda[gc_i];
|
||||
}
|
||||
|
||||
rs_b[gs_i] = ldb[gc_i];
|
||||
cs_b[gs_i] = 1;
|
||||
|
||||
if( bli_is_trans( blis_transb ) )
|
||||
{
|
||||
rs_b[gs_i] = 1;
|
||||
cs_b[gs_i] = ldb[gc_i];
|
||||
}
|
||||
|
||||
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_a[gs_i]) );
|
||||
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_b[gs_i]) );
|
||||
|
||||
// Reorder is not supported for A matrix
|
||||
if( mtag_a[gs_i] == REORDERED )
|
||||
{
|
||||
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
|
||||
goto err_hndl;
|
||||
}
|
||||
// From 5-loop function point of view,
|
||||
// A matrix when in column major storage needs to be packed to row-major
|
||||
// storage as kernel expects A matrix to be in row-major format.
|
||||
if( bli_is_trans(blis_transa ) )
|
||||
{
|
||||
mtag_a[gs_i] = PACK;
|
||||
}
|
||||
|
||||
if( bli_is_trans(blis_transb ) && ( mtag_b[gs_i] == UNPACKED ) )
|
||||
{
|
||||
mtag_b[gs_i] = PACK;
|
||||
}
|
||||
// copy the values of m & n
|
||||
m_local[gs_i] = m[gc_i];
|
||||
n_local[gs_i] = n[gc_i];
|
||||
|
||||
// copy the values of a & b pointers
|
||||
a_local[gs_i] = (float*)(a[mat_idx + gs_i]);
|
||||
b_local[gs_i] = (float*)(b[mat_idx + gs_i]);
|
||||
}
|
||||
|
||||
k_local[gs_i] = k[gc_i];
|
||||
|
||||
rs_c[gs_i] = ldc[gc_i];
|
||||
cs_c[gs_i] = 1;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( F32F32F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_f32f32f32of32_openmp_thread_decorator
|
||||
(
|
||||
g_sz, m_local, n_local, k_local,
|
||||
(const float**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const float**)b_local, rs_b, cs_b, mtag_b,
|
||||
&c[mat_idx], rs_c, cs_c,
|
||||
alpha[gc_i], beta[gc_i],
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
|
||||
#else
|
||||
batch_lpgemm_f32f32f32of32_thread_decorator
|
||||
(
|
||||
g_sz, m_local, n_local, k_local,
|
||||
(const float**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const float**)b_local, rs_b, cs_b, mtag_b,
|
||||
&c[mat_idx], rs_c, cs_c,
|
||||
alpha[gc_i], beta[gc_i],
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
#endif
|
||||
mat_idx += g_sz;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
// that in the case that a runtime is passed in, we make a local copy.
|
||||
rntm_t rntm_g;
|
||||
bli_rntm_init_from_global( &rntm_g );
|
||||
bli_pba_rntm_set_pba( &rntm_g );
|
||||
|
||||
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( F32F32F32OF32 );
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
batch_lpgemm_f32f32f32of32_openmp_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const float**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const float**)b_local, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
|
||||
|
||||
#else
|
||||
batch_lpgemm_f32f32f32of32_thread_decorator
|
||||
(
|
||||
batch_size, m_local, n_local, k,
|
||||
(const float**)a_local, rs_a, cs_a, mtag_a,
|
||||
(const float**)b_local, rs_b, cs_b, mtag_b,
|
||||
c, rs_c, cs_c,
|
||||
alpha, beta,
|
||||
&rntm_g, lcntx_g,
|
||||
post_op_list, F32
|
||||
);
|
||||
#endif
|
||||
|
||||
err_hndl:;
|
||||
LPGEMM_STOP_LOGGER();
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -105,11 +105,11 @@
|
||||
|
||||
#define AOCL_BATCH_GEMM_CHECK( op_str, \
|
||||
order, transa, transb, \
|
||||
gemm_no, \
|
||||
group_count, group_size, \
|
||||
m, n, k, \
|
||||
a, lda, mtag_a, \
|
||||
b, ldb, mtag_b, \
|
||||
c, ldc, \
|
||||
lda, mtag_a, \
|
||||
ldb, mtag_b, \
|
||||
ldc, \
|
||||
err_no \
|
||||
) \
|
||||
{ \
|
||||
@@ -164,12 +164,14 @@
|
||||
info = 17; \
|
||||
else if ( col_stored && ( ldc < m ) ) \
|
||||
info = 17; \
|
||||
else if ( group_count < 0 || group_size < 0 ) \
|
||||
info = 18; \
|
||||
\
|
||||
if( info != 0 ) \
|
||||
{ \
|
||||
char print_msg[ 150 ]; \
|
||||
\
|
||||
sprintf( print_msg, "** On entry to %6s, parameter number %2i of problem %ld had an illegal value", op_str, info, (long int) gemm_no); \
|
||||
sprintf( print_msg, "** On entry to %6s, parameter number %2i of problem %ld had an illegal value", op_str, info, (long int) group_count); \
|
||||
bli_print_msg(print_msg, __FILE__, __LINE__); \
|
||||
err_no = info; \
|
||||
} \
|
||||
|
||||
@@ -192,20 +192,21 @@ BLIS_EXPORT_ADDON void aocl_batch_gemm_ ## LP_SFX \
|
||||
const char* order, \
|
||||
const char* transa, \
|
||||
const char* transb, \
|
||||
const dim_t batch_size, \
|
||||
const dim_t* m, \
|
||||
const dim_t* n, \
|
||||
const dim_t* k, \
|
||||
const Sum_type* alpha, \
|
||||
const A_type** a, \
|
||||
const dim_t* lda, \
|
||||
const char* mem_format_a, \
|
||||
const B_type** b, \
|
||||
const dim_t* ldb, \
|
||||
const char* mem_format_b, \
|
||||
const Sum_type* beta, \
|
||||
C_type** c, \
|
||||
const dim_t* ldc, \
|
||||
const dim_t group_count, \
|
||||
const dim_t* group_size, \
|
||||
const char* mem_format_a, \
|
||||
const char* mem_format_b, \
|
||||
aocl_post_op** post_op_unparsed \
|
||||
) \
|
||||
|
||||
|
||||
@@ -320,7 +320,7 @@ void batch_lpgemm_write_logger_gemm_fn
|
||||
const char* order,
|
||||
const char* transa,
|
||||
const char* transb,
|
||||
const dim_t batch_size,
|
||||
const dim_t group_count,
|
||||
const dim_t* m,
|
||||
const dim_t* n,
|
||||
const dim_t* k,
|
||||
@@ -340,8 +340,8 @@ void batch_lpgemm_write_logger_gemm_fn
|
||||
|
||||
char post_ops_str[2048] = {0};
|
||||
|
||||
fprintf(fd, "%s:bs=%ld\n", op_type, batch_size);
|
||||
for( dim_t i = 0; i < batch_size; i++ )
|
||||
fprintf(fd, "%s:group_count=%ld\n", op_type, group_count);
|
||||
for( dim_t i = 0; i < group_count; i++ )
|
||||
{
|
||||
lpgemm_get_pre_ops_str( post_op_unparsed[i], pre_ops_str );
|
||||
lpgemm_get_post_ops_str( post_op_unparsed[i], post_ops_str );
|
||||
|
||||
@@ -71,7 +71,7 @@ void batch_lpgemm_write_logger_gemm_fn
|
||||
const char* order,
|
||||
const char* transa,
|
||||
const char* transb,
|
||||
const dim_t batch_size,
|
||||
const dim_t group_count,
|
||||
const dim_t* m,
|
||||
const dim_t* n,
|
||||
const dim_t* k,
|
||||
@@ -105,7 +105,7 @@ void lpgemm_write_logger_time_break_fn( FILE* fd, double stime );
|
||||
lpgemm_write_logger_gemm_fn( fd, __VA_ARGS__ ); \
|
||||
|
||||
#define BATCH_LPGEMM_WRITE_LOGGER( op_type, order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
group_count, group_size, m, n, k, \
|
||||
alpha, lda, mem_format_a, \
|
||||
ldb, mem_format_b, beta, \
|
||||
ldc, post_op_unparsed ) \
|
||||
@@ -116,14 +116,14 @@ void lpgemm_write_logger_time_break_fn( FILE* fd, double stime );
|
||||
\
|
||||
char post_ops_str[2048] = {0}; \
|
||||
\
|
||||
fprintf(fd, "%s:bs=%ld\n", op_type, batch_size); \
|
||||
for( dim_t i = 0; i < batch_size; i++ ) \
|
||||
fprintf(fd, "%s:group_count=%ld\n", op_type, group_count); \
|
||||
for( dim_t i = 0; i < group_count; i++ ) \
|
||||
{ \
|
||||
lpgemm_get_pre_ops_str( post_op_unparsed[i], pre_ops_str ); \
|
||||
lpgemm_get_post_ops_str( post_op_unparsed[i], post_ops_str ); \
|
||||
fprintf( fd, "%c %c %c %c %c %ld %ld %ld %ld %ld %ld "\
|
||||
fprintf( fd, "%ld %c %c %c %c %c %ld %ld %ld %ld %ld %ld "\
|
||||
":pre_ops=[%s]:post_ops=[%s] %f %f\n", \
|
||||
order[i], transa[i], transb[i], mem_format_a[i], mem_format_b[i], \
|
||||
group_size[i], order[i], transa[i], transb[i], mem_format_a[i], mem_format_b[i], \
|
||||
m[i], n[i], k[i], lda[i], ldb[i], ldc[i], \
|
||||
pre_ops_str, post_ops_str, \
|
||||
(float)(alpha[i]), (float)(beta[i]) ); \
|
||||
@@ -140,7 +140,7 @@ void lpgemm_write_logger_time_break_fn( FILE* fd, double stime );
|
||||
#define LPGEMM_WRITE_LOGGER(...)
|
||||
|
||||
#define BATCH_LPGEMM_WRITE_LOGGER(op_type, order, transa, transb, \
|
||||
batch_size, m, n, k, \
|
||||
group_count, group_size, m, n, k, \
|
||||
alpha, lda, mem_format_a, \
|
||||
ldb, mem_format_b, beta, \
|
||||
ldc, post_op_unparsed)
|
||||
|
||||
@@ -361,7 +361,7 @@ err_t lpgemm_translate_to_post_ops_list
|
||||
NONE, NONE
|
||||
);
|
||||
|
||||
bli_print_msg(" Max supported post-ops is 5, supplied input post-ops" \
|
||||
bli_print_msg(" Max supported post-ops is 8, supplied input post-ops" \
|
||||
" are more. Exiting..", __FILE__, __LINE__ );
|
||||
return BLIS_UNEXPECTED_VECTOR_DIM; //Error, seq length exceeds max post ops permitted.
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@
|
||||
|
||||
BLIS_INLINE void calculate_n_threads_per_gemm
|
||||
(
|
||||
dim_t batch_size,
|
||||
dim_t group_size,
|
||||
dim_t* n_threads,
|
||||
dim_t* n_gemms_in_parallel,
|
||||
dim_t* n_threads_per_gemm,
|
||||
@@ -61,7 +61,7 @@ BLIS_INLINE void calculate_n_threads_per_gemm
|
||||
}
|
||||
else if( *n_gemms_in_parallel < 1 )
|
||||
{
|
||||
( *n_gemms_in_parallel ) = bli_min( ( *n_threads ), batch_size );
|
||||
( *n_gemms_in_parallel ) = bli_min( ( *n_threads ), group_size );
|
||||
}
|
||||
/* ToDo: All the leftover thrads might go under-utilized. Could be optimized further. */
|
||||
( *n_threads_per_gemm ) = ( *n_threads ) / ( *n_gemms_in_parallel );
|
||||
@@ -376,7 +376,7 @@ BLIS_INLINE void lpgemm_s32o32_get_threading
|
||||
|
||||
BLIS_INLINE void batch_lpgemm_s32o32_get_threading
|
||||
(
|
||||
dim_t batch_size,
|
||||
dim_t group_size,
|
||||
dim_t* n_threads,
|
||||
dim_t* n_gemms_in_parallel,
|
||||
dim_t* n_threads_per_gemm,
|
||||
@@ -390,7 +390,7 @@ BLIS_INLINE void batch_lpgemm_s32o32_get_threading
|
||||
)
|
||||
{
|
||||
|
||||
calculate_n_threads_per_gemm(batch_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
|
||||
calculate_n_threads_per_gemm(group_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
|
||||
|
||||
if ( ( *n_threads_per_gemm ) > 1 )
|
||||
{
|
||||
@@ -441,7 +441,7 @@ BLIS_INLINE void batch_lpgemm_s32o32_get_threading
|
||||
|
||||
BLIS_INLINE void batch_lpgemm_u8s8s32o32_get_threading
|
||||
(
|
||||
dim_t batch_size,
|
||||
dim_t group_size,
|
||||
dim_t* n_threads,
|
||||
dim_t* n_gemms_in_parallel,
|
||||
dim_t* n_threads_per_gemm,
|
||||
@@ -455,7 +455,7 @@ BLIS_INLINE void batch_lpgemm_u8s8s32o32_get_threading
|
||||
{
|
||||
batch_lpgemm_s32o32_get_threading
|
||||
(
|
||||
batch_size,
|
||||
group_size,
|
||||
n_threads, n_gemms_in_parallel, n_threads_per_gemm,
|
||||
ic_ways, jc_ways,
|
||||
m, n, k, rntm_g,
|
||||
@@ -465,7 +465,7 @@ BLIS_INLINE void batch_lpgemm_u8s8s32o32_get_threading
|
||||
|
||||
BLIS_INLINE void batch_lpgemm_s8s8s32o32_get_threading
|
||||
(
|
||||
dim_t batch_size,
|
||||
dim_t group_size,
|
||||
dim_t* n_threads,
|
||||
dim_t* n_gemms_in_parallel,
|
||||
dim_t* n_threads_per_gemm,
|
||||
@@ -479,7 +479,7 @@ BLIS_INLINE void batch_lpgemm_s8s8s32o32_get_threading
|
||||
{
|
||||
batch_lpgemm_s32o32_get_threading
|
||||
(
|
||||
batch_size,
|
||||
group_size,
|
||||
n_threads, n_gemms_in_parallel, n_threads_per_gemm,
|
||||
ic_ways, jc_ways,
|
||||
m, n, k, rntm_g,
|
||||
@@ -608,7 +608,7 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
|
||||
|
||||
BLIS_INLINE void batch_lpgemm_bf16bf16f32of32_get_threading
|
||||
(
|
||||
dim_t batch_size,
|
||||
dim_t group_size,
|
||||
dim_t* n_threads,
|
||||
dim_t* n_gemms_in_parallel,
|
||||
dim_t* n_threads_per_gemm,
|
||||
@@ -621,7 +621,7 @@ BLIS_INLINE void batch_lpgemm_bf16bf16f32of32_get_threading
|
||||
)
|
||||
{
|
||||
|
||||
calculate_n_threads_per_gemm(batch_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
|
||||
calculate_n_threads_per_gemm(group_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
|
||||
|
||||
/* The user is not allowed to set ic_ways or jc_ways */
|
||||
if ( ( *n_threads_per_gemm ) > 1 )
|
||||
@@ -783,7 +783,7 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
|
||||
|
||||
BLIS_INLINE void batch_lpgemm_f32f32f32of32_get_threading
|
||||
(
|
||||
dim_t batch_size,
|
||||
dim_t group_size,
|
||||
dim_t* n_threads,
|
||||
dim_t* n_gemms_in_parallel,
|
||||
dim_t* n_threads_per_gemm,
|
||||
@@ -796,7 +796,7 @@ BLIS_INLINE void batch_lpgemm_f32f32f32of32_get_threading
|
||||
)
|
||||
{
|
||||
|
||||
calculate_n_threads_per_gemm(batch_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
|
||||
calculate_n_threads_per_gemm(group_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
|
||||
|
||||
// Query the context for SUP limits.
|
||||
const dim_t MT = lpgemm_get_sup_thres_MT_global_cntx( F32F32F32OF32 );
|
||||
@@ -1073,7 +1073,7 @@ GEN_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
( \
|
||||
const dim_t batch_size, \
|
||||
const dim_t group_size, \
|
||||
const dim_t* m, \
|
||||
const dim_t* n, \
|
||||
const dim_t* k, \
|
||||
@@ -1088,11 +1088,11 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
C_type** c, \
|
||||
const dim_t* rs_c, \
|
||||
const dim_t* cs_c, \
|
||||
const C_type* alpha, \
|
||||
const C_type* beta, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
rntm_t* rntm_g, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
|
||||
lpgemm_post_op(*post_op_list), \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
) \
|
||||
{ \
|
||||
@@ -1110,7 +1110,7 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
/* Assuming all the problems in GEMM are of same size */ \
|
||||
batch_lpgemm_ ## LPGEMM_SFX ## _get_threading \
|
||||
( \
|
||||
batch_size, \
|
||||
group_size, \
|
||||
&n_threads, \
|
||||
&n_gemms_in_parallel, \
|
||||
&n_threads_per_gemm, \
|
||||
@@ -1163,7 +1163,7 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
thrinfo_t thrinfo; \
|
||||
thrinfo.n_way = n_gemms_in_parallel; \
|
||||
thrinfo.work_id = omp_get_thread_num() / n_threads_per_gemm; \
|
||||
bli_thread_range_sub( &thrinfo, batch_size, 1, FALSE, &gemm_start, &gemm_end ); \
|
||||
bli_thread_range_sub( &thrinfo, group_size, 1, FALSE, &gemm_start, &gemm_end ); \
|
||||
\
|
||||
for( dim_t i = gemm_start; i < gemm_end; i++ ) \
|
||||
{ \
|
||||
@@ -1173,12 +1173,12 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
a[i], rs_a[i], cs_a[i], mtag_a[i], \
|
||||
b[i], rs_b[i], cs_b[i], mtag_b[i], \
|
||||
c[i], rs_c[i], cs_c[i],\
|
||||
alpha[i], \
|
||||
beta[i], \
|
||||
alpha, \
|
||||
beta, \
|
||||
&rntm_l, \
|
||||
&thread, \
|
||||
lcntx, \
|
||||
post_op_list[i], c_downscale \
|
||||
post_op_list, c_downscale \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
@@ -1420,7 +1420,7 @@ GEN_LPGEMM_OPENMP_DECORATOR_GRP(int8_t, int8_t, int32_t, s8s8s32o32_sym_quant)
|
||||
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_MP(A_type,B_type,C_type,LPGEMM_SFX, LPGEMM_PARENT_SFX) \
|
||||
void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
( \
|
||||
const dim_t batch_size, \
|
||||
const dim_t group_size, \
|
||||
const dim_t* m, \
|
||||
const dim_t* n, \
|
||||
const dim_t* k, \
|
||||
@@ -1435,12 +1435,12 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
C_type** c, \
|
||||
const dim_t* rs_c, \
|
||||
const dim_t* cs_c, \
|
||||
const C_type* alpha, \
|
||||
const C_type* beta, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
rntm_t* rntm_g, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
|
||||
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
|
||||
lpgemm_pre_op(*pre_op_list), \
|
||||
lpgemm_post_op(*post_op_list), \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
) \
|
||||
{ \
|
||||
@@ -1458,7 +1458,7 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
/* Assuming all the problems in GEMM are of same size */ \
|
||||
batch_lpgemm_ ## LPGEMM_PARENT_SFX ## _get_threading \
|
||||
( \
|
||||
batch_size, \
|
||||
group_size, \
|
||||
&n_threads, \
|
||||
&n_gemms_in_parallel, \
|
||||
&n_threads_per_gemm, \
|
||||
@@ -1513,7 +1513,7 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
thrinfo_t thrinfo; \
|
||||
thrinfo.n_way = n_gemms_in_parallel; \
|
||||
thrinfo.work_id = omp_get_thread_num() / n_threads_per_gemm; \
|
||||
bli_thread_range_sub( &thrinfo, batch_size, 1, FALSE, &gemm_start, &gemm_end ); \
|
||||
bli_thread_range_sub( &thrinfo, group_size, 1, FALSE, &gemm_start, &gemm_end ); \
|
||||
\
|
||||
for( dim_t i = gemm_start; i < gemm_end; i++ ) \
|
||||
{ \
|
||||
@@ -1534,13 +1534,13 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
a[i], rs_a[i], cs_a[i], mtag_a[i], \
|
||||
b[i], rs_b[i], cs_b[i], mtag_b[i], \
|
||||
c[i], rs_c[i], cs_c[i],\
|
||||
alpha[i], \
|
||||
beta[i], \
|
||||
alpha, \
|
||||
beta, \
|
||||
&rntm_l, \
|
||||
&thread, \
|
||||
lcntx, \
|
||||
pre_op_list[i], \
|
||||
post_op_list[i], c_downscale \
|
||||
pre_op_list, \
|
||||
post_op_list, c_downscale \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
@@ -1972,7 +1972,7 @@ GEN_LPGEMM_DECORATOR2(int8_t, int8_t, int32_t, s8s8s32o32_sym_quant)
|
||||
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
( \
|
||||
const dim_t batch_size, \
|
||||
const dim_t group_size, \
|
||||
const dim_t* m, \
|
||||
const dim_t* n, \
|
||||
const dim_t* k, \
|
||||
@@ -1987,11 +1987,11 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
C_type** c, \
|
||||
const dim_t* rs_c, \
|
||||
const dim_t* cs_c, \
|
||||
const C_type* alpha, \
|
||||
const C_type* beta, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
rntm_t* rntm_g, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
|
||||
lpgemm_post_op(*post_op_list), \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
) \
|
||||
{ \
|
||||
@@ -2020,7 +2020,7 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
thread.jc_ways = jc_ways; \
|
||||
thread.comm = cur_lpgemm_comm; \
|
||||
dim_t gemm_start = 0; \
|
||||
dim_t gemm_end = batch_size; \
|
||||
dim_t gemm_end = group_size; \
|
||||
\
|
||||
for( dim_t i = gemm_start; i < gemm_end; i++ ) \
|
||||
{ \
|
||||
@@ -2030,12 +2030,12 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
a[i], rs_a[i], cs_a[i], mtag_a[i], \
|
||||
b[i], rs_b[i], cs_b[i], mtag_b[i], \
|
||||
c[i], rs_c[i], cs_c[i],\
|
||||
alpha[i], \
|
||||
beta[i], \
|
||||
alpha, \
|
||||
beta, \
|
||||
rntm_g, \
|
||||
&thread, \
|
||||
lcntx, \
|
||||
post_op_list[i], c_downscale \
|
||||
post_op_list, c_downscale \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
@@ -2048,7 +2048,7 @@ GEN_BATCH_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_MP(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
( \
|
||||
const dim_t batch_size, \
|
||||
const dim_t group_size, \
|
||||
const dim_t* m, \
|
||||
const dim_t* n, \
|
||||
const dim_t* k, \
|
||||
@@ -2063,12 +2063,12 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
C_type** c, \
|
||||
const dim_t* rs_c, \
|
||||
const dim_t* cs_c, \
|
||||
const C_type* alpha, \
|
||||
const C_type* beta, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
rntm_t* rntm_g, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
|
||||
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
|
||||
lpgemm_pre_op(*pre_op_list), \
|
||||
lpgemm_post_op(*post_op_list), \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
) \
|
||||
{ \
|
||||
@@ -2097,7 +2097,7 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
thread.jc_ways = jc_ways; \
|
||||
thread.comm = cur_lpgemm_comm; \
|
||||
dim_t gemm_start = 0; \
|
||||
dim_t gemm_end = batch_size; \
|
||||
dim_t gemm_end = group_size; \
|
||||
\
|
||||
for( dim_t i = gemm_start; i < gemm_end; i++ ) \
|
||||
{ \
|
||||
@@ -2107,13 +2107,13 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
a[i], rs_a[i], cs_a[i], mtag_a[i], \
|
||||
b[i], rs_b[i], cs_b[i], mtag_b[i], \
|
||||
c[i], rs_c[i], cs_c[i],\
|
||||
alpha[i], \
|
||||
beta[i], \
|
||||
alpha, \
|
||||
beta, \
|
||||
rntm_g, \
|
||||
&thread, \
|
||||
lcntx, \
|
||||
pre_op_list[i], \
|
||||
post_op_list[i], c_downscale \
|
||||
pre_op_list, \
|
||||
post_op_list, c_downscale \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
|
||||
@@ -75,7 +75,7 @@ GEN_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
( \
|
||||
const dim_t batch_size, \
|
||||
const dim_t group_size, \
|
||||
const dim_t* m, \
|
||||
const dim_t* n, \
|
||||
const dim_t* k, \
|
||||
@@ -90,11 +90,11 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
C_type** c, \
|
||||
const dim_t* rs_c, \
|
||||
const dim_t* cs_c, \
|
||||
const C_type* alpha, \
|
||||
const C_type* beta, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
rntm_t* rntm_g, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
|
||||
lpgemm_post_op(*post_op_list), \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
); \
|
||||
|
||||
@@ -107,7 +107,7 @@ GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN_MXP(A_type,B_type,C_type,LPGEMM_SFX) \
|
||||
void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
( \
|
||||
const dim_t batch_size, \
|
||||
const dim_t group_size, \
|
||||
const dim_t* m, \
|
||||
const dim_t* n, \
|
||||
const dim_t* k, \
|
||||
@@ -122,12 +122,12 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
|
||||
C_type** c, \
|
||||
const dim_t* rs_c, \
|
||||
const dim_t* cs_c, \
|
||||
const C_type* alpha, \
|
||||
const C_type* beta, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
rntm_t* rntm_g, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
|
||||
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
|
||||
lpgemm_pre_op(*pre_op_list), \
|
||||
lpgemm_post_op(*post_op_list), \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
); \
|
||||
|
||||
@@ -261,11 +261,11 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
C_type** c, \
|
||||
const dim_t* rs_c, \
|
||||
const dim_t* cs_c, \
|
||||
const C_type* alpha, \
|
||||
const C_type* beta, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
rntm_t* rntm_g, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
|
||||
lpgemm_post_op(*post_op_list), \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
); \
|
||||
|
||||
@@ -292,12 +292,12 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
|
||||
C_type** c, \
|
||||
const dim_t* rs_c, \
|
||||
const dim_t* cs_c, \
|
||||
const C_type* alpha, \
|
||||
const C_type* beta, \
|
||||
const C_type alpha, \
|
||||
const C_type beta, \
|
||||
rntm_t* rntm_g, \
|
||||
lpgemm_cntx_t* lcntx, \
|
||||
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
|
||||
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
|
||||
lpgemm_pre_op(*pre_op_list), \
|
||||
lpgemm_post_op(*post_op_list), \
|
||||
AOCL_STORAGE_TYPE c_downscale \
|
||||
); \
|
||||
|
||||
|
||||
Reference in New Issue
Block a user