Updated aocl_batch_gemm_ APIs aligning to CBLAS batch API. (#58)

* Updated aocl_batch_gemm_ APIs aligning to CBLAS batch API.

 - Modified Batch-Gemm API to align with cblas_?gemm_batch_ API,
 and added a parameter group_size to the existing APIs.
 - Updated bench batch_gemm code to align to the new API definition.
 - Modified the hardcoded number in lpgemm_postop file.
 - Added necessary early return condition to account for group_count/group_size < 0.

AMD-Internal: [ SWLCSG - 3592 ]
This commit is contained in:
V, Varsha
2025-06-30 11:16:04 +05:30
committed by GitHub
parent 5193433141
commit 1f9d1a85d3
14 changed files with 2980 additions and 2707 deletions

View File

@@ -50,7 +50,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
(
"bf16bf16f32of32", \
order, transa, transb, \
batch_size, m, n, k, \
group_count, group_size, m, n, k, \
( ( float* ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
@@ -58,29 +58,11 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
bfloat16 *a_local[batch_size], *b_local[batch_size];
dim_t m_local[batch_size], n_local[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
// Check if AVX2 ISA is supported, lpgemm fp32 matmul only works with it.
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
bli_print_msg(" AVX2 ISA not supported by processor, "
"cannot perform f32f32f32 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
@@ -90,194 +72,225 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32)
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
#ifdef LPGEMM_BF16_JIT
if( get_jit_kernels_generated() == FALSE )
{
bli_print_msg(" Could not generate bf16bf16f32of32 "
" kernels using JIT.", __FILE__, __LINE__ );
goto err_hndl;
}
#endif
#ifdef LPGEMM_BF16_JIT
if( get_jit_kernels_generated() == FALSE )
{
bli_print_msg(" Could not generate bf16bf16f32of32 "
" kernels using JIT.", __FILE__, __LINE__ );
goto err_hndl;
}
#endif
trans_t blis_transa;
trans_t blis_transb;
// offset to get subsequent matrix when group_count > 1
dim_t mat_idx = 0;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
for( dim_t gc_i = 0; gc_i < group_count; gc_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_bf16bf16f32of32",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
"batch_bf16bf16f32of32",
order[gc_i], transa[gc_i], transb[gc_i],
group_count, group_size[gc_i],
m[gc_i], n[gc_i], k[gc_i],
lda[gc_i], mem_format_a[gc_i],
ldb[gc_i], mem_format_b[gc_i],
ldc[gc_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
// Group_size is used across
dim_t g_sz = group_size[gc_i];
if( is_column_major == TRUE )
{
rs_a[bs_i] = ldb[bs_i];
cs_a[bs_i] = 1;
trans_t blis_transa;
trans_t blis_transb;
if( bli_is_trans( blis_transb ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = ldb[bs_i];
}
inc_t rs_a[g_sz];
inc_t cs_a[g_sz];
rs_b[bs_i] = lda[bs_i];
cs_b[bs_i] = 1;
inc_t rs_b[g_sz];
inc_t cs_b[g_sz];
if( bli_is_trans( blis_transa ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = lda[bs_i];
}
inc_t rs_c[g_sz];
inc_t cs_c[g_sz];
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
AOCL_MEMORY_TAG mtag_a[g_sz];
AOCL_MEMORY_TAG mtag_b[g_sz];
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
// Inputs swapped in column major, A becomes B from kernel point of view.
if ( bli_is_trans(blis_transb ) )
{
mtag_a[bs_i] = PACK;
}
bfloat16 *a_local[g_sz], *b_local[g_sz];
dim_t m_local[g_sz], n_local[g_sz], k_local[g_sz];
// swap m & n in case of col-major matrices
m_local[bs_i] = n[bs_i];
n_local[bs_i] = m[bs_i];
// swap a & b pointers in case of col-major matrices
a_local[bs_i] = (bfloat16*)(b[bs_i]);
b_local[bs_i] = (bfloat16*)(a[bs_i]);
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
// copy the values of m & n
m_local[bs_i] = m[bs_i];
n_local[bs_i] = n[bs_i];
// copy the values of a & b pointers
a_local[bs_i] = (bfloat16*)(a[bs_i]);
b_local[bs_i] = (bfloat16*)(b[bs_i]);
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
(
post_op_unparsed[gc_i], post_op_list,
( void* )c[gc_i], ( void* )( (order + gc_i) ),
m[gc_i], n[gc_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[gc_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[gc_i], &blis_transb );
bool is_column_major = ( ( order[gc_i] == 'c' ) || ( order[gc_i] == 'C' ) );
for( dim_t gs_i = 0; gs_i < g_sz; gs_i++ )
{
if( is_column_major == TRUE )
{
rs_a[gs_i] = ldb[gc_i];
cs_a[gs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_a[gs_i] = 1;
cs_a[gs_i] = ldb[gc_i];
}
rs_b[gs_i] = lda[gc_i];
cs_b[gs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_b[gs_i] = 1;
cs_b[gs_i] = lda[gc_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_b[gs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_a[gs_i]) );
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
if ( ( ( mtag_b[gs_i] == REORDERED ) || ( mtag_a[gs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
// Inputs swapped in column major, A becomes B from kernel point of view.
if ( bli_is_trans(blis_transb ) )
{
mtag_a[gs_i] = PACK;
}
// swap m & n in case of col-major matrices
m_local[gs_i] = n[gc_i];
n_local[gs_i] = m[gc_i];
// swap a & b pointers in case of col-major matrices
a_local[gs_i] = (bfloat16*)(b[mat_idx + gs_i]);
b_local[gs_i] = (bfloat16*)(a[mat_idx + gs_i]);
}
else // row-major
{
rs_a[gs_i] = lda[gc_i];
cs_a[gs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[gs_i] = 1;
cs_a[gs_i] = lda[gc_i];
}
rs_b[gs_i] = ldb[gc_i];
cs_b[gs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[gs_i] = 1;
cs_b[gs_i] = ldb[gc_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_a[gs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_b[gs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[gs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[gs_i] = PACK;
}
// copy the values of m & n
m_local[gs_i] = m[gc_i];
n_local[gs_i] = n[gc_i];
// copy the values of a & b pointers
a_local[gs_i] = (bfloat16*)(a[mat_idx + gs_i]);
b_local[gs_i] = (bfloat16*)(b[mat_idx + gs_i]);
}
k_local[gs_i] = k[gc_i];
rs_c[gs_i] = ldc[gc_i];
cs_c[gs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[gs_i] == UNPACKED )
{
mtag_b[gs_i] = PACK;
}
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_bf16bf16f32of32_openmp_thread_decorator
(
g_sz, m_local, n_local, k_local,
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
&c[mat_idx], rs_c, cs_c,
alpha[gc_i], beta[gc_i],
&rntm_g, lcntx_g,
post_op_list, F32
);
#else
batch_lpgemm_bf16bf16f32of32_thread_decorator
(
g_sz, m_local, n_local, k_local,
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
&c[mat_idx], rs_c, cs_c,
alpha[gc_i], beta[gc_i],
&rntm_g, lcntx_g,
post_op_list, F32
);
#endif
mat_idx+=g_sz;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_bf16bf16f32of32_openmp_thread_decorator
(
batch_size, m_local, n_local, k,
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
#else
batch_lpgemm_bf16bf16f32of32_thread_decorator
(
batch_size, m_local, n_local, k,
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}
@@ -289,7 +302,7 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
(
"bf16bf16f32obf16", \
order, transa, transb, \
batch_size, m, n, k, \
group_count, group_size, m, n, k, \
( ( float* ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
@@ -297,24 +310,6 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
bfloat16 *a_local[batch_size], *b_local[batch_size];
dim_t m_local[batch_size], n_local[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
{
@@ -338,25 +333,25 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
}
#endif
trans_t blis_transa;
trans_t blis_transb;
// offset to get subsequent matrix when group_count > 1
dim_t mat_idx = 0;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
for( dim_t gc_i = 0; gc_i < group_count; gc_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_bf16bf16f32obf16",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
"batch_bf16bf16f32obf16",
order[gc_i], transa[gc_i], transb[gc_i],
group_count, group_size[gc_i],
m[gc_i], n[gc_i], k[gc_i],
lda[gc_i], mem_format_a[gc_i],
ldb[gc_i], mem_format_b[gc_i],
ldc[gc_i],
err_no
);
if ( err_no != 0 )
@@ -364,163 +359,190 @@ AOCL_BGEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
// Group_size is used across
dim_t g_sz = group_size[gc_i];
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
trans_t blis_transa;
trans_t blis_transb;
if( is_column_major == TRUE )
{
rs_a[bs_i] = ldb[bs_i];
cs_a[bs_i] = 1;
inc_t rs_a[g_sz];
inc_t cs_a[g_sz];
if( bli_is_trans( blis_transb ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = ldb[bs_i];
}
inc_t rs_b[g_sz];
inc_t cs_b[g_sz];
rs_b[bs_i] = lda[bs_i];
cs_b[bs_i] = 1;
inc_t rs_c[g_sz];
inc_t cs_c[g_sz];
if( bli_is_trans( blis_transa ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = lda[bs_i];
}
AOCL_MEMORY_TAG mtag_a[g_sz];
AOCL_MEMORY_TAG mtag_b[g_sz];
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
bfloat16 *a_local[g_sz], *b_local[g_sz];
dim_t m_local[g_sz], n_local[g_sz], k_local[g_sz];
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
// Inputs swapped in column major, A becomes B from kernel point of view.
if ( bli_is_trans(blis_transb ) )
{
mtag_a[bs_i] = PACK;
}
// swap m & n in case of col-major matrices
m_local[bs_i] = n[bs_i];
n_local[bs_i] = m[bs_i];
// swap a & b pointers in case of col-major matrices
a_local[bs_i] = (bfloat16*)(b[bs_i]);
b_local[bs_i] = (bfloat16*)(a[bs_i]);
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
// copy the values of m & n
m_local[bs_i] = m[bs_i];
n_local[bs_i] = n[bs_i];
// copy the values of a & b pointers
a_local[bs_i] = (bfloat16*)(a[bs_i]);
b_local[bs_i] = (bfloat16*)(b[bs_i]);
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
post_op_unparsed[gc_i], post_op_list,
( void* )c[gc_i], ( void* )( order + gc_i ),
m[gc_i], n[gc_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[gc_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[gc_i], &blis_transb );
bool is_column_major = ( ( order[gc_i] == 'c' ) || ( order[gc_i] == 'C' ) );
for( dim_t gs_i = 0; gs_i < g_sz; gs_i++ )
{
if( is_column_major == TRUE )
{
rs_a[gs_i] = ldb[gc_i];
cs_a[gs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_a[gs_i] = 1;
cs_a[gs_i] = ldb[gc_i];
}
rs_b[gs_i] = lda[gc_i];
cs_b[gs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_b[gs_i] = 1;
cs_b[gs_i] = lda[gc_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_b[gs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_a[gs_i]) );
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
if ( ( ( mtag_b[gs_i] == REORDERED ) || ( mtag_a[gs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
// Inputs swapped in column major, A becomes B from kernel point of view.
if ( bli_is_trans(blis_transb ) )
{
mtag_a[gs_i] = PACK;
}
// swap m & n in case of col-major matrices
m_local[gs_i] = n[gc_i];
n_local[gs_i] = m[gc_i];
// swap a & b pointers in case of col-major matrices
a_local[gs_i] = (bfloat16*)(b[mat_idx + gs_i]);
b_local[gs_i] = (bfloat16*)(a[mat_idx + gs_i]);
}
else // row-major
{
rs_a[gs_i] = lda[gc_i];
cs_a[gs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[gs_i] = 1;
cs_a[gs_i] = lda[gc_i];
}
rs_b[gs_i] = ldb[gc_i];
cs_b[gs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[gs_i] = 1;
cs_b[gs_i] = ldb[gc_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_a[gs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_b[gs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[gs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[gs_i] = PACK;
}
// copy the values of m & n
m_local[gs_i] = m[gc_i];
n_local[gs_i] = n[gc_i];
// copy the values of a & b pointers
a_local[gs_i] = (bfloat16*)(a[mat_idx + gs_i]);
b_local[gs_i] = (bfloat16*)(b[mat_idx + gs_i]);
}
k_local[gs_i] = k[gc_i];
rs_c[gs_i] = ldc[gc_i];
cs_c[gs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[gs_i] == UNPACKED )
{
mtag_b[gs_i] = PACK;
}
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_bf16bf16f32of32_openmp_thread_decorator
(
g_sz, m_local, n_local, k_local,
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
(float**)&c[mat_idx], rs_c, cs_c,
alpha[gc_i], beta[gc_i],
&rntm_g, lcntx_g,
post_op_list, BF16
);
#else
batch_lpgemm_bf16bf16f32of32_thread_decorator
(
g_sz, m_local, n_local, k_local,
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
(float**)&c[mat_idx], rs_c, cs_c,
alpha[gc_i], beta[gc_i],
&rntm_g, lcntx_g,
post_op_list, BF16
);
#endif
mat_idx += g_sz;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_bf16bf16f32of32_openmp_thread_decorator
(
batch_size, m_local, n_local, k,
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
(float**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, BF16
);
#else
batch_lpgemm_bf16bf16f32of32_thread_decorator
(
batch_size, m_local, n_local, k,
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
(const bfloat16**)b_local, rs_b, cs_b, mtag_b,
(float**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, BF16
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -50,7 +50,7 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,float,float,bf16s4f32of32)
(
"bf16s4f32of32", \
order, transa, transb, \
batch_size, m, n, k, \
group_count, group_size, m, n, k, \
( ( float* ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
@@ -58,21 +58,6 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,float,float,bf16s4f32of32)
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
lpgemm_pre_op pre_op_list[batch_size][AOCL_MAX_PRE_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
{
@@ -87,154 +72,191 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,float,float,bf16s4f32of32)
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
#ifdef LPGEMM_BF16_JIT
bli_print_msg(" WOQ is not supported by JIT kernels.", __FILE__, __LINE__ );
return;
#endif
#ifdef LPGEMM_BF16_JIT
bli_print_msg(" WOQ is not supported by JIT kernels.", __FILE__, __LINE__ );
return;
#endif
trans_t blis_transa;
trans_t blis_transb;
// offset to get subsequent matrix when group_count > 1
dim_t mat_idx = 0;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
for( dim_t gc_i = 0; gc_i < group_count; gc_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_bf16s4f32of32",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
"batch_bf16s4f32of32",
order[gc_i], transa[gc_i], transb[gc_i],
group_count, group_size[gc_i],
m[gc_i], n[gc_i], k[gc_i],
lda[gc_i], mem_format_a[gc_i],
ldb[gc_i], mem_format_b[gc_i],
ldc[gc_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
dim_t g_sz = group_size[gc_i];
if( is_column_major == TRUE )
{
bli_print_msg("Column major inputs not supported.",
__FILE__, __LINE__);
goto err_hndl;
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
trans_t blis_transa;
trans_t blis_transb;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
inc_t rs_a[g_sz];
inc_t cs_a[g_sz];
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
inc_t rs_b[g_sz];
inc_t cs_b[g_sz];
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
inc_t rs_c[g_sz];
inc_t cs_c[g_sz];
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
AOCL_MEMORY_TAG mtag_a[g_sz];
AOCL_MEMORY_TAG mtag_b[g_sz];
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS];
// Convert pre op struct to pre op linked list format.
err_t err = lpgemm_translate_to_pre_ops_list
(
post_op_unparsed[bs_i]->pre_ops,
pre_op_list[bs_i],
m[bs_i], n[bs_i], k[bs_i]
);
(
post_op_unparsed[gc_i]->pre_ops,
pre_op_list,
m[gc_i], n[gc_i], k[gc_i]
);
if (err != BLIS_SUCCESS) goto err_hndl;
// Convert post op struct to post op linked list format.
err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
post_op_unparsed[gc_i], post_op_list,
( void* )c[gc_i], ( void* )( order + gc_i ),
m[gc_i], n[gc_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
bfloat16 *a_local[g_sz];
int8_t* b_local[g_sz];
dim_t m_local[g_sz], n_local[g_sz], k_local[g_sz];
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[gc_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[gc_i], &blis_transb );
bool is_column_major = ( ( order[gc_i] == 'c' ) || ( order[gc_i] == 'C' ) );
for( dim_t gs_i = 0; gs_i < g_sz; gs_i++ )
{
if( is_column_major == TRUE )
{
bli_print_msg("Column major inputs not supported.",
__FILE__, __LINE__);
goto err_hndl;
}
else // row-major
{
rs_a[gs_i] = lda[gc_i];
cs_a[gs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[gs_i] = 1;
cs_a[gs_i] = lda[gc_i];
}
rs_b[gs_i] = ldb[gc_i];
cs_b[gs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[gs_i] = 1;
cs_b[gs_i] = ldb[gc_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_a[gs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_b[gs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[gs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[gs_i] = PACK;
}
// copy the values of m & n
m_local[gs_i] = m[gc_i];
n_local[gs_i] = n[gc_i];
// copy the values of a & b pointers
a_local[gs_i] = (bfloat16*)(a[mat_idx + gs_i]);
b_local[gs_i] = (int8_t*)(b[mat_idx + gs_i]);
}
k_local[gs_i] = k[gc_i];
rs_c[gs_i] = ldc[gc_i];
cs_c[gs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[gs_i] == UNPACKED )
{
mtag_b[gs_i] = PACK;
}
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16S4F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_bf16s4f32of32_openmp_thread_decorator
(
g_sz, m_local, n_local, k_local,
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
&c[mat_idx], rs_c, cs_c,
alpha[gc_i], beta[gc_i],
&rntm_g, lcntx_g,
pre_op_list, post_op_list, F32
);
#else
batch_lpgemm_bf16s4f32of32_thread_decorator
(
g_sz, m_local, n_local, k_local,
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
&c[mat_idx], rs_c, cs_c,
alpha[gc_i], beta[gc_i],
&rntm_g, lcntx_g,
pre_op_list, post_op_list, F32
);
#endif
mat_idx += g_sz;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16S4F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_bf16s4f32of32_openmp_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
pre_op_list, post_op_list, F32
);
#else
batch_lpgemm_bf16s4f32of32_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
pre_op_list, post_op_list, F32
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}
@@ -246,7 +268,7 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
(
"bf16s4f32obf16", \
order, transa, transb, \
batch_size, m, n, k, \
group_count, group_size, m, n, k, \
( ( float* ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
@@ -254,26 +276,11 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
ldc, post_op_unparsed \
);
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
lpgemm_pre_op pre_op_list[batch_size][AOCL_MAX_PRE_OPS];
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
if ( bli_cpuid_is_avx512bf16_supported() == FALSE )
{
bli_print_msg(" AVX512_BF16 ISA not supported by processor, "
"cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ );
"cannot perform bf16s4f32 gemm.", __FILE__, __LINE__ );
goto err_hndl;
}
@@ -283,30 +290,30 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
#ifdef LPGEMM_BF16_JIT
bli_print_msg(" WOQ is not supported by JIT kernels.", __FILE__, __LINE__ );
return;
#endif
#ifdef LPGEMM_BF16_JIT
bli_print_msg(" WOQ is not supported by JIT kernels.", __FILE__, __LINE__ );
return;
#endif
trans_t blis_transa;
trans_t blis_transb;
// offset to get subsequent matrix when group_count > 1
dim_t mat_idx = 0;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
for( dim_t gc_i = 0; gc_i < group_count; gc_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_bf16s4f32obf16",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
"batch_bf16s4f32obf16",
order[gc_i], transa[gc_i], transb[gc_i],
group_count, group_size[gc_i],
m[gc_i], n[gc_i], k[gc_i],
lda[gc_i], mem_format_a[gc_i],
ldb[gc_i], mem_format_b[gc_i],
ldc[gc_i],
err_no
);
if ( err_no != 0 )
@@ -314,125 +321,177 @@ AOCL_BGEMM_MATMUL(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
dim_t g_sz = group_size[gc_i];
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
inc_t rs_a[g_sz];
inc_t cs_a[g_sz];
if( is_column_major == TRUE )
{
bli_print_msg("Column major inputs not supported.",
__FILE__, __LINE__);
goto err_hndl;
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
inc_t rs_b[g_sz];
inc_t cs_b[g_sz];
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
inc_t rs_c[g_sz];
inc_t cs_c[g_sz];
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
AOCL_MEMORY_TAG mtag_a[g_sz];
AOCL_MEMORY_TAG mtag_b[g_sz];
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bfloat16 *a_local[g_sz];
int8_t *b_local[g_sz];
dim_t m_local[g_sz], n_local[g_sz], k_local[g_sz];
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
trans_t blis_transa;
trans_t blis_transb;
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[bs_i] == UNPACKED )
{
mtag_b[bs_i] = PACK;
}
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
lpgemm_pre_op pre_op_list[AOCL_MAX_PRE_OPS];
// Convert pre op struct to pre op linked list format.
err_t err = lpgemm_translate_to_pre_ops_list
(
post_op_unparsed[bs_i]->pre_ops,
pre_op_list[bs_i],
m[bs_i], n[bs_i], k[bs_i]
);
(
post_op_unparsed[gc_i]->pre_ops,
pre_op_list,
m[gc_i], n[gc_i], k[gc_i]
);
if (err != BLIS_SUCCESS) goto err_hndl;
// Convert post op struct to post op linked list format.
err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
(
post_op_unparsed[gc_i], post_op_list,
( void* )c[gc_i], ( void* )( order+gc_i ),
m[gc_i], n[gc_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
for( dim_t gs_i = 0; gs_i < g_sz; gs_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_bf16s4f32obf16",
order[gc_i], transa[gc_i], transb[gc_i],
group_count, group_size[gc_i],
m[gc_i], n[gc_i], k[gc_i],
lda[gc_i], mem_format_a[gc_i],
ldb[gc_i], mem_format_b[gc_i],
ldc[gc_i],
err_no
);
if ( err_no != 0 )
{
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[gc_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[gc_i], &blis_transb );
bool is_column_major = ( ( order[gc_i] == 'c' ) || ( order[gc_i] == 'C' ) );
if( is_column_major == TRUE )
{
bli_print_msg("Column major inputs not supported.",
__FILE__, __LINE__);
goto err_hndl;
}
else // row-major
{
rs_a[gs_i] = lda[gc_i];
cs_a[gs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[gs_i] = 1;
cs_a[gs_i] = lda[gc_i];
}
rs_b[gs_i] = ldb[gc_i];
cs_b[gs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[gs_i] = 1;
cs_b[gs_i] = ldb[gc_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[gs_i], &(mtag_a[gs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[gs_i], &(mtag_b[gs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[gs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[gs_i] = PACK;
}
// copy the values of m & n
m_local[gs_i] = m[gc_i];
n_local[gs_i] = n[gc_i];
// copy the values of a & b pointers
a_local[gs_i] = (bfloat16*)(a[mat_idx + gs_i]);
b_local[gs_i] = (int8_t*)(b[mat_idx + gs_i]);
}
k_local[gs_i] = k[gc_i];
rs_c[gs_i] = ldc[gc_i];
cs_c[gs_i] = 1;
// From 5-loop function point of view
// B matrix needs to be packed in a certain format in order to be loaded
// and used in bf16 instrution. As such the mtag_b always needs to be either
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
// the mtag_b is set to packed to enable runtime packing.
if ( mtag_b[gs_i] == UNPACKED )
{
mtag_b[gs_i] = PACK;
}
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16S4F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_bf16s4f32of32_openmp_thread_decorator
(
g_sz, m_local, n_local, k_local,
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
(float**)&c[mat_idx], rs_c, cs_c,
alpha[gc_i], beta[gc_i],
&rntm_g, lcntx_g,
pre_op_list, post_op_list, BF16
);
#else
batch_lpgemm_bf16s4f32of32_thread_decorator
(
g_sz, m_local, n_local, k_local,
(const bfloat16**)a_local, rs_a, cs_a, mtag_a,
(const int8_t**)b_local, rs_b, cs_b, mtag_b,
(float**)&c[mat_idx], rs_c, cs_c,
alpha[gc_i], beta[gc_i],
&rntm_g, lcntx_g,
pre_op_list, post_op_list, BF16
);
#endif
mat_idx += g_sz;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16S4F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_bf16s4f32of32_openmp_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(float**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
pre_op_list, post_op_list, BF16
);
#else
batch_lpgemm_bf16s4f32of32_thread_decorator
(
batch_size, m, n, k,
a, rs_a, cs_a, mtag_a,
b, rs_b, cs_b, mtag_b,
(float**)c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
pre_op_list, post_op_list, BF16
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

View File

@@ -50,7 +50,7 @@ AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32)
(
"f32f32f32of32", \
order, transa, transb, \
batch_size, m, n, k, \
group_count, group_size, m, n, k, \
( ( float* ) alpha ), \
lda, mem_format_a, \
ldb, mem_format_b, \
@@ -58,27 +58,6 @@ AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32)
ldc, post_op_unparsed \
);
trans_t blis_transa;
trans_t blis_transb;
inc_t rs_a[batch_size];
inc_t cs_a[batch_size];
inc_t rs_b[batch_size];
inc_t cs_b[batch_size];
inc_t rs_c[batch_size];
inc_t cs_c[batch_size];
AOCL_MEMORY_TAG mtag_a[batch_size];
AOCL_MEMORY_TAG mtag_b[batch_size];
float *a_local[batch_size], *b_local[batch_size];
dim_t m_local[batch_size], n_local[batch_size];
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[batch_size][AOCL_MAX_POST_OPS];
// Check if AVX2 ISA is supported, lpgemm fp32 matmul only works with it.
if ( bli_cpuid_is_avx2fma3_supported() == FALSE )
{
@@ -93,22 +72,25 @@ AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32)
// Set MC, NC, KC, NR, MR.
aocl_lpgemm_init_global_cntx();
// offset to get subsequent matrix when group_count > 1
dim_t mat_idx = 0;
// check for validity of params.
int err_no = 0;
for( dim_t bs_i = 0; bs_i < batch_size; bs_i++ )
for( dim_t gc_i = 0; gc_i < group_count; gc_i++ )
{
// check for validity of params.
AOCL_BATCH_GEMM_CHECK
(
"batch_f32f32f32of32",
order[bs_i], transa[bs_i], transb[bs_i],
bs_i,
m[bs_i], n[bs_i], k[bs_i],
a[bs_i], lda[bs_i], mem_format_a[bs_i],
b[bs_i], ldb[bs_i], mem_format_b[bs_i],
c[bs_i], ldc[bs_i],
err_no
"batch_f32f32f32of32",
order[gc_i], transa[gc_i], transb[gc_i],
group_count, group_size[gc_i],
m[gc_i], n[gc_i], k[gc_i],
lda[gc_i], mem_format_a[gc_i],
ldb[gc_i], mem_format_b[gc_i],
ldc[gc_i],
err_no
);
if ( err_no != 0 )
@@ -116,163 +98,188 @@ AOCL_BGEMM_MATMUL(float,float,float,float,f32f32f32of32)
goto err_hndl;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[bs_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[bs_i], &blis_transb );
// Group_size is used across
dim_t g_sz = group_size[gc_i];
bool is_column_major = ( ( order[bs_i] == 'c' ) || ( order[bs_i] == 'C' ) );
// Convert post op struct to post op linked list format.
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
if( is_column_major == TRUE )
{
rs_a[bs_i] = ldb[bs_i];
cs_a[bs_i] = 1;
trans_t blis_transa;
trans_t blis_transb;
if( bli_is_trans( blis_transb ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = ldb[bs_i];
}
inc_t rs_a[g_sz];
inc_t cs_a[g_sz];
rs_b[bs_i] = lda[bs_i];
cs_b[bs_i] = 1;
inc_t rs_b[g_sz];
inc_t cs_b[g_sz];
if( bli_is_trans( blis_transa ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = lda[bs_i];
}
inc_t rs_c[g_sz];
inc_t cs_c[g_sz];
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_b[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_a[bs_i]) );
AOCL_MEMORY_TAG mtag_a[g_sz];
AOCL_MEMORY_TAG mtag_b[g_sz];
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
if ( ( ( mtag_b[bs_i] == REORDERED ) || ( mtag_a[bs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
// Inputs swapped in column major, A becomes B from kernel point of view.
if ( bli_is_trans(blis_transb ) )
{
mtag_a[bs_i] = PACK;
}
if( bli_is_trans(blis_transa ) )
{
mtag_b[bs_i] = PACK;
}
// swap m & n in case of col-major matrices
m_local[bs_i] = n[bs_i];
n_local[bs_i] = m[bs_i];
// swap a & b pointers in case of col-major matrices
a_local[bs_i] = (float*)(b[bs_i]);
b_local[bs_i] = (float*)(a[bs_i]);
}
else // row-major
{
rs_a[bs_i] = lda[bs_i];
cs_a[bs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[bs_i] = 1;
cs_a[bs_i] = lda[bs_i];
}
rs_b[bs_i] = ldb[bs_i];
cs_b[bs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[bs_i] = 1;
cs_b[bs_i] = ldb[bs_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[bs_i], &(mtag_a[bs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[bs_i], &(mtag_b[bs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[bs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[bs_i] = PACK;
}
if( bli_is_trans(blis_transb ) && ( mtag_b[bs_i] == UNPACKED ) )
{
mtag_b[bs_i] = PACK;
}
// copy the values of m & n
m_local[bs_i] = m[bs_i];
n_local[bs_i] = n[bs_i];
// copy the values of a & b pointers
a_local[bs_i] = (float*)(a[bs_i]);
b_local[bs_i] = (float*)(b[bs_i]);
}
rs_c[bs_i] = ldc[bs_i];
cs_c[bs_i] = 1;
float *a_local[g_sz], *b_local[g_sz];
dim_t m_local[g_sz], n_local[g_sz], k_local[g_sz];
err_t err = lpgemm_translate_to_post_ops_list
(
post_op_unparsed[bs_i], post_op_list[bs_i],
( void* )c[bs_i], ( void* )( (order + bs_i) ),
m[bs_i], n[bs_i]
);
(
post_op_unparsed[gc_i], post_op_list,
( void* )c[gc_i], ( void* )( order + gc_i ),
m[gc_i], n[gc_i]
);
if( err != BLIS_SUCCESS ) goto err_hndl;
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( transa[gc_i], &blis_transa );
bli_param_map_netlib_to_blis_trans( transb[gc_i], &blis_transb );
bool is_column_major = ( ( order[gc_i] == 'c' ) || ( order[gc_i] == 'C' ) );
for( dim_t gs_i = 0; gs_i < g_sz; gs_i++ )
{
if( is_column_major == TRUE )
{
rs_a[gs_i] = ldb[gc_i];
cs_a[gs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_a[gs_i] = 1;
cs_a[gs_i] = ldb[gc_i];
}
rs_b[gs_i] = lda[gc_i];
cs_b[gs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_b[gs_i] = 1;
cs_b[gs_i] = lda[gc_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_b[gs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_a[gs_i]) );
// Inputs swapped in column major, A becomes B from kernel point of view.
// Reorder is not supported for column major matrices.
if ( ( ( mtag_b[gs_i] == REORDERED ) || ( mtag_a[gs_i] == REORDERED ) ) )
{
bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
// Inputs swapped in column major, A becomes B from kernel point of view.
if ( bli_is_trans(blis_transb ) )
{
mtag_a[gs_i] = PACK;
}
if( bli_is_trans(blis_transa ) )
{
mtag_b[gs_i] = PACK;
}
// swap m & n in case of col-major matrices
m_local[gs_i] = n[gc_i];
n_local[gs_i] = m[gc_i];
// swap a & b pointers in case of col-major matrices
a_local[gs_i] = (float*)(b[mat_idx + gs_i]);
b_local[gs_i] = (float*)(a[mat_idx + gs_i]);
}
else // row-major
{
rs_a[gs_i] = lda[gc_i];
cs_a[gs_i] = 1;
if( bli_is_trans( blis_transa ) )
{
rs_a[gs_i] = 1;
cs_a[gs_i] = lda[gc_i];
}
rs_b[gs_i] = ldb[gc_i];
cs_b[gs_i] = 1;
if( bli_is_trans( blis_transb ) )
{
rs_b[gs_i] = 1;
cs_b[gs_i] = ldb[gc_i];
}
bli_param_map_char_to_lpmtag( mem_format_a[gc_i], &(mtag_a[gs_i]) );
bli_param_map_char_to_lpmtag( mem_format_b[gc_i], &(mtag_b[gs_i]) );
// Reorder is not supported for A matrix
if( mtag_a[gs_i] == REORDERED )
{
bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ );
goto err_hndl;
}
// From 5-loop function point of view,
// A matrix when in column major storage needs to be packed to row-major
// storage as kernel expects A matrix to be in row-major format.
if( bli_is_trans(blis_transa ) )
{
mtag_a[gs_i] = PACK;
}
if( bli_is_trans(blis_transb ) && ( mtag_b[gs_i] == UNPACKED ) )
{
mtag_b[gs_i] = PACK;
}
// copy the values of m & n
m_local[gs_i] = m[gc_i];
n_local[gs_i] = n[gc_i];
// copy the values of a & b pointers
a_local[gs_i] = (float*)(a[mat_idx + gs_i]);
b_local[gs_i] = (float*)(b[mat_idx + gs_i]);
}
k_local[gs_i] = k[gc_i];
rs_c[gs_i] = ldc[gc_i];
cs_c[gs_i] = 1;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( F32F32F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_f32f32f32of32_openmp_thread_decorator
(
g_sz, m_local, n_local, k_local,
(const float**)a_local, rs_a, cs_a, mtag_a,
(const float**)b_local, rs_b, cs_b, mtag_b,
&c[mat_idx], rs_c, cs_c,
alpha[gc_i], beta[gc_i],
&rntm_g, lcntx_g,
post_op_list, F32
);
#else
batch_lpgemm_f32f32f32of32_thread_decorator
(
g_sz, m_local, n_local, k_local,
(const float**)a_local, rs_a, cs_a, mtag_a,
(const float**)b_local, rs_b, cs_b, mtag_b,
&c[mat_idx], rs_c, cs_c,
alpha[gc_i], beta[gc_i],
&rntm_g, lcntx_g,
post_op_list, F32
);
#endif
mat_idx += g_sz;
}
// Initialize a local runtime with global settings if necessary. Note
// that in the case that a runtime is passed in, we make a local copy.
rntm_t rntm_g;
bli_rntm_init_from_global( &rntm_g );
bli_pba_rntm_set_pba( &rntm_g );
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( F32F32F32OF32 );
#ifdef BLIS_ENABLE_OPENMP
batch_lpgemm_f32f32f32of32_openmp_thread_decorator
(
batch_size, m_local, n_local, k,
(const float**)a_local, rs_a, cs_a, mtag_a,
(const float**)b_local, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
#else
batch_lpgemm_f32f32f32of32_thread_decorator
(
batch_size, m_local, n_local, k,
(const float**)a_local, rs_a, cs_a, mtag_a,
(const float**)b_local, rs_b, cs_b, mtag_b,
c, rs_c, cs_c,
alpha, beta,
&rntm_g, lcntx_g,
post_op_list, F32
);
#endif
err_hndl:;
LPGEMM_STOP_LOGGER();
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -105,11 +105,11 @@
#define AOCL_BATCH_GEMM_CHECK( op_str, \
order, transa, transb, \
gemm_no, \
group_count, group_size, \
m, n, k, \
a, lda, mtag_a, \
b, ldb, mtag_b, \
c, ldc, \
lda, mtag_a, \
ldb, mtag_b, \
ldc, \
err_no \
) \
{ \
@@ -164,12 +164,14 @@
info = 17; \
else if ( col_stored && ( ldc < m ) ) \
info = 17; \
else if ( group_count < 0 || group_size < 0 ) \
info = 18; \
\
if( info != 0 ) \
{ \
char print_msg[ 150 ]; \
\
sprintf( print_msg, "** On entry to %6s, parameter number %2i of problem %ld had an illegal value", op_str, info, (long int) gemm_no); \
sprintf( print_msg, "** On entry to %6s, parameter number %2i of problem %ld had an illegal value", op_str, info, (long int) group_count); \
bli_print_msg(print_msg, __FILE__, __LINE__); \
err_no = info; \
} \

View File

@@ -192,20 +192,21 @@ BLIS_EXPORT_ADDON void aocl_batch_gemm_ ## LP_SFX \
const char* order, \
const char* transa, \
const char* transb, \
const dim_t batch_size, \
const dim_t* m, \
const dim_t* n, \
const dim_t* k, \
const Sum_type* alpha, \
const A_type** a, \
const dim_t* lda, \
const char* mem_format_a, \
const B_type** b, \
const dim_t* ldb, \
const char* mem_format_b, \
const Sum_type* beta, \
C_type** c, \
const dim_t* ldc, \
const dim_t group_count, \
const dim_t* group_size, \
const char* mem_format_a, \
const char* mem_format_b, \
aocl_post_op** post_op_unparsed \
) \

View File

@@ -320,7 +320,7 @@ void batch_lpgemm_write_logger_gemm_fn
const char* order,
const char* transa,
const char* transb,
const dim_t batch_size,
const dim_t group_count,
const dim_t* m,
const dim_t* n,
const dim_t* k,
@@ -340,8 +340,8 @@ void batch_lpgemm_write_logger_gemm_fn
char post_ops_str[2048] = {0};
fprintf(fd, "%s:bs=%ld\n", op_type, batch_size);
for( dim_t i = 0; i < batch_size; i++ )
fprintf(fd, "%s:group_count=%ld\n", op_type, group_count);
for( dim_t i = 0; i < group_count; i++ )
{
lpgemm_get_pre_ops_str( post_op_unparsed[i], pre_ops_str );
lpgemm_get_post_ops_str( post_op_unparsed[i], post_ops_str );

View File

@@ -71,7 +71,7 @@ void batch_lpgemm_write_logger_gemm_fn
const char* order,
const char* transa,
const char* transb,
const dim_t batch_size,
const dim_t group_count,
const dim_t* m,
const dim_t* n,
const dim_t* k,
@@ -105,7 +105,7 @@ void lpgemm_write_logger_time_break_fn( FILE* fd, double stime );
lpgemm_write_logger_gemm_fn( fd, __VA_ARGS__ ); \
#define BATCH_LPGEMM_WRITE_LOGGER( op_type, order, transa, transb, \
batch_size, m, n, k, \
group_count, group_size, m, n, k, \
alpha, lda, mem_format_a, \
ldb, mem_format_b, beta, \
ldc, post_op_unparsed ) \
@@ -116,14 +116,14 @@ void lpgemm_write_logger_time_break_fn( FILE* fd, double stime );
\
char post_ops_str[2048] = {0}; \
\
fprintf(fd, "%s:bs=%ld\n", op_type, batch_size); \
for( dim_t i = 0; i < batch_size; i++ ) \
fprintf(fd, "%s:group_count=%ld\n", op_type, group_count); \
for( dim_t i = 0; i < group_count; i++ ) \
{ \
lpgemm_get_pre_ops_str( post_op_unparsed[i], pre_ops_str ); \
lpgemm_get_post_ops_str( post_op_unparsed[i], post_ops_str ); \
fprintf( fd, "%c %c %c %c %c %ld %ld %ld %ld %ld %ld "\
fprintf( fd, "%ld %c %c %c %c %c %ld %ld %ld %ld %ld %ld "\
":pre_ops=[%s]:post_ops=[%s] %f %f\n", \
order[i], transa[i], transb[i], mem_format_a[i], mem_format_b[i], \
group_size[i], order[i], transa[i], transb[i], mem_format_a[i], mem_format_b[i], \
m[i], n[i], k[i], lda[i], ldb[i], ldc[i], \
pre_ops_str, post_ops_str, \
(float)(alpha[i]), (float)(beta[i]) ); \
@@ -140,7 +140,7 @@ void lpgemm_write_logger_time_break_fn( FILE* fd, double stime );
#define LPGEMM_WRITE_LOGGER(...)
#define BATCH_LPGEMM_WRITE_LOGGER(op_type, order, transa, transb, \
batch_size, m, n, k, \
group_count, group_size, m, n, k, \
alpha, lda, mem_format_a, \
ldb, mem_format_b, beta, \
ldc, post_op_unparsed)

View File

@@ -361,7 +361,7 @@ err_t lpgemm_translate_to_post_ops_list
NONE, NONE
);
bli_print_msg(" Max supported post-ops is 5, supplied input post-ops" \
bli_print_msg(" Max supported post-ops is 8, supplied input post-ops" \
" are more. Exiting..", __FILE__, __LINE__ );
return BLIS_UNEXPECTED_VECTOR_DIM; //Error, seq length exceeds max post ops permitted.
}

View File

@@ -46,7 +46,7 @@
BLIS_INLINE void calculate_n_threads_per_gemm
(
dim_t batch_size,
dim_t group_size,
dim_t* n_threads,
dim_t* n_gemms_in_parallel,
dim_t* n_threads_per_gemm,
@@ -61,7 +61,7 @@ BLIS_INLINE void calculate_n_threads_per_gemm
}
else if( *n_gemms_in_parallel < 1 )
{
( *n_gemms_in_parallel ) = bli_min( ( *n_threads ), batch_size );
( *n_gemms_in_parallel ) = bli_min( ( *n_threads ), group_size );
}
/* ToDo: All the leftover thrads might go under-utilized. Could be optimized further. */
( *n_threads_per_gemm ) = ( *n_threads ) / ( *n_gemms_in_parallel );
@@ -376,7 +376,7 @@ BLIS_INLINE void lpgemm_s32o32_get_threading
BLIS_INLINE void batch_lpgemm_s32o32_get_threading
(
dim_t batch_size,
dim_t group_size,
dim_t* n_threads,
dim_t* n_gemms_in_parallel,
dim_t* n_threads_per_gemm,
@@ -390,7 +390,7 @@ BLIS_INLINE void batch_lpgemm_s32o32_get_threading
)
{
calculate_n_threads_per_gemm(batch_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
calculate_n_threads_per_gemm(group_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
if ( ( *n_threads_per_gemm ) > 1 )
{
@@ -441,7 +441,7 @@ BLIS_INLINE void batch_lpgemm_s32o32_get_threading
BLIS_INLINE void batch_lpgemm_u8s8s32o32_get_threading
(
dim_t batch_size,
dim_t group_size,
dim_t* n_threads,
dim_t* n_gemms_in_parallel,
dim_t* n_threads_per_gemm,
@@ -455,7 +455,7 @@ BLIS_INLINE void batch_lpgemm_u8s8s32o32_get_threading
{
batch_lpgemm_s32o32_get_threading
(
batch_size,
group_size,
n_threads, n_gemms_in_parallel, n_threads_per_gemm,
ic_ways, jc_ways,
m, n, k, rntm_g,
@@ -465,7 +465,7 @@ BLIS_INLINE void batch_lpgemm_u8s8s32o32_get_threading
BLIS_INLINE void batch_lpgemm_s8s8s32o32_get_threading
(
dim_t batch_size,
dim_t group_size,
dim_t* n_threads,
dim_t* n_gemms_in_parallel,
dim_t* n_threads_per_gemm,
@@ -479,7 +479,7 @@ BLIS_INLINE void batch_lpgemm_s8s8s32o32_get_threading
{
batch_lpgemm_s32o32_get_threading
(
batch_size,
group_size,
n_threads, n_gemms_in_parallel, n_threads_per_gemm,
ic_ways, jc_ways,
m, n, k, rntm_g,
@@ -608,7 +608,7 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading
BLIS_INLINE void batch_lpgemm_bf16bf16f32of32_get_threading
(
dim_t batch_size,
dim_t group_size,
dim_t* n_threads,
dim_t* n_gemms_in_parallel,
dim_t* n_threads_per_gemm,
@@ -621,7 +621,7 @@ BLIS_INLINE void batch_lpgemm_bf16bf16f32of32_get_threading
)
{
calculate_n_threads_per_gemm(batch_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
calculate_n_threads_per_gemm(group_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
/* The user is not allowed to set ic_ways or jc_ways */
if ( ( *n_threads_per_gemm ) > 1 )
@@ -783,7 +783,7 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading
BLIS_INLINE void batch_lpgemm_f32f32f32of32_get_threading
(
dim_t batch_size,
dim_t group_size,
dim_t* n_threads,
dim_t* n_gemms_in_parallel,
dim_t* n_threads_per_gemm,
@@ -796,7 +796,7 @@ BLIS_INLINE void batch_lpgemm_f32f32f32of32_get_threading
)
{
calculate_n_threads_per_gemm(batch_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
calculate_n_threads_per_gemm(group_size, n_threads, n_gemms_in_parallel, n_threads_per_gemm, rntm_g );
// Query the context for SUP limits.
const dim_t MT = lpgemm_get_sup_thres_MT_global_cntx( F32F32F32OF32 );
@@ -1073,7 +1073,7 @@ GEN_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32)
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \
void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
( \
const dim_t batch_size, \
const dim_t group_size, \
const dim_t* m, \
const dim_t* n, \
const dim_t* k, \
@@ -1088,11 +1088,11 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
C_type** c, \
const dim_t* rs_c, \
const dim_t* cs_c, \
const C_type* alpha, \
const C_type* beta, \
const C_type alpha, \
const C_type beta, \
rntm_t* rntm_g, \
lpgemm_cntx_t* lcntx, \
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
lpgemm_post_op(*post_op_list), \
AOCL_STORAGE_TYPE c_downscale \
) \
{ \
@@ -1110,7 +1110,7 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
/* Assuming all the problems in GEMM are of same size */ \
batch_lpgemm_ ## LPGEMM_SFX ## _get_threading \
( \
batch_size, \
group_size, \
&n_threads, \
&n_gemms_in_parallel, \
&n_threads_per_gemm, \
@@ -1163,7 +1163,7 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
thrinfo_t thrinfo; \
thrinfo.n_way = n_gemms_in_parallel; \
thrinfo.work_id = omp_get_thread_num() / n_threads_per_gemm; \
bli_thread_range_sub( &thrinfo, batch_size, 1, FALSE, &gemm_start, &gemm_end ); \
bli_thread_range_sub( &thrinfo, group_size, 1, FALSE, &gemm_start, &gemm_end ); \
\
for( dim_t i = gemm_start; i < gemm_end; i++ ) \
{ \
@@ -1173,12 +1173,12 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
a[i], rs_a[i], cs_a[i], mtag_a[i], \
b[i], rs_b[i], cs_b[i], mtag_b[i], \
c[i], rs_c[i], cs_c[i],\
alpha[i], \
beta[i], \
alpha, \
beta, \
&rntm_l, \
&thread, \
lcntx, \
post_op_list[i], c_downscale \
post_op_list, c_downscale \
); \
} \
} \
@@ -1420,7 +1420,7 @@ GEN_LPGEMM_OPENMP_DECORATOR_GRP(int8_t, int8_t, int32_t, s8s8s32o32_sym_quant)
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_MP(A_type,B_type,C_type,LPGEMM_SFX, LPGEMM_PARENT_SFX) \
void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
( \
const dim_t batch_size, \
const dim_t group_size, \
const dim_t* m, \
const dim_t* n, \
const dim_t* k, \
@@ -1435,12 +1435,12 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
C_type** c, \
const dim_t* rs_c, \
const dim_t* cs_c, \
const C_type* alpha, \
const C_type* beta, \
const C_type alpha, \
const C_type beta, \
rntm_t* rntm_g, \
lpgemm_cntx_t* lcntx, \
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
lpgemm_pre_op(*pre_op_list), \
lpgemm_post_op(*post_op_list), \
AOCL_STORAGE_TYPE c_downscale \
) \
{ \
@@ -1458,7 +1458,7 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
/* Assuming all the problems in GEMM are of same size */ \
batch_lpgemm_ ## LPGEMM_PARENT_SFX ## _get_threading \
( \
batch_size, \
group_size, \
&n_threads, \
&n_gemms_in_parallel, \
&n_threads_per_gemm, \
@@ -1513,7 +1513,7 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
thrinfo_t thrinfo; \
thrinfo.n_way = n_gemms_in_parallel; \
thrinfo.work_id = omp_get_thread_num() / n_threads_per_gemm; \
bli_thread_range_sub( &thrinfo, batch_size, 1, FALSE, &gemm_start, &gemm_end ); \
bli_thread_range_sub( &thrinfo, group_size, 1, FALSE, &gemm_start, &gemm_end ); \
\
for( dim_t i = gemm_start; i < gemm_end; i++ ) \
{ \
@@ -1534,13 +1534,13 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
a[i], rs_a[i], cs_a[i], mtag_a[i], \
b[i], rs_b[i], cs_b[i], mtag_b[i], \
c[i], rs_c[i], cs_c[i],\
alpha[i], \
beta[i], \
alpha, \
beta, \
&rntm_l, \
&thread, \
lcntx, \
pre_op_list[i], \
post_op_list[i], c_downscale \
pre_op_list, \
post_op_list, c_downscale \
); \
} \
} \
@@ -1972,7 +1972,7 @@ GEN_LPGEMM_DECORATOR2(int8_t, int8_t, int32_t, s8s8s32o32_sym_quant)
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \
void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
( \
const dim_t batch_size, \
const dim_t group_size, \
const dim_t* m, \
const dim_t* n, \
const dim_t* k, \
@@ -1987,11 +1987,11 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
C_type** c, \
const dim_t* rs_c, \
const dim_t* cs_c, \
const C_type* alpha, \
const C_type* beta, \
const C_type alpha, \
const C_type beta, \
rntm_t* rntm_g, \
lpgemm_cntx_t* lcntx, \
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
lpgemm_post_op(*post_op_list), \
AOCL_STORAGE_TYPE c_downscale \
) \
{ \
@@ -2020,7 +2020,7 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
thread.jc_ways = jc_ways; \
thread.comm = cur_lpgemm_comm; \
dim_t gemm_start = 0; \
dim_t gemm_end = batch_size; \
dim_t gemm_end = group_size; \
\
for( dim_t i = gemm_start; i < gemm_end; i++ ) \
{ \
@@ -2030,12 +2030,12 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
a[i], rs_a[i], cs_a[i], mtag_a[i], \
b[i], rs_b[i], cs_b[i], mtag_b[i], \
c[i], rs_c[i], cs_c[i],\
alpha[i], \
beta[i], \
alpha, \
beta, \
rntm_g, \
&thread, \
lcntx, \
post_op_list[i], c_downscale \
post_op_list, c_downscale \
); \
} \
} \
@@ -2048,7 +2048,7 @@ GEN_BATCH_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32)
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_MP(A_type,B_type,C_type,LPGEMM_SFX) \
void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
( \
const dim_t batch_size, \
const dim_t group_size, \
const dim_t* m, \
const dim_t* n, \
const dim_t* k, \
@@ -2063,12 +2063,12 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
C_type** c, \
const dim_t* rs_c, \
const dim_t* cs_c, \
const C_type* alpha, \
const C_type* beta, \
const C_type alpha, \
const C_type beta, \
rntm_t* rntm_g, \
lpgemm_cntx_t* lcntx, \
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
lpgemm_pre_op(*pre_op_list), \
lpgemm_post_op(*post_op_list), \
AOCL_STORAGE_TYPE c_downscale \
) \
{ \
@@ -2097,7 +2097,7 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
thread.jc_ways = jc_ways; \
thread.comm = cur_lpgemm_comm; \
dim_t gemm_start = 0; \
dim_t gemm_end = batch_size; \
dim_t gemm_end = group_size; \
\
for( dim_t i = gemm_start; i < gemm_end; i++ ) \
{ \
@@ -2107,13 +2107,13 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
a[i], rs_a[i], cs_a[i], mtag_a[i], \
b[i], rs_b[i], cs_b[i], mtag_b[i], \
c[i], rs_c[i], cs_c[i],\
alpha[i], \
beta[i], \
alpha, \
beta, \
rntm_g, \
&thread, \
lcntx, \
pre_op_list[i], \
post_op_list[i], c_downscale \
pre_op_list, \
post_op_list, c_downscale \
); \
} \
} \

View File

@@ -75,7 +75,7 @@ GEN_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32)
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(A_type,B_type,C_type,LPGEMM_SFX) \
void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
( \
const dim_t batch_size, \
const dim_t group_size, \
const dim_t* m, \
const dim_t* n, \
const dim_t* k, \
@@ -90,11 +90,11 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
C_type** c, \
const dim_t* rs_c, \
const dim_t* cs_c, \
const C_type* alpha, \
const C_type* beta, \
const C_type alpha, \
const C_type beta, \
rntm_t* rntm_g, \
lpgemm_cntx_t* lcntx, \
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
lpgemm_post_op(*post_op_list), \
AOCL_STORAGE_TYPE c_downscale \
); \
@@ -107,7 +107,7 @@ GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32)
#define GEN_BATCH_LPGEMM_OPENMP_DECORATOR_FN_MXP(A_type,B_type,C_type,LPGEMM_SFX) \
void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
( \
const dim_t batch_size, \
const dim_t group_size, \
const dim_t* m, \
const dim_t* n, \
const dim_t* k, \
@@ -122,12 +122,12 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \
C_type** c, \
const dim_t* rs_c, \
const dim_t* cs_c, \
const C_type* alpha, \
const C_type* beta, \
const C_type alpha, \
const C_type beta, \
rntm_t* rntm_g, \
lpgemm_cntx_t* lcntx, \
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
lpgemm_pre_op(*pre_op_list), \
lpgemm_post_op(*post_op_list), \
AOCL_STORAGE_TYPE c_downscale \
); \
@@ -261,11 +261,11 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
C_type** c, \
const dim_t* rs_c, \
const dim_t* cs_c, \
const C_type* alpha, \
const C_type* beta, \
const C_type alpha, \
const C_type beta, \
rntm_t* rntm_g, \
lpgemm_cntx_t* lcntx, \
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
lpgemm_post_op(*post_op_list), \
AOCL_STORAGE_TYPE c_downscale \
); \
@@ -292,12 +292,12 @@ void batch_lpgemm_ ## LPGEMM_SFX ## _thread_decorator \
C_type** c, \
const dim_t* rs_c, \
const dim_t* cs_c, \
const C_type* alpha, \
const C_type* beta, \
const C_type alpha, \
const C_type beta, \
rntm_t* rntm_g, \
lpgemm_cntx_t* lcntx, \
lpgemm_pre_op(*pre_op_list)[AOCL_MAX_PRE_OPS], \
lpgemm_post_op(*post_op_list)[AOCL_MAX_POST_OPS], \
lpgemm_pre_op(*pre_op_list), \
lpgemm_post_op(*post_op_list), \
AOCL_STORAGE_TYPE c_downscale \
); \