Merge "SUP GEMM - Enable only block panel (var2m)" into amd-staging-milan-3.1

This commit is contained in:
Kiran Varaganti
2021-05-31 06:46:04 -04:00
committed by Gerrit Code Review

View File

@@ -47,9 +47,20 @@ err_t bli_gemmsup_int
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4);
// AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_4, alpha, a, b, beta, c);
#if 0
#ifdef BLIS_CONFIG_EPYC
const num_t dt = bli_obj_dt( c );
const dim_t m = bli_obj_length( c );
const dim_t n = bli_obj_width( c );
const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx );
const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx );
const bool auto_factor = bli_rntm_auto_factor( rntm );
const dim_t n_threads = bli_rntm_num_threads( rntm );
dim_t jc_new;
dim_t ic_new;
//bli_gemmsup_ref_var2
//bli_gemmsup_ref_var1
#if 0
@@ -61,23 +72,82 @@ err_t bli_gemmsup_int
stor_id == BLIS_RRC ||
stor_id == BLIS_RCR ||
stor_id == BLIS_CRR );
#ifdef TRACEVAR
if ( bli_thread_am_ochief( thread ) )
printf( "bli_l3_sup_int(): var2m primary\n" );
#endif
// Don't use the small/unpacked implementation if one of the matrices
// uses general stride.
if ( stor_id == BLIS_XXX ) {
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_4, "SUP doesn't support general stide.");
return BLIS_FAILURE;
}
if ( is_rrr_rrc_rcr_crr )
{
bli_gemmsup_ref_var2m
(
BLIS_NO_TRANSPOSE, alpha, a, b, beta, c, stor_id, cntx, rntm
);
// This branch handles:
// - rrr rrc rcr crr for row-preferential kernels
// - rcc crc ccr ccc for column-preferential kernels
// - Currently only row-preferential kernels are only supported.
// calculate number of micropanels in m and n dimensions and
// recalculate the automatic thread factorization based on these number of micropanels
const dim_t mu = m / MR;
const dim_t nu = n / NR;
// If the parallel thread factorization was automatic, we update it
// with a new factorization based on the matrix dimensions in units
// of micropanels.
if ( auto_factor )
{
// In the block-panel algorithm, the m dimension is parallelized
// with ic_nt and the n dimension is parallelized with jc_nt.
bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new );
// Update the ways of parallelism for the jc and ic loops, and then
// update the current thread's root thrinfo_t node according to the
// new ways of parallelism value for the jc loop.
bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm );
bli_l3_sup_thrinfo_update_root( rntm, thread );
}
bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE,
alpha, a, b, beta, c,
stor_id, cntx, rntm, thread );
}
else
{
bli_gemmsup_ref_var2m
(
BLIS_TRANSPOSE, alpha, a, b, beta, c, stor_id, cntx, rntm
);
// This branch handles:
// - rrr rrc rcr crr for column-preferential kernels
// - rcc crc ccr ccc for row-preferential kernels
// - Currently only row-preferential kernels are only supported.
const dim_t mu = n / MR; // the n becomes m after a transposition
const dim_t nu = m / NR; // the m becomes n after a transposition
if ( auto_factor )
{
// In the block-panel algorithm, the m dimension is parallelized
// with ic_nt and the n dimension is parallelized with jc_nt.
bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new );
// Update the ways of parallelism for the jc and ic loops, and then
// update the current thread's root thrinfo_t node according to the
// new ways of parallelism value for the jc loop.
bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm );
bli_l3_sup_thrinfo_update_root( rntm, thread );
}
bli_gemmsup_ref_var2m( BLIS_TRANSPOSE,
alpha, a, b, beta, c,
stor_id, cntx, rntm, thread );
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4);
return BLIS_SUCCESS;
#endif
#else // #ifdef BLIS_CONFIG_EPYC
const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b );
@@ -92,6 +162,7 @@ err_t bli_gemmsup_int
stor_id == BLIS_RRC ||
stor_id == BLIS_RCR ||
stor_id == BLIS_CRR );
const bool is_rcc_crc_ccr_ccc = !is_rrr_rrc_rcr_crr;
const num_t dt = bli_obj_dt( c );
@@ -110,7 +181,6 @@ err_t bli_gemmsup_int
dim_t jc_new;
dim_t ic_new;
if ( is_primary )
{
// This branch handles:
@@ -252,6 +322,8 @@ err_t bli_gemmsup_int
// Return success so that the caller knows that we computed the solution.
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4)
return BLIS_SUCCESS;
#endif
}
// -----------------------------------------------------------------------------