Avoid gemmsup barriers when not packing A or B. (#622)

Details:
- Implemented a multithreaded optimization for the special (and common)
  case of employing the gemmsup code path when the user requests
  (implicitly or explicitly) that neither A nor B be packed during
  computation. This optimization takes the form of a greatly reduced
  code branch in bli_thrinfo_sup_create_for_cntl(), which avoids a
  broadcast and two barriers, and results in higher performance when
  obtaining two-way or higher parallelism within BLIS. Thanks to
  Bhaskar Nallani of AMD for proposing this change via issue #605.
- Added an early return branch to bli_thrinfo_create_for_cntl() that
  detects and quickly handles cases where no parallelism is being
  obtained within BLIS (i.e., single-threaded execution). Note that
  this special case handling was/is already present in
  bli_thrinfo_sup_create_for_cntl().
- CREDITS file update.
This commit is contained in:
Field G. Van Zee
2022-03-11 13:28:50 -06:00
committed by GitHub
parent cad10410b2
commit 7c07b477e4
3 changed files with 133 additions and 71 deletions

View File

@@ -64,6 +64,7 @@ but many others have contributed code and feedback, including
Simon Lukas Märtens @ACSimon33 (RWTH Aachen University)
Devin Matthews @devinamatthews (The University of Texas at Austin)
Stefanos Mavros @smavros
Mithun Mohan @MithunMohanKadavil (AMD)
Ilknur Mustafazade @Runkli
@nagsingh
Bhaskar Nallani @BhaskarNallani (AMD)

View File

@@ -298,6 +298,24 @@ thrinfo_t* bli_thrinfo_create_for_cntl
thrinfo_t* thread_par
)
{
// If we are running with a single thread, all of the code can be reduced
// and simplified to this.
if ( bli_rntm_calc_num_threads( rntm ) == 1 )
{
thrinfo_t* thread_chl = bli_thrinfo_create
(
rntm, // rntm
&BLIS_SINGLE_COMM, // ocomm
0, // ocomm_id
1, // n_way
0, // work_id
FALSE, // free_comm
BLIS_NO_PART, // bszid
NULL // sub_node
);
return thread_chl;
}
thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ];
thrcomm_t** new_comms = NULL;

View File

@@ -145,7 +145,6 @@ thrinfo_t* bli_thrinfo_sup_create_for_cntl
thrinfo_t* thread_par
)
{
#if 1
// If we are running with a single thread, all of the code can be reduced
// and simplified to this.
if ( bli_rntm_calc_num_threads( rntm ) == 1 )
@@ -163,84 +162,128 @@ thrinfo_t* bli_thrinfo_sup_create_for_cntl
);
return thread_chl;
}
#endif
thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ];
thrcomm_t** new_comms = NULL;
// The remainder of this function handles the cases involving the use of
// multiple BLIS threads.
const dim_t parent_nt_in = bli_thread_num_threads( thread_par );
const dim_t parent_n_way = bli_thread_n_way( thread_par );
const dim_t parent_comm_id = bli_thread_ocomm_id( thread_par );
const dim_t parent_work_id = bli_thread_work_id( thread_par );
// Sanity check: make sure the number of threads in the parent's
// communicator is divisible by the number of new sub-groups.
if ( parent_nt_in % parent_n_way != 0 )
if ( bli_rntm_pack_a( rntm ) == FALSE &&
bli_rntm_pack_b( rntm ) == FALSE )
{
printf( "Assertion failed: parent_nt_in <mod> parent_n_way != 0\n" );
bli_abort();
}
// If we are packing neither A nor B, there are no broadcasts or barriers
// needed to synchronize threads (since all threads can work completely
// independently). In this special case situation, the thrinfo_t can be
// created with much simpler logic.
// Compute:
// - the number of threads inside the new child comm,
// - the current thread's id within the new communicator,
// - the current thread's work id, given the ways of parallelism
// to be obtained within the next loop.
const dim_t child_nt_in = bli_rntm_calc_num_threads_in( bszid_chl, rntm );
const dim_t child_n_way = bli_rntm_ways_for( *bszid_chl, rntm );
const dim_t child_comm_id = parent_comm_id % child_nt_in;
const dim_t child_work_id = child_comm_id / ( child_nt_in / child_n_way );
const dim_t parent_comm_id = bli_thread_ocomm_id( thread_par );
// Compute:
// - the number of threads inside the new child comm,
// - the current thread's id within the new communicator,
// - the current thread's work id, given the ways of parallelism
// to be obtained within the next loop.
const dim_t child_nt_in = bli_rntm_calc_num_threads_in( bszid_chl, rntm );
const dim_t child_n_way = bli_rntm_ways_for( *bszid_chl, rntm );
const dim_t child_comm_id = parent_comm_id % child_nt_in;
const dim_t child_work_id = child_comm_id / ( child_nt_in / child_n_way );
// All threads create a new thrinfo_t node using the communicator
// that was created by their chief, as identified by parent_work_id.
thrinfo_t* thread_chl = bli_thrinfo_create
(
rntm, // rntm
NULL, // ocomm
child_comm_id, // ocomm_id
child_n_way, // n_way
child_work_id, // work_id
TRUE, // free_comm
*bszid_chl, // bszid
NULL // sub_node
);
return thread_chl;
}
else
{
// If we are packing at least one of A or B, then we use the general
// approach that employs broadcasts and barriers.
thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ];
thrcomm_t** new_comms = NULL;
const dim_t parent_nt_in = bli_thread_num_threads( thread_par );
const dim_t parent_n_way = bli_thread_n_way( thread_par );
const dim_t parent_comm_id = bli_thread_ocomm_id( thread_par );
const dim_t parent_work_id = bli_thread_work_id( thread_par );
// Sanity check: make sure the number of threads in the parent's
// communicator is divisible by the number of new sub-groups.
if ( parent_nt_in % parent_n_way != 0 )
{
printf( "Assertion failed: parent_nt_in <mod> parent_n_way != 0\n" );
bli_abort();
}
// Compute:
// - the number of threads inside the new child comm,
// - the current thread's id within the new communicator,
// - the current thread's work id, given the ways of parallelism
// to be obtained within the next loop.
const dim_t child_nt_in = bli_rntm_calc_num_threads_in( bszid_chl, rntm );
const dim_t child_n_way = bli_rntm_ways_for( *bszid_chl, rntm );
const dim_t child_comm_id = parent_comm_id % child_nt_in;
const dim_t child_work_id = child_comm_id / ( child_nt_in / child_n_way );
//printf( "thread %d: child_n_way = %d child_nt_in = %d parent_n_way = %d (bszid = %d->%d)\n", (int)child_comm_id, (int)child_nt_in, (int)child_n_way, (int)parent_n_way, (int)bli_cntl_bszid( cntl_par ), (int)bszid_chl );
// The parent's chief thread creates a temporary array of thrcomm_t
// pointers.
if ( bli_thread_am_ochief( thread_par ) )
{
err_t r_val;
// The parent's chief thread creates a temporary array of thrcomm_t
// pointers.
if ( bli_thread_am_ochief( thread_par ) )
{
err_t r_val;
if ( parent_n_way > BLIS_NUM_STATIC_COMMS )
new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ), &r_val );
else
new_comms = static_comms;
if ( parent_n_way > BLIS_NUM_STATIC_COMMS )
new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ), &r_val );
else
new_comms = static_comms;
}
// Broadcast the temporary array to all threads in the parent's
// communicator.
new_comms = bli_thread_broadcast( thread_par, new_comms );
// Chiefs in the child communicator allocate the communicator
// object and store it in the array element corresponding to the
// parent's work id.
if ( child_comm_id == 0 )
new_comms[ parent_work_id ] = bli_thrcomm_create( rntm, child_nt_in );
bli_thread_barrier( thread_par );
// All threads create a new thrinfo_t node using the communicator
// that was created by their chief, as identified by parent_work_id.
thrinfo_t* thread_chl = bli_thrinfo_create
(
rntm, // rntm
new_comms[ parent_work_id ], // ocomm
child_comm_id, // ocomm_id
child_n_way, // n_way
child_work_id, // work_id
TRUE, // free_comm
*bszid_chl, // bszid
NULL // sub_node
);
bli_thread_barrier( thread_par );
// The parent's chief thread frees the temporary array of thrcomm_t
// pointers.
if ( bli_thread_am_ochief( thread_par ) )
{
if ( parent_n_way > BLIS_NUM_STATIC_COMMS )
bli_free_intl( new_comms );
}
return thread_chl;
}
// Broadcast the temporary array to all threads in the parent's
// communicator.
new_comms = bli_thread_broadcast( thread_par, new_comms );
// Chiefs in the child communicator allocate the communicator
// object and store it in the array element corresponding to the
// parent's work id.
if ( child_comm_id == 0 )
new_comms[ parent_work_id ] = bli_thrcomm_create( rntm, child_nt_in );
bli_thread_barrier( thread_par );
// All threads create a new thrinfo_t node using the communicator
// that was created by their chief, as identified by parent_work_id.
thrinfo_t* thread_chl = bli_thrinfo_create
(
rntm, // rntm
new_comms[ parent_work_id ], // ocomm
child_comm_id, // ocomm_id
child_n_way, // n_way
child_work_id, // work_id
TRUE, // free_comm
*bszid_chl, // bszid
NULL // sub_node
);
bli_thread_barrier( thread_par );
// The parent's chief thread frees the temporary array of thrcomm_t
// pointers.
if ( bli_thread_am_ochief( thread_par ) )
{
if ( parent_n_way > BLIS_NUM_STATIC_COMMS )
bli_free_intl( new_comms );
}
return thread_chl;
}