mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Merge pull request #94 from flame/distcomm
Implemented distributed thrinfo_t management.
This commit is contained in:
@@ -34,12 +34,11 @@
|
||||
|
||||
#include "blis.h"
|
||||
|
||||
#if 0
|
||||
thrinfo_t* bli_packm_thrinfo_create
|
||||
(
|
||||
thrcomm_t* ocomm,
|
||||
dim_t ocomm_id,
|
||||
thrcomm_t* icomm,
|
||||
dim_t icomm_id,
|
||||
dim_t n_way,
|
||||
dim_t work_id,
|
||||
thrinfo_t* sub_node
|
||||
@@ -51,7 +50,6 @@ thrinfo_t* bli_packm_thrinfo_create
|
||||
(
|
||||
thread,
|
||||
ocomm, ocomm_id,
|
||||
icomm, icomm_id,
|
||||
n_way,
|
||||
work_id,
|
||||
FALSE,
|
||||
@@ -60,14 +58,13 @@ thrinfo_t* bli_packm_thrinfo_create
|
||||
|
||||
return thread;
|
||||
}
|
||||
#endif
|
||||
|
||||
void bli_packm_thrinfo_init
|
||||
(
|
||||
thrinfo_t* thread,
|
||||
thrcomm_t* ocomm,
|
||||
dim_t ocomm_id,
|
||||
thrcomm_t* icomm,
|
||||
dim_t icomm_id,
|
||||
dim_t n_way,
|
||||
dim_t work_id,
|
||||
thrinfo_t* sub_node
|
||||
@@ -77,7 +74,6 @@ void bli_packm_thrinfo_init
|
||||
(
|
||||
thread,
|
||||
ocomm, ocomm_id,
|
||||
icomm, icomm_id,
|
||||
n_way, work_id,
|
||||
FALSE,
|
||||
sub_node
|
||||
@@ -93,13 +89,13 @@ void bli_packm_thrinfo_init_single
|
||||
(
|
||||
thread,
|
||||
&BLIS_SINGLE_COMM, 0,
|
||||
&BLIS_SINGLE_COMM, 0,
|
||||
1,
|
||||
0,
|
||||
NULL
|
||||
);
|
||||
}
|
||||
|
||||
#if 0
|
||||
void bli_packm_thrinfo_free
|
||||
(
|
||||
thrinfo_t* thread
|
||||
@@ -109,4 +105,4 @@ void bli_packm_thrinfo_free
|
||||
thread != &BLIS_PACKM_SINGLE_THREADED )
|
||||
bli_free_intl( thread );
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -42,24 +42,22 @@
|
||||
// thrinfo_t APIs specific to packm.
|
||||
//
|
||||
|
||||
#if 0
|
||||
thrinfo_t* bli_packm_thrinfo_create
|
||||
(
|
||||
thrcomm_t* ocomm,
|
||||
dim_t ocomm_id,
|
||||
thrcomm_t* icomm,
|
||||
dim_t icomm_id,
|
||||
dim_t n_way,
|
||||
dim_t work_id,
|
||||
thrinfo_t* sub_node
|
||||
);
|
||||
#endif
|
||||
|
||||
void bli_packm_thrinfo_init
|
||||
(
|
||||
thrinfo_t* thread,
|
||||
thrcomm_t* ocomm,
|
||||
dim_t ocomm_id,
|
||||
thrcomm_t* icomm,
|
||||
dim_t icomm_id,
|
||||
dim_t n_way,
|
||||
dim_t work_id,
|
||||
thrinfo_t* sub_node
|
||||
@@ -70,8 +68,10 @@ void bli_packm_thrinfo_init_single
|
||||
thrinfo_t* thread
|
||||
);
|
||||
|
||||
#if 0
|
||||
void bli_packm_thrinfo_free
|
||||
(
|
||||
thrinfo_t* thread
|
||||
);
|
||||
#endif
|
||||
|
||||
|
||||
@@ -35,12 +35,11 @@
|
||||
#include "blis.h"
|
||||
#include "assert.h"
|
||||
|
||||
#if 0
|
||||
thrinfo_t* bli_l3_thrinfo_create
|
||||
(
|
||||
thrcomm_t* ocomm,
|
||||
dim_t ocomm_id,
|
||||
thrcomm_t* icomm,
|
||||
dim_t icomm_id,
|
||||
dim_t n_way,
|
||||
dim_t work_id,
|
||||
thrinfo_t* sub_node
|
||||
@@ -49,21 +48,19 @@ thrinfo_t* bli_l3_thrinfo_create
|
||||
return bli_thrinfo_create
|
||||
(
|
||||
ocomm, ocomm_id,
|
||||
icomm, icomm_id,
|
||||
n_way,
|
||||
work_id,
|
||||
TRUE,
|
||||
sub_node
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
void bli_l3_thrinfo_init
|
||||
(
|
||||
thrinfo_t* thread,
|
||||
thrcomm_t* ocomm,
|
||||
dim_t ocomm_id,
|
||||
thrcomm_t* icomm,
|
||||
dim_t icomm_id,
|
||||
dim_t n_way,
|
||||
dim_t work_id,
|
||||
thrinfo_t* sub_node
|
||||
@@ -73,7 +70,6 @@ void bli_l3_thrinfo_init
|
||||
(
|
||||
thread,
|
||||
ocomm, ocomm_id,
|
||||
icomm, icomm_id,
|
||||
n_way,
|
||||
work_id,
|
||||
TRUE,
|
||||
@@ -105,14 +101,12 @@ void bli_l3_thrinfo_free
|
||||
// is marked as needing them to be freed. The most common example of
|
||||
// thrinfo_t nodes NOT marked as needing their comms freed are those
|
||||
// associated with packm thrinfo_t nodes.
|
||||
if ( bli_thrinfo_needs_free_comms( thread ) )
|
||||
if ( bli_thrinfo_needs_free_comm( thread ) )
|
||||
{
|
||||
// The ochief always frees his communicator, and the ichief free its
|
||||
// communicator if we are at the leaf node.
|
||||
if ( bli_thread_am_ochief( thread ) )
|
||||
bli_thrcomm_free( bli_thrinfo_ocomm( thread ) );
|
||||
if ( thrinfo_sub_node == NULL && bli_thread_am_ichief( thread ) )
|
||||
bli_thrcomm_free( bli_thrinfo_icomm( thread ) );
|
||||
}
|
||||
|
||||
// Free all children of the current thrinfo_t.
|
||||
@@ -124,117 +118,208 @@ void bli_l3_thrinfo_free
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
//#define PRINT_THRINFO
|
||||
|
||||
thrinfo_t** bli_l3_thrinfo_create_paths
|
||||
void bli_l3_thrinfo_create_root
|
||||
(
|
||||
opid_t l3_op,
|
||||
side_t side
|
||||
dim_t id,
|
||||
thrcomm_t* gl_comm,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t** thread
|
||||
)
|
||||
{
|
||||
dim_t jc_in, jc_way;
|
||||
dim_t kc_in, kc_way;
|
||||
dim_t ic_in, ic_way;
|
||||
dim_t jr_in, jr_way;
|
||||
dim_t ir_in, ir_way;
|
||||
// Query the global communicator for the total number of threads to use.
|
||||
dim_t n_threads = bli_thrcomm_num_threads( gl_comm );
|
||||
|
||||
#ifdef BLIS_ENABLE_MULTITHREADING
|
||||
jc_in = bli_env_read_nway( "BLIS_JC_NT" );
|
||||
//kc_way = bli_env_read_nway( "BLIS_KC_NT" );
|
||||
kc_in = 1;
|
||||
ic_in = bli_env_read_nway( "BLIS_IC_NT" );
|
||||
jr_in = bli_env_read_nway( "BLIS_JR_NT" );
|
||||
ir_in = bli_env_read_nway( "BLIS_IR_NT" );
|
||||
#else
|
||||
jc_in = 1;
|
||||
kc_in = 1;
|
||||
ic_in = 1;
|
||||
jr_in = 1;
|
||||
ir_in = 1;
|
||||
#endif
|
||||
// Use the thread id passed in as the global communicator id.
|
||||
dim_t gl_comm_id = id;
|
||||
|
||||
if ( l3_op == BLIS_TRMM )
|
||||
{
|
||||
// We reconfigure the parallelism for trmm_r due to a dependency in
|
||||
// the jc loop. (NOTE: This dependency does not exist for trmm3.)
|
||||
if ( bli_is_right( side ) )
|
||||
{
|
||||
jc_way = 1;
|
||||
kc_way = kc_in;
|
||||
ic_way = ic_in;
|
||||
jr_way = jr_in * jc_in;
|
||||
ir_way = ir_in;
|
||||
}
|
||||
else // if ( bli_is_left( side ) )
|
||||
{
|
||||
jc_way = jc_in;
|
||||
kc_way = kc_in;
|
||||
ic_way = ic_in;
|
||||
jr_way = jr_in;
|
||||
ir_way = ir_in;
|
||||
}
|
||||
}
|
||||
else if ( l3_op == BLIS_TRSM )
|
||||
{
|
||||
if ( bli_is_right( side ) )
|
||||
{
|
||||
// Use the blocksize id of the current (root) control tree node to
|
||||
// query the top-most ways of parallelism to obtain.
|
||||
bszid_t bszid = bli_cntl_bszid( cntl );
|
||||
dim_t xx_way = bli_cntx_way_for_bszid( bszid, cntx );
|
||||
|
||||
jc_way = 1;
|
||||
kc_way = 1;
|
||||
ic_way = jc_in * ic_in * jr_in;
|
||||
jr_way = 1;
|
||||
ir_way = 1;
|
||||
}
|
||||
else // if ( bli_is_left( side ) )
|
||||
{
|
||||
jc_way = 1;
|
||||
kc_way = 1;
|
||||
ic_way = 1;
|
||||
jr_way = ic_in * jr_in * ir_in;
|
||||
ir_way = 1;
|
||||
}
|
||||
}
|
||||
else // all other level-3 operations
|
||||
// Determine the work id for this thrinfo_t node.
|
||||
dim_t work_id = gl_comm_id / ( n_threads / xx_way );
|
||||
|
||||
// Create the root thrinfo_t node.
|
||||
*thread = bli_thrinfo_create
|
||||
(
|
||||
gl_comm,
|
||||
gl_comm_id,
|
||||
xx_way,
|
||||
work_id,
|
||||
TRUE,
|
||||
NULL
|
||||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
void bli_l3_thrinfo_print_paths
|
||||
(
|
||||
thrinfo_t** threads
|
||||
)
|
||||
{
|
||||
dim_t n_threads = bli_thread_num_threads( threads[0] );
|
||||
dim_t gl_comm_id;
|
||||
|
||||
thrinfo_t* jc_info = threads[0];
|
||||
thrinfo_t* pc_info = bli_thrinfo_sub_node( jc_info );
|
||||
thrinfo_t* pb_info = bli_thrinfo_sub_node( pc_info );
|
||||
thrinfo_t* ic_info = bli_thrinfo_sub_node( pb_info );
|
||||
thrinfo_t* pa_info = bli_thrinfo_sub_node( ic_info );
|
||||
thrinfo_t* jr_info = bli_thrinfo_sub_node( pa_info );
|
||||
thrinfo_t* ir_info = bli_thrinfo_sub_node( jr_info );
|
||||
|
||||
dim_t jc_way = bli_thread_n_way( jc_info );
|
||||
dim_t pc_way = bli_thread_n_way( pc_info );
|
||||
dim_t pb_way = bli_thread_n_way( pb_info );
|
||||
dim_t ic_way = bli_thread_n_way( ic_info );
|
||||
dim_t pa_way = bli_thread_n_way( pa_info );
|
||||
dim_t jr_way = bli_thread_n_way( jr_info );
|
||||
dim_t ir_way = bli_thread_n_way( ir_info );
|
||||
|
||||
dim_t gl_nt = bli_thread_num_threads( jc_info );
|
||||
dim_t jc_nt = bli_thread_num_threads( pc_info );
|
||||
dim_t pc_nt = bli_thread_num_threads( pb_info );
|
||||
dim_t pb_nt = bli_thread_num_threads( ic_info );
|
||||
dim_t ic_nt = bli_thread_num_threads( pa_info );
|
||||
dim_t pa_nt = bli_thread_num_threads( jr_info );
|
||||
dim_t jr_nt = bli_thread_num_threads( ir_info );
|
||||
|
||||
printf( " gl jc kc pb ic pa jr ir\n" );
|
||||
printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
|
||||
gl_nt, jc_nt, pc_nt, pb_nt, ic_nt, pa_nt, jr_nt, (dim_t)1 );
|
||||
printf( "\n" );
|
||||
printf( " jc kc pb ic pa jr ir\n" );
|
||||
printf( "xx_way: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
|
||||
jc_way, pc_way, pb_way, ic_way, pa_way, jr_way, ir_way );
|
||||
printf( "=================================================\n" );
|
||||
|
||||
for ( gl_comm_id = 0; gl_comm_id < n_threads; ++gl_comm_id )
|
||||
{
|
||||
jc_way = jc_in;
|
||||
kc_way = kc_in;
|
||||
ic_way = ic_in;
|
||||
jr_way = jr_in;
|
||||
ir_way = ir_in;
|
||||
jc_info = threads[gl_comm_id];
|
||||
pc_info = bli_thrinfo_sub_node( jc_info );
|
||||
pb_info = bli_thrinfo_sub_node( pc_info );
|
||||
ic_info = bli_thrinfo_sub_node( pb_info );
|
||||
pa_info = bli_thrinfo_sub_node( ic_info );
|
||||
jr_info = bli_thrinfo_sub_node( pa_info );
|
||||
ir_info = bli_thrinfo_sub_node( jr_info );
|
||||
|
||||
dim_t gl_comm_id = bli_thread_ocomm_id( jc_info );
|
||||
dim_t jc_comm_id = bli_thread_ocomm_id( pc_info );
|
||||
dim_t pc_comm_id = bli_thread_ocomm_id( pb_info );
|
||||
dim_t pb_comm_id = bli_thread_ocomm_id( ic_info );
|
||||
dim_t ic_comm_id = bli_thread_ocomm_id( pa_info );
|
||||
dim_t pa_comm_id = bli_thread_ocomm_id( jr_info );
|
||||
dim_t jr_comm_id = bli_thread_ocomm_id( ir_info );
|
||||
|
||||
dim_t jc_work_id = bli_thread_work_id( jc_info );
|
||||
dim_t pc_work_id = bli_thread_work_id( pc_info );
|
||||
dim_t pb_work_id = bli_thread_work_id( pb_info );
|
||||
dim_t ic_work_id = bli_thread_work_id( ic_info );
|
||||
dim_t pa_work_id = bli_thread_work_id( pa_info );
|
||||
dim_t jr_work_id = bli_thread_work_id( jr_info );
|
||||
dim_t ir_work_id = bli_thread_work_id( ir_info );
|
||||
|
||||
printf( " gl jc pb kc pa ic jr \n" );
|
||||
printf( "comm ids: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
|
||||
gl_comm_id, jc_comm_id, pc_comm_id, pb_comm_id, ic_comm_id, pa_comm_id, jr_comm_id );
|
||||
printf( "work ids: %4ld %4ld %4lu %4lu %4ld %4ld %4ld\n",
|
||||
jc_work_id, pc_work_id, pb_work_id, ic_work_id, pa_work_id, jr_work_id, ir_work_id );
|
||||
printf( "---------------------------------------\n" );
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
dim_t global_num_threads = jc_way * kc_way * ic_way * jr_way * ir_way;
|
||||
assert( global_num_threads != 0 );
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
dim_t jc_nt = kc_way * ic_way * jr_way * ir_way;
|
||||
dim_t kc_nt = ic_way * jr_way * ir_way;
|
||||
#if 0
|
||||
thrinfo_t** bli_l3_thrinfo_create_roots
|
||||
(
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl
|
||||
)
|
||||
{
|
||||
// Query the context for the total number of threads to use.
|
||||
dim_t n_threads = bli_cntx_get_num_threads( cntx );
|
||||
|
||||
// Create a global thread communicator for all the threads.
|
||||
thrcomm_t* gl_comm = bli_thrcomm_create( n_threads );
|
||||
|
||||
// Allocate an array of thrinfo_t pointers, one for each thread.
|
||||
thrinfo_t** paths = bli_malloc_intl( n_threads * sizeof( thrinfo_t* ) );
|
||||
|
||||
// Use the blocksize id of the current (root) control tree node to
|
||||
// query the top-most ways of parallelism to obtain.
|
||||
bszid_t bszid = bli_cntl_bszid( cntl );
|
||||
dim_t xx_way = bli_cntx_way_for_bszid( bszid, cntx );
|
||||
|
||||
dim_t gl_comm_id;
|
||||
|
||||
// Create one thrinfo_t node for each thread in the (global) communicator.
|
||||
for ( gl_comm_id = 0; gl_comm_id < n_threads; ++gl_comm_id )
|
||||
{
|
||||
dim_t work_id = gl_comm_id / ( n_threads / xx_way );
|
||||
|
||||
paths[ gl_comm_id ] = bli_thrinfo_create
|
||||
(
|
||||
gl_comm,
|
||||
gl_comm_id,
|
||||
xx_way,
|
||||
work_id,
|
||||
TRUE,
|
||||
NULL
|
||||
);
|
||||
}
|
||||
|
||||
return paths;
|
||||
}
|
||||
|
||||
//#define PRINT_THRINFO
|
||||
|
||||
thrinfo_t** bli_l3_thrinfo_create_full_paths
|
||||
(
|
||||
cntx_t* cntx
|
||||
)
|
||||
{
|
||||
dim_t jc_way = bli_cntx_jc_way( cntx );
|
||||
dim_t pc_way = bli_cntx_pc_way( cntx );
|
||||
dim_t ic_way = bli_cntx_ic_way( cntx );
|
||||
dim_t jr_way = bli_cntx_jr_way( cntx );
|
||||
dim_t ir_way = bli_cntx_ir_way( cntx );
|
||||
|
||||
dim_t gl_nt = jc_way * pc_way * ic_way * jr_way * ir_way;
|
||||
dim_t jc_nt = pc_way * ic_way * jr_way * ir_way;
|
||||
dim_t pc_nt = ic_way * jr_way * ir_way;
|
||||
dim_t ic_nt = jr_way * ir_way;
|
||||
dim_t jr_nt = ir_way;
|
||||
dim_t ir_nt = 1;
|
||||
|
||||
assert( gl_nt != 0 );
|
||||
|
||||
#ifdef PRINT_THRINFO
|
||||
printf( " jc kc ic jr ir\n" );
|
||||
printf( "xx_way: %4lu %4lu %4lu %4lu %4lu\n",
|
||||
jc_way, kc_way, ic_way, jr_way, ir_way );
|
||||
printf( " gl jc kc pb ic pa jr ir\n" );
|
||||
printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
|
||||
gl_nt, jc_nt, pc_nt, pc_nt, ic_nt, ic_nt, jr_nt, ir_nt );
|
||||
printf( "\n" );
|
||||
printf( " gl jc kc ic jr ir\n" );
|
||||
printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu\n",
|
||||
global_num_threads, jc_nt, kc_nt, ic_nt, jr_nt, ir_nt );
|
||||
printf( "=======================================\n" );
|
||||
printf( " jc kc pb ic pa jr ir\n" );
|
||||
printf( "xx_way: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
|
||||
jc_way, pc_way, (dim_t)0, ic_way, (dim_t)0, jr_way, ir_way );
|
||||
printf( "=================================================\n" );
|
||||
#endif
|
||||
|
||||
thrinfo_t** paths = bli_malloc_intl( global_num_threads * sizeof( thrinfo_t* ) );
|
||||
thrinfo_t** paths = bli_malloc_intl( gl_nt * sizeof( thrinfo_t* ) );
|
||||
|
||||
thrcomm_t* global_comm = bli_thrcomm_create( global_num_threads );
|
||||
thrcomm_t* gl_comm = bli_thrcomm_create( gl_nt );
|
||||
|
||||
for( int a = 0; a < jc_way; a++ )
|
||||
{
|
||||
thrcomm_t* jc_comm = bli_thrcomm_create( jc_nt );
|
||||
|
||||
for( int b = 0; b < kc_way; b++ )
|
||||
for( int b = 0; b < pc_way; b++ )
|
||||
{
|
||||
thrcomm_t* kc_comm = bli_thrcomm_create( kc_nt );
|
||||
thrcomm_t* pc_comm = bli_thrcomm_create( pc_nt );
|
||||
|
||||
for( int c = 0; c < ic_way; c++ )
|
||||
{
|
||||
@@ -246,73 +331,83 @@ printf( "=======================================\n" );
|
||||
|
||||
for( int e = 0; e < ir_way; e++ )
|
||||
{
|
||||
thrcomm_t* ir_comm = bli_thrcomm_create( ir_nt );
|
||||
|
||||
dim_t ir_comm_id = 0;
|
||||
dim_t jr_comm_id = e*ir_nt + ir_comm_id;
|
||||
dim_t ic_comm_id = d*jr_nt + jr_comm_id;
|
||||
dim_t kc_comm_id = c*ic_nt + ic_comm_id;
|
||||
dim_t jc_comm_id = b*kc_nt + kc_comm_id;
|
||||
dim_t global_comm_id = a*jc_nt + jc_comm_id;
|
||||
//thrcomm_t* ir_comm = bli_thrcomm_create( ir_nt );
|
||||
dim_t ir_comm_id = 0;
|
||||
dim_t jr_comm_id = e*ir_nt + ir_comm_id;
|
||||
dim_t ic_comm_id = d*jr_nt + jr_comm_id;
|
||||
dim_t pc_comm_id = c*ic_nt + ic_comm_id;
|
||||
dim_t jc_comm_id = b*pc_nt + pc_comm_id;
|
||||
dim_t gl_comm_id = a*jc_nt + jc_comm_id;
|
||||
|
||||
// macro-kernel loops
|
||||
thrinfo_t* ir_info
|
||||
=
|
||||
bli_l3_thrinfo_create( jr_comm, jr_comm_id,
|
||||
ir_comm, ir_comm_id,
|
||||
ir_way, e,
|
||||
NULL );
|
||||
thrinfo_t* jr_info
|
||||
=
|
||||
bli_l3_thrinfo_create( ic_comm, ic_comm_id,
|
||||
jr_comm, jr_comm_id,
|
||||
jr_way, d,
|
||||
ir_info );
|
||||
// packa
|
||||
thrinfo_t* pack_ic_in
|
||||
thrinfo_t* pa_info
|
||||
=
|
||||
bli_packm_thrinfo_create( ic_comm, ic_comm_id,
|
||||
jr_comm, jr_comm_id,
|
||||
ic_nt, ic_comm_id,
|
||||
jr_info );
|
||||
// blk_var1
|
||||
thrinfo_t* ic_info
|
||||
=
|
||||
bli_l3_thrinfo_create( kc_comm, kc_comm_id,
|
||||
ic_comm, ic_comm_id,
|
||||
bli_l3_thrinfo_create( pc_comm, pc_comm_id,
|
||||
ic_way, c,
|
||||
pack_ic_in );
|
||||
pa_info );
|
||||
// packb
|
||||
thrinfo_t* pack_kc_in
|
||||
thrinfo_t* pb_info
|
||||
=
|
||||
bli_packm_thrinfo_create( kc_comm, kc_comm_id,
|
||||
ic_comm, ic_comm_id,
|
||||
kc_nt, kc_comm_id,
|
||||
bli_packm_thrinfo_create( pc_comm, pc_comm_id,
|
||||
pc_nt, pc_comm_id,
|
||||
ic_info );
|
||||
// blk_var3
|
||||
thrinfo_t* kc_info
|
||||
thrinfo_t* pc_info
|
||||
=
|
||||
bli_l3_thrinfo_create( jc_comm, jc_comm_id,
|
||||
kc_comm, kc_comm_id,
|
||||
kc_way, b,
|
||||
pack_kc_in );
|
||||
pc_way, b,
|
||||
pb_info );
|
||||
// blk_var2
|
||||
thrinfo_t* jc_info
|
||||
=
|
||||
bli_l3_thrinfo_create( global_comm, global_comm_id,
|
||||
jc_comm, jc_comm_id,
|
||||
bli_l3_thrinfo_create( gl_comm, gl_comm_id,
|
||||
jc_way, a,
|
||||
kc_info );
|
||||
pc_info );
|
||||
|
||||
paths[global_comm_id] = jc_info;
|
||||
paths[gl_comm_id] = jc_info;
|
||||
|
||||
#ifdef PRINT_THRINFO
|
||||
printf( " gl jc kc ic jr ir\n" );
|
||||
printf( "comm ids: %4lu %4lu %4lu %4lu %4lu %4lu\n",
|
||||
global_comm_id, jc_comm_id, kc_comm_id, ic_comm_id, jr_comm_id, ir_comm_id );
|
||||
//printf( " a b c d e\n" );
|
||||
printf( "work ids: %4ld %4ld %4ld %4ld %4ld\n", (long int)a, (long int)b, (long int)c, (long int)d, (long int)e );
|
||||
printf( "---------------------------------------\n" );
|
||||
{
|
||||
dim_t gl_comm_id = bli_thread_ocomm_id( jc_info );
|
||||
dim_t jc_comm_id = bli_thread_ocomm_id( pc_info );
|
||||
dim_t pc_comm_id = bli_thread_ocomm_id( pb_info );
|
||||
dim_t pb_comm_id = bli_thread_ocomm_id( ic_info );
|
||||
dim_t ic_comm_id = bli_thread_ocomm_id( pa_info );
|
||||
dim_t pa_comm_id = bli_thread_ocomm_id( jr_info );
|
||||
dim_t jr_comm_id = bli_thread_ocomm_id( ir_info );
|
||||
|
||||
dim_t jc_work_id = bli_thread_work_id( jc_info );
|
||||
dim_t pc_work_id = bli_thread_work_id( pc_info );
|
||||
dim_t pb_work_id = bli_thread_work_id( pb_info );
|
||||
dim_t ic_work_id = bli_thread_work_id( ic_info );
|
||||
dim_t pa_work_id = bli_thread_work_id( pa_info );
|
||||
dim_t jr_work_id = bli_thread_work_id( jr_info );
|
||||
dim_t ir_work_id = bli_thread_work_id( ir_info );
|
||||
|
||||
printf( " gl jc pb kc pa ic jr \n" );
|
||||
printf( "comm ids: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
|
||||
gl_comm_id, jc_comm_id, pc_comm_id, pb_comm_id, ic_comm_id, pa_comm_id, jr_comm_id );
|
||||
printf( "work ids: %4ld %4ld %4lu %4lu %4ld %4ld %4ld\n",
|
||||
jc_work_id, pc_work_id, pb_work_id, ic_work_id, pa_work_id, jr_work_id, ir_work_id );
|
||||
printf( "-------------------------------------------------\n" );
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
@@ -330,15 +425,16 @@ exit(1);
|
||||
|
||||
void bli_l3_thrinfo_free_paths
|
||||
(
|
||||
thrinfo_t** threads,
|
||||
dim_t num
|
||||
thrinfo_t** threads
|
||||
)
|
||||
{
|
||||
dim_t n_threads = bli_thread_num_threads( threads[0] );
|
||||
dim_t i;
|
||||
|
||||
for ( i = 0; i < num; ++i )
|
||||
for ( i = 0; i < n_threads; ++i )
|
||||
bli_l3_thrinfo_free( threads[i] );
|
||||
|
||||
bli_free_intl( threads );
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -61,24 +61,22 @@
|
||||
// thrinfo_t APIs specific to level-3 operations.
|
||||
//
|
||||
|
||||
#if 0
|
||||
thrinfo_t* bli_l3_thrinfo_create
|
||||
(
|
||||
thrcomm_t* ocomm,
|
||||
dim_t ocomm_id,
|
||||
thrcomm_t* icomm,
|
||||
dim_t icomm_id,
|
||||
dim_t n_way,
|
||||
dim_t work_id,
|
||||
thrinfo_t* sub_node
|
||||
);
|
||||
#endif
|
||||
|
||||
void bli_l3_thrinfo_init
|
||||
(
|
||||
thrinfo_t* thread,
|
||||
thrcomm_t* ocomm,
|
||||
dim_t ocomm_id,
|
||||
thrcomm_t* icomm,
|
||||
dim_t icomm_id,
|
||||
dim_t n_way,
|
||||
dim_t work_id,
|
||||
thrinfo_t* sub_node
|
||||
@@ -96,15 +94,37 @@ void bli_l3_thrinfo_free
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
thrinfo_t** bli_l3_thrinfo_create_paths
|
||||
void bli_l3_thrinfo_create_root
|
||||
(
|
||||
opid_t l3_op,
|
||||
side_t side
|
||||
dim_t id,
|
||||
thrcomm_t* gl_comm,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t** thread
|
||||
);
|
||||
|
||||
void bli_l3_thrinfo_print_paths
|
||||
(
|
||||
thrinfo_t** threads
|
||||
);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
#if 0
|
||||
thrinfo_t** bli_l3_thrinfo_create_roots
|
||||
(
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl
|
||||
);
|
||||
|
||||
thrinfo_t** bli_l3_thrinfo_create_full_paths
|
||||
(
|
||||
cntx_t* cntx
|
||||
);
|
||||
|
||||
void bli_l3_thrinfo_free_paths
|
||||
(
|
||||
thrinfo_t** threads,
|
||||
dim_t num
|
||||
thrinfo_t** threads
|
||||
);
|
||||
#endif
|
||||
|
||||
|
||||
@@ -84,10 +84,10 @@ void bli_gemm_blk_var3
|
||||
c,
|
||||
cntx,
|
||||
bli_cntl_sub_node( cntl ),
|
||||
bli_thrinfo_sub_node( thread)
|
||||
bli_thrinfo_sub_node( thread )
|
||||
);
|
||||
|
||||
bli_thread_ibarrier( thread );
|
||||
bli_thread_obarrier( bli_thrinfo_sub_node( thread ) );
|
||||
|
||||
// This variant executes multiple rank-k updates. Therefore, if the
|
||||
// internal beta scalar on matrix C is non-zero, we must use it
|
||||
|
||||
@@ -46,14 +46,21 @@ cntl_t* bli_gemm_cntl_create
|
||||
if ( family == BLIS_HERK ) macro_kernel_p = bli_herk_x_ker_var2;
|
||||
else if ( family == BLIS_TRMM ) macro_kernel_p = bli_trmm_xx_ker_var2;
|
||||
|
||||
// Create a node for the macro-kernel.
|
||||
cntl_t* gemm_cntl_bp_ke = bli_gemm_cntl_obj_create
|
||||
// Create two nodes for the macro-kernel.
|
||||
cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_obj_create
|
||||
(
|
||||
BLIS_NR, // bszid not used by macro-kernel.
|
||||
macro_kernel_p,
|
||||
BLIS_MR, // needed for bli_thrinfo_rgrow()
|
||||
NULL, // variant function pointer not used
|
||||
NULL // no sub-node; this is the leaf of the tree.
|
||||
);
|
||||
|
||||
cntl_t* gemm_cntl_bp_bu = bli_gemm_cntl_obj_create
|
||||
(
|
||||
BLIS_NR, // not used by macro-kernel, but needed for bli_thrinfo_rgrow()
|
||||
macro_kernel_p,
|
||||
gemm_cntl_bu_ke
|
||||
);
|
||||
|
||||
// Create a node for packing matrix A.
|
||||
cntl_t* gemm_cntl_packa = bli_packm_cntl_obj_create
|
||||
(
|
||||
@@ -66,7 +73,7 @@ cntl_t* bli_gemm_cntl_create
|
||||
FALSE, // reverse iteration if lower?
|
||||
BLIS_PACKED_ROW_PANELS,
|
||||
BLIS_BUFFER_FOR_A_BLOCK,
|
||||
gemm_cntl_bp_ke
|
||||
gemm_cntl_bp_bu
|
||||
);
|
||||
|
||||
// Create a node for partitioning the m dimension by MC.
|
||||
|
||||
@@ -85,13 +85,19 @@ void bli_gemm_front
|
||||
// Set the operation family id in the context.
|
||||
bli_cntx_set_family( BLIS_GEMM, cntx );
|
||||
|
||||
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_GEMM, BLIS_LEFT );
|
||||
dim_t n_threads = bli_thread_num_threads( infos[0] );
|
||||
// Record the threading for each level within the context.
|
||||
bli_cntx_set_thrloop_from_env( BLIS_GEMM, BLIS_LEFT, cntx );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
// Create the first node in the thrinfo_t tree for each thread.
|
||||
//thrinfo_t** infos = bli_l3_thrinfo_create_full_paths( cntx );
|
||||
//bli_l3_thrinfo_print_paths( infos );
|
||||
//exit(1);
|
||||
//cntl = bli_gemm_cntl_create( BLIS_GEMM );
|
||||
//thrinfo_t** infos = bli_l3_thrinfo_create_roots( cntx, cntl );
|
||||
|
||||
// Invoke the internal back-end via the thread handler.
|
||||
bli_l3_thread_decorator
|
||||
(
|
||||
n_threads,
|
||||
bli_gemm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
@@ -99,10 +105,12 @@ void bli_gemm_front
|
||||
beta,
|
||||
&c_local,
|
||||
cntx,
|
||||
cntl,
|
||||
infos
|
||||
cntl
|
||||
);
|
||||
//bli_l3_thrinfo_print_paths( infos );
|
||||
//exit(1);
|
||||
|
||||
bli_l3_thrinfo_free_paths( infos, n_threads );
|
||||
// Free the thrinfo_t structures.
|
||||
//bli_l3_thrinfo_free_paths( infos );
|
||||
}
|
||||
|
||||
|
||||
@@ -50,7 +50,6 @@ void bli_gemm_int
|
||||
obj_t b_local;
|
||||
obj_t c_local;
|
||||
gemm_voft f;
|
||||
ind_t im;
|
||||
|
||||
// Check parameters.
|
||||
if ( bli_error_checking_is_enabled() )
|
||||
@@ -102,17 +101,22 @@ void bli_gemm_int
|
||||
bli_obj_scalar_apply_scalar( beta, &c_local );
|
||||
}
|
||||
|
||||
// Create the next node in the thrinfo_t structure.
|
||||
bli_thrinfo_grow( cntx, cntl, thread );
|
||||
|
||||
// Extract the function pointer from the current control tree node.
|
||||
f = bli_cntl_var_func( cntl );
|
||||
|
||||
// Somewhat hackish support for 3m3, 3m2, and 4m1b method implementations.
|
||||
im = bli_cntx_get_ind_method( cntx );
|
||||
|
||||
if ( im != BLIS_NAT )
|
||||
{
|
||||
if ( im == BLIS_3M3 && f == bli_gemm_packa ) f = bli_gemm3m3_packa;
|
||||
else if ( im == BLIS_3M2 && f == bli_gemm_ker_var2 ) f = bli_gemm3m2_ker_var2;
|
||||
else if ( im == BLIS_4M1B && f == bli_gemm_ker_var2 ) f = bli_gemm4mb_ker_var2;
|
||||
ind_t im = bli_cntx_get_ind_method( cntx );
|
||||
|
||||
if ( im != BLIS_NAT )
|
||||
{
|
||||
if ( im == BLIS_3M3 && f == bli_gemm_packa ) f = bli_gemm3m3_packa;
|
||||
else if ( im == BLIS_3M2 && f == bli_gemm_ker_var2 ) f = bli_gemm3m2_ker_var2;
|
||||
else if ( im == BLIS_4M1B && f == bli_gemm_ker_var2 ) f = bli_gemm4mb_ker_var2;
|
||||
}
|
||||
}
|
||||
|
||||
// Invoke the variant.
|
||||
|
||||
@@ -92,13 +92,12 @@ void bli_hemm_front
|
||||
// Set the operation family id in the context.
|
||||
bli_cntx_set_family( BLIS_GEMM, cntx );
|
||||
|
||||
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_HEMM, BLIS_LEFT );
|
||||
dim_t n_threads = bli_thread_num_threads( infos[0] );
|
||||
// Record the threading for each level within the context.
|
||||
bli_cntx_set_thrloop_from_env( BLIS_HEMM, BLIS_LEFT, cntx );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_l3_thread_decorator
|
||||
(
|
||||
n_threads,
|
||||
bli_gemm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
@@ -106,10 +105,7 @@ void bli_hemm_front
|
||||
beta,
|
||||
&c_local,
|
||||
cntx,
|
||||
cntl,
|
||||
infos
|
||||
cntl
|
||||
);
|
||||
|
||||
bli_l3_thrinfo_free_paths( infos, n_threads );
|
||||
}
|
||||
|
||||
|
||||
@@ -110,14 +110,14 @@ void bli_her2k_front
|
||||
// Set the operation family id in the context.
|
||||
bli_cntx_set_family( BLIS_HERK, cntx );
|
||||
|
||||
// Invoke herk twice, using beta only the first time.
|
||||
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_HER2K, BLIS_LEFT );
|
||||
dim_t n_threads = bli_thread_num_threads( infos[0] );
|
||||
// Record the threading for each level within the context.
|
||||
bli_cntx_set_thrloop_from_env( BLIS_HER2K, BLIS_LEFT, cntx );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
// Invoke herk twice, using beta only the first time.
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_l3_thread_decorator
|
||||
(
|
||||
n_threads,
|
||||
bli_gemm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
@@ -125,13 +125,11 @@ void bli_her2k_front
|
||||
beta,
|
||||
&c_local,
|
||||
cntx,
|
||||
cntl,
|
||||
infos
|
||||
cntl
|
||||
);
|
||||
|
||||
bli_l3_thread_decorator
|
||||
(
|
||||
n_threads,
|
||||
bli_gemm_int,
|
||||
&alpha_conj,
|
||||
&b_local,
|
||||
@@ -139,12 +137,9 @@ void bli_her2k_front
|
||||
&BLIS_ONE,
|
||||
&c_local,
|
||||
cntx,
|
||||
cntl,
|
||||
infos
|
||||
cntl
|
||||
);
|
||||
|
||||
bli_l3_thrinfo_free_paths( infos, n_threads );
|
||||
|
||||
// The Hermitian rank-2k product was computed as A*B'+B*A', even for
|
||||
// the diagonal elements. Mathematically, the imaginary components of
|
||||
// diagonal elements of a Hermitian rank-2k product should always be
|
||||
|
||||
@@ -90,13 +90,12 @@ void bli_herk_front
|
||||
// Set the operation family id in the context.
|
||||
bli_cntx_set_family( BLIS_HERK, cntx );
|
||||
|
||||
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_HERK, BLIS_LEFT );
|
||||
dim_t n_threads = bli_thread_num_threads( infos[0] );
|
||||
// Record the threading for each level within the context.
|
||||
bli_cntx_set_thrloop_from_env( BLIS_HERK, BLIS_LEFT, cntx );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_l3_thread_decorator
|
||||
(
|
||||
n_threads,
|
||||
bli_gemm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
@@ -104,12 +103,9 @@ void bli_herk_front
|
||||
beta,
|
||||
&c_local,
|
||||
cntx,
|
||||
cntl,
|
||||
infos
|
||||
cntl
|
||||
);
|
||||
|
||||
bli_l3_thrinfo_free_paths( infos, n_threads );
|
||||
|
||||
// The Hermitian rank-k product was computed as A*A', even for the
|
||||
// diagonal elements. Mathematically, the imaginary components of
|
||||
// diagonal elements of a Hermitian rank-k product should always be
|
||||
|
||||
@@ -91,13 +91,12 @@ void bli_symm_front
|
||||
// Set the operation family id in the context.
|
||||
bli_cntx_set_family( BLIS_GEMM, cntx );
|
||||
|
||||
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_SYMM, BLIS_LEFT );
|
||||
dim_t n_threads = bli_thread_num_threads( infos[0] );
|
||||
// Record the threading for each level within the context.
|
||||
bli_cntx_set_thrloop_from_env( BLIS_SYMM, BLIS_LEFT, cntx );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_l3_thread_decorator
|
||||
(
|
||||
n_threads,
|
||||
bli_gemm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
@@ -105,10 +104,7 @@ void bli_symm_front
|
||||
beta,
|
||||
&c_local,
|
||||
cntx,
|
||||
cntl,
|
||||
infos
|
||||
cntl
|
||||
);
|
||||
|
||||
bli_l3_thrinfo_free_paths( infos, n_threads );
|
||||
}
|
||||
|
||||
|
||||
@@ -91,14 +91,14 @@ void bli_syr2k_front
|
||||
// Set the operation family id in the context.
|
||||
bli_cntx_set_family( BLIS_HERK, cntx );
|
||||
|
||||
// Record the threading for each level within the context.
|
||||
bli_cntx_set_thrloop_from_env( BLIS_SYR2K, BLIS_LEFT, cntx );
|
||||
|
||||
// Invoke herk twice, using beta only the first time.
|
||||
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_SYR2K, BLIS_LEFT );
|
||||
dim_t n_threads = bli_thread_num_threads( infos[0] );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_l3_thread_decorator
|
||||
(
|
||||
n_threads,
|
||||
bli_gemm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
@@ -106,13 +106,11 @@ void bli_syr2k_front
|
||||
beta,
|
||||
&c_local,
|
||||
cntx,
|
||||
cntl,
|
||||
infos
|
||||
cntl
|
||||
);
|
||||
|
||||
bli_l3_thread_decorator
|
||||
(
|
||||
n_threads,
|
||||
bli_gemm_int,
|
||||
alpha,
|
||||
&b_local,
|
||||
@@ -120,10 +118,7 @@ void bli_syr2k_front
|
||||
&BLIS_ONE,
|
||||
&c_local,
|
||||
cntx,
|
||||
cntl,
|
||||
infos
|
||||
cntl
|
||||
);
|
||||
|
||||
bli_l3_thrinfo_free_paths( infos, n_threads );
|
||||
}
|
||||
|
||||
|
||||
@@ -84,13 +84,12 @@ void bli_syrk_front
|
||||
// Set the operation family id in the context.
|
||||
bli_cntx_set_family( BLIS_HERK, cntx );
|
||||
|
||||
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_SYRK, BLIS_LEFT );
|
||||
dim_t n_threads = bli_thread_num_threads( infos[0] );
|
||||
// Record the threading for each level within the context.
|
||||
bli_cntx_set_thrloop_from_env( BLIS_SYRK, BLIS_LEFT, cntx );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_l3_thread_decorator
|
||||
(
|
||||
n_threads,
|
||||
bli_gemm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
@@ -98,10 +97,7 @@ void bli_syrk_front
|
||||
beta,
|
||||
&c_local,
|
||||
cntx,
|
||||
cntl,
|
||||
infos
|
||||
cntl
|
||||
);
|
||||
|
||||
bli_l3_thrinfo_free_paths( infos, n_threads );
|
||||
}
|
||||
|
||||
|
||||
@@ -134,13 +134,12 @@ void bli_trmm_front
|
||||
// Set the operation family id in the context.
|
||||
bli_cntx_set_family( BLIS_TRMM, cntx );
|
||||
|
||||
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_TRMM, side );
|
||||
dim_t n_threads = bli_thread_num_threads( infos[0] );
|
||||
// Record the threading for each level within the context.
|
||||
bli_cntx_set_thrloop_from_env( BLIS_TRMM, side, cntx );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_l3_thread_decorator
|
||||
(
|
||||
n_threads,
|
||||
bli_gemm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
@@ -148,10 +147,7 @@ void bli_trmm_front
|
||||
&BLIS_ZERO,
|
||||
&c_local,
|
||||
cntx,
|
||||
cntl,
|
||||
infos
|
||||
cntl
|
||||
);
|
||||
|
||||
bli_l3_thrinfo_free_paths( infos, n_threads );
|
||||
}
|
||||
|
||||
|
||||
@@ -133,13 +133,12 @@ void bli_trmm3_front
|
||||
// Set the operation family id in the context.
|
||||
bli_cntx_set_family( BLIS_TRMM, cntx );
|
||||
|
||||
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_TRMM3, side );
|
||||
dim_t n_threads = bli_thread_num_threads( infos[0] );
|
||||
// Record the threading for each level within the context.
|
||||
bli_cntx_set_thrloop_from_env( BLIS_TRMM3, side, cntx );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_l3_thread_decorator
|
||||
(
|
||||
n_threads,
|
||||
bli_gemm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
@@ -147,10 +146,7 @@ void bli_trmm3_front
|
||||
beta,
|
||||
&c_local,
|
||||
cntx,
|
||||
cntl,
|
||||
infos
|
||||
cntl
|
||||
);
|
||||
|
||||
bli_l3_thrinfo_free_paths( infos, n_threads );
|
||||
}
|
||||
|
||||
|
||||
@@ -87,7 +87,8 @@ void bli_trsm_blk_var3
|
||||
bli_thrinfo_sub_node( thread )
|
||||
);
|
||||
|
||||
bli_thread_ibarrier( thread );
|
||||
//bli_thread_ibarrier( thread );
|
||||
bli_thread_obarrier( bli_thrinfo_sub_node( thread ) );
|
||||
|
||||
// This variant executes multiple rank-k updates. Therefore, if the
|
||||
// internal alpha scalars on A/B and C are non-zero, we must ensure
|
||||
|
||||
@@ -50,14 +50,21 @@ cntl_t* bli_trsm_l_cntl_create
|
||||
{
|
||||
void* macro_kernel_p = bli_trsm_xx_ker_var2;
|
||||
|
||||
// Create a node for the macro-kernel.
|
||||
cntl_t* trsm_cntl_bp_ke = bli_trsm_cntl_obj_create
|
||||
// Create two nodes for the macro-kernel.
|
||||
cntl_t* trsm_cntl_bu_ke = bli_trsm_cntl_obj_create
|
||||
(
|
||||
BLIS_NR, // bszid not used by macro-kernel.
|
||||
macro_kernel_p,
|
||||
BLIS_MR, // needed for bli_thrinfo_rgrow()
|
||||
NULL, // variant function pointer not used
|
||||
NULL // no sub-node; this is the leaf of the tree.
|
||||
);
|
||||
|
||||
cntl_t* trsm_cntl_bp_bu = bli_trsm_cntl_obj_create
|
||||
(
|
||||
BLIS_NR, // not used by macro-kernel, but needed for bli_thrinfo_rgrow()
|
||||
macro_kernel_p,
|
||||
trsm_cntl_bu_ke
|
||||
);
|
||||
|
||||
// Create a node for packing matrix A.
|
||||
cntl_t* trsm_cntl_packa = bli_packm_cntl_obj_create
|
||||
(
|
||||
@@ -70,7 +77,7 @@ cntl_t* bli_trsm_l_cntl_create
|
||||
FALSE, // reverse iteration if lower?
|
||||
BLIS_PACKED_ROW_PANELS,
|
||||
BLIS_BUFFER_FOR_A_BLOCK,
|
||||
trsm_cntl_bp_ke
|
||||
trsm_cntl_bp_bu
|
||||
);
|
||||
|
||||
// Create a node for partitioning the m dimension by MC.
|
||||
@@ -122,14 +129,21 @@ cntl_t* bli_trsm_r_cntl_create
|
||||
{
|
||||
void* macro_kernel_p = bli_trsm_xx_ker_var2;
|
||||
|
||||
// Create a node for the macro-kernel.
|
||||
cntl_t* trsm_cntl_bp_ke = bli_trsm_cntl_obj_create
|
||||
// Create two nodes for the macro-kernel.
|
||||
cntl_t* trsm_cntl_bu_ke = bli_trsm_cntl_obj_create
|
||||
(
|
||||
BLIS_NR, // bszid not used by macro-kernel.
|
||||
macro_kernel_p,
|
||||
BLIS_MR, // needed for bli_thrinfo_rgrow()
|
||||
NULL, // variant function pointer not used
|
||||
NULL // no sub-node; this is the leaf of the tree.
|
||||
);
|
||||
|
||||
cntl_t* trsm_cntl_bp_bu = bli_trsm_cntl_obj_create
|
||||
(
|
||||
BLIS_NR, // not used by macro-kernel, but needed for bli_thrinfo_rgrow()
|
||||
macro_kernel_p,
|
||||
trsm_cntl_bu_ke
|
||||
);
|
||||
|
||||
// Create a node for packing matrix A.
|
||||
cntl_t* trsm_cntl_packa = bli_packm_cntl_obj_create
|
||||
(
|
||||
@@ -142,7 +156,7 @@ cntl_t* bli_trsm_r_cntl_create
|
||||
FALSE, // reverse iteration if lower?
|
||||
BLIS_PACKED_ROW_PANELS,
|
||||
BLIS_BUFFER_FOR_A_BLOCK,
|
||||
trsm_cntl_bp_ke
|
||||
trsm_cntl_bp_bu
|
||||
);
|
||||
|
||||
// Create a node for partitioning the m dimension by MC.
|
||||
|
||||
@@ -119,13 +119,12 @@ void bli_trsm_front
|
||||
// Set the operation family id in the context.
|
||||
bli_cntx_set_family( BLIS_TRSM, cntx );
|
||||
|
||||
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_TRSM, side );
|
||||
dim_t n_threads = bli_thread_num_threads( infos[0] );
|
||||
// Record the threading for each level within the context.
|
||||
bli_cntx_set_thrloop_from_env( BLIS_TRSM, side, cntx );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_l3_thread_decorator
|
||||
(
|
||||
n_threads,
|
||||
bli_trsm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
@@ -133,10 +132,7 @@ void bli_trsm_front
|
||||
alpha,
|
||||
&c_local,
|
||||
cntx,
|
||||
cntl,
|
||||
infos
|
||||
cntl
|
||||
);
|
||||
|
||||
bli_l3_thrinfo_free_paths( infos, n_threads );
|
||||
}
|
||||
|
||||
|
||||
@@ -117,6 +117,9 @@ void bli_trsm_int
|
||||
// FGVZ->TMS: Is this barrier still needed?
|
||||
bli_thread_obarrier( thread );
|
||||
|
||||
// Create the next node in the thrinfo_t structure.
|
||||
bli_thrinfo_grow( cntx, cntl, thread );
|
||||
|
||||
// Extract the function pointer from the current control tree node.
|
||||
f = bli_cntl_var_func( cntl );
|
||||
|
||||
|
||||
@@ -341,6 +341,37 @@ pack_t bli_cntx_get_pack_schema_b( cntx_t* cntx )
|
||||
}
|
||||
#endif
|
||||
|
||||
dim_t bli_cntx_get_num_threads( cntx_t* cntx )
|
||||
{
|
||||
return bli_cntx_jc_way( cntx ) *
|
||||
bli_cntx_pc_way( cntx ) *
|
||||
bli_cntx_ic_way( cntx ) *
|
||||
bli_cntx_jr_way( cntx ) *
|
||||
bli_cntx_ir_way( cntx );
|
||||
}
|
||||
|
||||
dim_t bli_cntx_get_num_threads_in( cntx_t* cntx, cntl_t* cntl )
|
||||
{
|
||||
dim_t n_threads_in = 1;
|
||||
|
||||
for ( ; cntl != NULL; cntl = bli_cntl_sub_node( cntl ) )
|
||||
{
|
||||
bszid_t bszid = bli_cntl_bszid( cntl );
|
||||
dim_t cur_way;
|
||||
|
||||
// We assume bszid is in {KR,MR,NR,MC,KC,NR} if it is not
|
||||
// BLIS_NO_PART.
|
||||
if ( bszid != BLIS_NO_PART )
|
||||
cur_way = bli_cntx_way_for_bszid( bszid, cntx );
|
||||
else
|
||||
cur_way = 1;
|
||||
|
||||
n_threads_in *= cur_way;
|
||||
}
|
||||
|
||||
return n_threads_in;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
#if 1
|
||||
@@ -663,6 +694,96 @@ void bli_cntx_set_pack_schema_c( pack_t schema_c,
|
||||
bli_cntx_set_schema_c( schema_c, cntx );
|
||||
}
|
||||
|
||||
void bli_cntx_set_thrloop_from_env( opid_t l3_op, side_t side, cntx_t* cntx )
|
||||
{
|
||||
dim_t jc, pc, ic, jr, ir;
|
||||
|
||||
#ifdef BLIS_ENABLE_MULTITHREADING
|
||||
jc = bli_env_read_nway( "BLIS_JC_NT" );
|
||||
//pc = bli_env_read_nway( "BLIS_KC_NT" );
|
||||
pc = 1;
|
||||
ic = bli_env_read_nway( "BLIS_IC_NT" );
|
||||
jr = bli_env_read_nway( "BLIS_JR_NT" );
|
||||
ir = bli_env_read_nway( "BLIS_IR_NT" );
|
||||
#else
|
||||
jc = 1;
|
||||
pc = 1;
|
||||
ic = 1;
|
||||
jr = 1;
|
||||
ir = 1;
|
||||
#endif
|
||||
|
||||
if ( l3_op == BLIS_TRMM )
|
||||
{
|
||||
// We reconfigure the paralelism from trmm_r due to a dependency in
|
||||
// the jc loop. (NOTE: This dependency does not exist for trmm3 )
|
||||
if ( bli_is_right( side ) )
|
||||
{
|
||||
bli_cntx_set_thrloop
|
||||
(
|
||||
1,
|
||||
pc,
|
||||
ic,
|
||||
jr * jc,
|
||||
ir,
|
||||
cntx
|
||||
);
|
||||
}
|
||||
else // if ( bli_is_left( side ) )
|
||||
{
|
||||
bli_cntx_set_thrloop
|
||||
(
|
||||
jc,
|
||||
pc,
|
||||
ic,
|
||||
jr,
|
||||
ir,
|
||||
cntx
|
||||
);
|
||||
}
|
||||
}
|
||||
else if ( l3_op == BLIS_TRSM )
|
||||
{
|
||||
if ( bli_is_right( side ) )
|
||||
{
|
||||
bli_cntx_set_thrloop
|
||||
(
|
||||
1,
|
||||
1,
|
||||
jc * ic * jr,
|
||||
1,
|
||||
1,
|
||||
cntx
|
||||
);
|
||||
}
|
||||
else // if ( bli_is_left( side ) )
|
||||
{
|
||||
bli_cntx_set_thrloop
|
||||
(
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ic * jr * ir,
|
||||
1,
|
||||
cntx
|
||||
);
|
||||
}
|
||||
}
|
||||
else // if ( l3_op == BLIS_TRSM )
|
||||
{
|
||||
bli_cntx_set_thrloop
|
||||
(
|
||||
jc,
|
||||
pc,
|
||||
ic,
|
||||
jr,
|
||||
ir,
|
||||
cntx
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
bool_t bli_cntx_l3_nat_ukr_prefers_rows_dt( num_t dt,
|
||||
|
||||
@@ -59,6 +59,8 @@ typedef struct cntx_s
|
||||
pack_t schema_b;
|
||||
pack_t schema_c;
|
||||
|
||||
dim_t* thrloop;
|
||||
|
||||
membrk_t* membrk;
|
||||
} cntx_t;
|
||||
*/
|
||||
@@ -127,6 +129,36 @@ typedef struct cntx_s
|
||||
\
|
||||
( (cntx)->membrk )
|
||||
|
||||
#define bli_cntx_thrloop( cntx ) \
|
||||
\
|
||||
( (cntx)->thrloop )
|
||||
|
||||
#if 1
|
||||
#define bli_cntx_jc_way( cntx ) \
|
||||
\
|
||||
( (cntx)->thrloop[ BLIS_NC ] )
|
||||
|
||||
#define bli_cntx_pc_way( cntx ) \
|
||||
\
|
||||
( (cntx)->thrloop[ BLIS_KC ] )
|
||||
|
||||
#define bli_cntx_ic_way( cntx ) \
|
||||
\
|
||||
( (cntx)->thrloop[ BLIS_MC ] )
|
||||
|
||||
#define bli_cntx_jr_way( cntx ) \
|
||||
\
|
||||
( (cntx)->thrloop[ BLIS_NR ] )
|
||||
|
||||
#define bli_cntx_ir_way( cntx ) \
|
||||
\
|
||||
( (cntx)->thrloop[ BLIS_MR ] )
|
||||
#endif
|
||||
|
||||
#define bli_cntx_way_for_bszid( bszid, cntx ) \
|
||||
\
|
||||
( (cntx)->thrloop[ bszid ] )
|
||||
|
||||
// cntx_t modification (fields only)
|
||||
|
||||
#define bli_cntx_set_blkszs_buf( _blkszs, cntx_p ) \
|
||||
@@ -199,6 +231,16 @@ typedef struct cntx_s
|
||||
(cntx_p)->membrk = _membrk; \
|
||||
}
|
||||
|
||||
#define bli_cntx_set_thrloop( jc_, pc_, ic_, jr_, ir_, cntx_p ) \
|
||||
{ \
|
||||
(cntx_p)->thrloop[ BLIS_NC ] = jc_; \
|
||||
(cntx_p)->thrloop[ BLIS_KC ] = pc_; \
|
||||
(cntx_p)->thrloop[ BLIS_MC ] = ic_; \
|
||||
(cntx_p)->thrloop[ BLIS_NR ] = jr_; \
|
||||
(cntx_p)->thrloop[ BLIS_MR ] = ir_; \
|
||||
(cntx_p)->thrloop[ BLIS_KR ] = 1; \
|
||||
}
|
||||
|
||||
// cntx_t query (complex)
|
||||
|
||||
#define bli_cntx_get_blksz_def_dt( dt, bs_id, cntx ) \
|
||||
@@ -356,6 +398,8 @@ func_t* bli_cntx_get_packm_ukr( cntx_t* cntx );
|
||||
//pack_t bli_cntx_get_pack_schema_a( cntx_t* cntx );
|
||||
//pack_t bli_cntx_get_pack_schema_b( cntx_t* cntx );
|
||||
//pack_t bli_cntx_get_pack_schema_c( cntx_t* cntx );
|
||||
dim_t bli_cntx_get_num_threads( cntx_t* cntx );
|
||||
dim_t bli_cntx_get_num_threads_in( cntx_t* cntx, cntl_t* cntl );
|
||||
|
||||
// set functions
|
||||
|
||||
@@ -390,6 +434,9 @@ void bli_cntx_set_pack_schema_b( pack_t schema_b,
|
||||
cntx_t* cntx );
|
||||
void bli_cntx_set_pack_schema_c( pack_t schema_c,
|
||||
cntx_t* cntx );
|
||||
void bli_cntx_set_thrloop_from_env( opid_t l3_op,
|
||||
side_t side,
|
||||
cntx_t* cntx );
|
||||
|
||||
// other query functions
|
||||
|
||||
|
||||
@@ -638,6 +638,21 @@ typedef enum
|
||||
#define BLIS_NUM_UKR_IMPL_TYPES 4
|
||||
|
||||
|
||||
#if 0
|
||||
typedef enum
|
||||
{
|
||||
BLIS_JC_IDX = 0,
|
||||
BLIS_PC_IDX,
|
||||
BLIS_IC_IDX,
|
||||
BLIS_JR_IDX,
|
||||
BLIS_IR_IDX,
|
||||
BLIS_PR_IDX,
|
||||
} thridx_t;
|
||||
#endif
|
||||
|
||||
#define BLIS_NUM_LOOPS 6
|
||||
|
||||
|
||||
// -- Operation ID type --
|
||||
|
||||
typedef enum
|
||||
@@ -949,6 +964,8 @@ typedef struct cntx_s
|
||||
pack_t schema_b;
|
||||
pack_t schema_c;
|
||||
|
||||
dim_t thrloop[ BLIS_NUM_LOOPS ];
|
||||
|
||||
membrk_t* membrk;
|
||||
} cntx_t;
|
||||
|
||||
|
||||
@@ -41,6 +41,12 @@
|
||||
#include "bli_thrcomm_openmp.h"
|
||||
#include "bli_thrcomm_pthreads.h"
|
||||
|
||||
|
||||
// thrcomm_t query (field only)
|
||||
|
||||
#define bli_thrcomm_num_threads( comm ) ( (comm)->n_threads )
|
||||
|
||||
|
||||
// Thread communicator prototypes.
|
||||
thrcomm_t* bli_thrcomm_create( dim_t n_threads );
|
||||
void bli_thrcomm_free( thrcomm_t* communicator );
|
||||
|
||||
@@ -201,7 +201,6 @@ void bli_thrcomm_tree_barrier( barrier_t* barack )
|
||||
|
||||
void bli_l3_thread_decorator
|
||||
(
|
||||
dim_t n_threads,
|
||||
l3int_t func,
|
||||
obj_t* alpha,
|
||||
obj_t* a,
|
||||
@@ -209,20 +208,28 @@ void bli_l3_thread_decorator
|
||||
obj_t* beta,
|
||||
obj_t* c,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t** thread
|
||||
cntl_t* cntl
|
||||
)
|
||||
{
|
||||
// Query the total number of threads from the context.
|
||||
dim_t n_threads = bli_cntx_get_num_threads( cntx );
|
||||
|
||||
// Allcoate a global communicator for the root thrinfo_t structures.
|
||||
thrcomm_t* gl_comm = bli_thrcomm_create( n_threads );
|
||||
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
{
|
||||
dim_t omp_id = omp_get_thread_num();
|
||||
thrinfo_t* thread_i = thread[omp_id];
|
||||
dim_t id = omp_get_thread_num();
|
||||
|
||||
cntl_t* cntl_use;
|
||||
thrinfo_t* thread;
|
||||
|
||||
// Create a default control tree for the operation, if needed.
|
||||
bli_l3_cntl_create_if( a, b, c, cntx, cntl, &cntl_use );
|
||||
|
||||
// Create the root node of the current thread's thrinfo_t structure.
|
||||
bli_l3_thrinfo_create_root( id, gl_comm, cntx, cntl_use, &thread );
|
||||
|
||||
func
|
||||
(
|
||||
alpha,
|
||||
@@ -232,12 +239,19 @@ void bli_l3_thread_decorator
|
||||
c,
|
||||
cntx,
|
||||
cntl_use,
|
||||
thread[omp_id]
|
||||
thread
|
||||
);
|
||||
|
||||
// Free the control tree, if one was created locally.
|
||||
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread_i );
|
||||
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread );
|
||||
|
||||
// Free the current thread's thrinfo_t structure.
|
||||
bli_l3_thrinfo_free( thread );
|
||||
}
|
||||
|
||||
// We shouldn't free the global communicator since it was already freed
|
||||
// by the global communicator's chief thread in bli_l3_thrinfo_free()
|
||||
// (called above).
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -136,7 +136,8 @@ typedef struct thread_data
|
||||
obj_t* c;
|
||||
cntx_t* cntx;
|
||||
cntl_t* cntl;
|
||||
thrinfo_t* thread;
|
||||
dim_t id;
|
||||
thrcomm_t* gl_comm;
|
||||
} thread_data_t;
|
||||
|
||||
// Entry point for additional threads
|
||||
@@ -151,13 +152,18 @@ void* bli_l3_thread_entry( void* data_void )
|
||||
obj_t* c = data->c;
|
||||
cntx_t* cntx = data->cntx;
|
||||
cntl_t* cntl = data->cntl;
|
||||
thrinfo_t* thread_i = data->thread;
|
||||
dim_t id = data->id;
|
||||
thrcomm_t* gl_comm = data->gl_comm;
|
||||
|
||||
cntl_t* cntl_use;
|
||||
thrinfo_t* thread;
|
||||
|
||||
// Create a default control tree for the operation, if needed.
|
||||
bli_l3_cntl_create_if( a, b, c, cntx, cntl, &cntl_use );
|
||||
|
||||
// Create the root node of the current thread's thrinfo_t structure.
|
||||
bli_l3_thrinfo_create_root( id, gl_comm, cntx, cntl_use, &thread );
|
||||
|
||||
data->func
|
||||
(
|
||||
alpha,
|
||||
@@ -171,14 +177,16 @@ void* bli_l3_thread_entry( void* data_void )
|
||||
);
|
||||
|
||||
// Free the control tree, if one was created locally.
|
||||
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread_i );
|
||||
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread );
|
||||
|
||||
// Free the current thread's thrinfo_t structure.
|
||||
bli_l3_thrinfo_free( thread );
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
void bli_l3_thread_decorator
|
||||
(
|
||||
dim_t n_threads,
|
||||
l3int_t func,
|
||||
obj_t* alpha,
|
||||
obj_t* a,
|
||||
@@ -186,50 +194,51 @@ void bli_l3_thread_decorator
|
||||
obj_t* beta,
|
||||
obj_t* c,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t** thread
|
||||
cntl_t* cntl
|
||||
)
|
||||
{
|
||||
pthread_t* pthreads = bli_malloc_intl( sizeof( pthread_t ) * n_threads );
|
||||
thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads );
|
||||
// Query the total number of threads from the context.
|
||||
dim_t n_threads = bli_cntx_get_num_threads( cntx );
|
||||
|
||||
for ( int i = 1; i < n_threads; i++ )
|
||||
// Allocate an array of pthread objects and auxiliary data structs to pass
|
||||
// to the thread entry functions.
|
||||
pthread_t* pthreads = bli_malloc_intl( sizeof( pthread_t ) * n_threads );
|
||||
thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads );
|
||||
|
||||
// Allocate a global communicator for the root thrinfo_t structures.
|
||||
thrcomm_t* gl_comm = bli_thrcomm_create( n_threads );
|
||||
|
||||
// NOTE: We must iterate backwards so that the chief thread (thread id 0)
|
||||
// can spawn all other threads before proceeding with its own computation.
|
||||
for ( dim_t id = n_threads - 1; 0 <= id; id-- )
|
||||
{
|
||||
// Set up thread data for additional threads (beyond thread 0).
|
||||
datas[i].func = func;
|
||||
datas[i].alpha = alpha;
|
||||
datas[i].a = a;
|
||||
datas[i].b = b;
|
||||
datas[i].beta = beta;
|
||||
datas[i].c = c;
|
||||
datas[i].cntx = cntx;
|
||||
datas[i].cntl = cntl;
|
||||
datas[i].thread = thread[i];
|
||||
datas[id].func = func;
|
||||
datas[id].alpha = alpha;
|
||||
datas[id].a = a;
|
||||
datas[id].b = b;
|
||||
datas[id].beta = beta;
|
||||
datas[id].c = c;
|
||||
datas[id].cntx = cntx;
|
||||
datas[id].cntl = cntl;
|
||||
datas[id].id = id;
|
||||
datas[id].gl_comm = gl_comm;
|
||||
|
||||
// Spawn additional threads.
|
||||
pthread_create( &pthreads[i], NULL, &bli_l3_thread_entry, &datas[i] );
|
||||
}
|
||||
|
||||
|
||||
// The main thread executes this.
|
||||
{
|
||||
cntl_t* cntl_use;
|
||||
|
||||
// Create a default control tree for the operation, if needed.
|
||||
bli_l3_cntl_create_if( a, b, c, cntx, cntl, &cntl_use );
|
||||
|
||||
// Thread 0 simply executes func.
|
||||
func( alpha, a, b, beta, c, cntx, cntl, thread[0] );
|
||||
|
||||
// Free the control tree, if one was created locally.
|
||||
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread[0] );
|
||||
// Spawn additional threads for ids greater than 1.
|
||||
if ( id != 0 )
|
||||
pthread_create( &pthreads[id], NULL, &bli_l3_thread_entry, &datas[id] );
|
||||
else
|
||||
bli_l3_thread_entry( ( void* )(&datas[0]) );
|
||||
}
|
||||
|
||||
// We shouldn't free the global communicator since it was already freed
|
||||
// by the global communicator's chief thread in bli_l3_thrinfo_free()
|
||||
// (called from the thread entry function).
|
||||
|
||||
// Thread 0 waits for additional threads to finish.
|
||||
for ( int i = 1; i < n_threads; i++)
|
||||
for ( dim_t id = 1; id < n_threads; id++ )
|
||||
{
|
||||
pthread_join( pthreads[i], NULL );
|
||||
pthread_join( pthreads[id], NULL );
|
||||
}
|
||||
|
||||
bli_free_intl( pthreads );
|
||||
|
||||
@@ -73,7 +73,6 @@ void bli_thrcomm_barrier( thrcomm_t* communicator, dim_t t_id )
|
||||
|
||||
void bli_l3_thread_decorator
|
||||
(
|
||||
dim_t n_threads,
|
||||
l3int_t func,
|
||||
obj_t* alpha,
|
||||
obj_t* a,
|
||||
@@ -81,17 +80,25 @@ void bli_l3_thread_decorator
|
||||
obj_t* beta,
|
||||
obj_t* c,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t** thread
|
||||
cntl_t* cntl
|
||||
)
|
||||
{
|
||||
thrinfo_t* thread_i = thread[0];
|
||||
// For sequential execution, we use only one thread.
|
||||
dim_t n_threads = 1;
|
||||
dim_t id = 0;
|
||||
|
||||
// Allcoate a global communicator for the root thrinfo_t structures.
|
||||
thrcomm_t* gl_comm = bli_thrcomm_create( n_threads );
|
||||
|
||||
cntl_t* cntl_use;
|
||||
thrinfo_t* thread;
|
||||
|
||||
// Create a default control tree for the operation, if needed.
|
||||
bli_l3_cntl_create_if( a, b, c, cntx, cntl, &cntl_use );
|
||||
|
||||
// Create the root node of the thread's thrinfo_t structure.
|
||||
bli_l3_thrinfo_create_root( id, gl_comm, cntx, cntl_use, &thread );
|
||||
|
||||
func
|
||||
(
|
||||
alpha,
|
||||
@@ -101,11 +108,18 @@ void bli_l3_thread_decorator
|
||||
c,
|
||||
cntx,
|
||||
cntl_use,
|
||||
thread[0]
|
||||
thread
|
||||
);
|
||||
|
||||
// Free the control tree, if one was created locally.
|
||||
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread_i );
|
||||
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread );
|
||||
|
||||
// Free the current thread's thrinfo_t structure.
|
||||
bli_l3_thrinfo_free( thread );
|
||||
|
||||
// We shouldn't free the global communicator since it was already freed
|
||||
// by the global communicator's chief thread in bli_l3_thrinfo_free()
|
||||
// (called above).
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -78,8 +78,8 @@ void bli_thread_get_range_sub
|
||||
dim_t* end
|
||||
)
|
||||
{
|
||||
dim_t n_way = thread->n_way;
|
||||
dim_t work_id = thread->work_id;
|
||||
dim_t n_way = bli_thread_n_way( thread );
|
||||
dim_t work_id = bli_thread_work_id( thread );
|
||||
|
||||
dim_t all_start = 0;
|
||||
dim_t all_end = n;
|
||||
@@ -511,8 +511,8 @@ siz_t bli_thread_get_range_weighted_sub
|
||||
dim_t* j_end_thr
|
||||
)
|
||||
{
|
||||
dim_t n_way = thread->n_way;
|
||||
dim_t my_id = thread->work_id;
|
||||
dim_t n_way = bli_thread_n_way( thread );
|
||||
dim_t my_id = bli_thread_work_id( thread );
|
||||
|
||||
dim_t bf_left = n % bf;
|
||||
|
||||
|
||||
@@ -173,16 +173,14 @@ typedef void (*l3int_t)
|
||||
// Level-3 thread decorator prototype
|
||||
void bli_l3_thread_decorator
|
||||
(
|
||||
dim_t n_threads,
|
||||
l3int_t func,
|
||||
obj_t* alpha,
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
obj_t* beta,
|
||||
obj_t* c,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t** thread
|
||||
l3int_t func,
|
||||
obj_t* alpha,
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
obj_t* beta,
|
||||
obj_t* c,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl
|
||||
);
|
||||
|
||||
// Miscellaneous prototypes
|
||||
|
||||
@@ -38,11 +38,9 @@ thrinfo_t* bli_thrinfo_create
|
||||
(
|
||||
thrcomm_t* ocomm,
|
||||
dim_t ocomm_id,
|
||||
thrcomm_t* icomm,
|
||||
dim_t icomm_id,
|
||||
dim_t n_way,
|
||||
dim_t work_id,
|
||||
bool_t free_comms,
|
||||
bool_t free_comm,
|
||||
thrinfo_t* sub_node
|
||||
)
|
||||
{
|
||||
@@ -52,9 +50,8 @@ thrinfo_t* bli_thrinfo_create
|
||||
(
|
||||
thread,
|
||||
ocomm, ocomm_id,
|
||||
icomm, icomm_id,
|
||||
n_way, work_id,
|
||||
free_comms,
|
||||
free_comm,
|
||||
sub_node
|
||||
);
|
||||
|
||||
@@ -66,23 +63,19 @@ void bli_thrinfo_init
|
||||
thrinfo_t* thread,
|
||||
thrcomm_t* ocomm,
|
||||
dim_t ocomm_id,
|
||||
thrcomm_t* icomm,
|
||||
dim_t icomm_id,
|
||||
dim_t n_way,
|
||||
dim_t work_id,
|
||||
bool_t free_comms,
|
||||
bool_t free_comm,
|
||||
thrinfo_t* sub_node
|
||||
)
|
||||
{
|
||||
thread->ocomm = ocomm;
|
||||
thread->ocomm_id = ocomm_id;
|
||||
thread->icomm = icomm;
|
||||
thread->icomm_id = icomm_id;
|
||||
thread->n_way = n_way;
|
||||
thread->work_id = work_id;
|
||||
thread->free_comms = free_comms;
|
||||
thread->ocomm = ocomm;
|
||||
thread->ocomm_id = ocomm_id;
|
||||
thread->n_way = n_way;
|
||||
thread->work_id = work_id;
|
||||
thread->free_comm = free_comm;
|
||||
|
||||
thread->sub_node = sub_node;
|
||||
thread->sub_node = sub_node;
|
||||
}
|
||||
|
||||
void bli_thrinfo_init_single
|
||||
@@ -94,7 +87,6 @@ void bli_thrinfo_init_single
|
||||
(
|
||||
thread,
|
||||
&BLIS_SINGLE_COMM, 0,
|
||||
&BLIS_SINGLE_COMM, 0,
|
||||
1,
|
||||
0,
|
||||
FALSE,
|
||||
@@ -102,3 +94,178 @@ void bli_thrinfo_init_single
|
||||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
#include "assert.h"
|
||||
|
||||
#define BLIS_NUM_STATIC_COMMS 18
|
||||
|
||||
thrinfo_t* bli_thrinfo_create_for_cntl
|
||||
(
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl_par,
|
||||
cntl_t* cntl_chl,
|
||||
thrinfo_t* thread_par
|
||||
)
|
||||
{
|
||||
thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ];
|
||||
thrcomm_t** new_comms = NULL;
|
||||
|
||||
thrinfo_t* thread_chl;
|
||||
|
||||
bszid_t bszid_chl = bli_cntl_bszid( cntl_chl );
|
||||
|
||||
dim_t parent_nt_in = bli_thread_num_threads( thread_par );
|
||||
dim_t parent_n_way = bli_thread_n_way( thread_par );
|
||||
dim_t parent_comm_id = bli_thread_ocomm_id( thread_par );
|
||||
dim_t parent_work_id = bli_thread_work_id( thread_par );
|
||||
|
||||
dim_t child_nt_in;
|
||||
dim_t child_comm_id;
|
||||
dim_t child_n_way;
|
||||
dim_t child_work_id;
|
||||
|
||||
// Sanity check: make sure the number of threads in the parent's
|
||||
// communicator is divisible by the number of new sub-groups.
|
||||
assert( parent_nt_in % parent_n_way == 0 );
|
||||
|
||||
// Compute:
|
||||
// - the number of threads inside the new child comm,
|
||||
// - the current thread's id within the new communicator,
|
||||
// - the current thread's work id, given the ways of parallelism
|
||||
// to be obtained within the next loop.
|
||||
child_nt_in = bli_cntx_get_num_threads_in( cntx, cntl_chl );
|
||||
child_n_way = bli_cntx_way_for_bszid( bszid_chl, cntx );
|
||||
child_comm_id = parent_comm_id % child_nt_in;
|
||||
child_work_id = child_comm_id / ( child_nt_in / child_n_way );
|
||||
|
||||
// The parent's chief thread creates a temporary array of thrcomm_t
|
||||
// pointers.
|
||||
if ( bli_thread_am_ochief( thread_par ) )
|
||||
{
|
||||
if ( parent_n_way > BLIS_NUM_STATIC_COMMS )
|
||||
new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ) );
|
||||
else
|
||||
new_comms = static_comms;
|
||||
}
|
||||
|
||||
// Broadcast the temporary array to all threads in the parent's
|
||||
// communicator.
|
||||
new_comms = bli_thread_obroadcast( thread_par, new_comms );
|
||||
|
||||
// Chiefs in the child communicator allocate the communicator
|
||||
// object and store it in the array element corresponding to the
|
||||
// parent's work id.
|
||||
if ( child_comm_id == 0 )
|
||||
new_comms[ parent_work_id ] = bli_thrcomm_create( child_nt_in );
|
||||
|
||||
bli_thread_obarrier( thread_par );
|
||||
|
||||
// All threads create a new thrinfo_t node using the communicator
|
||||
// that was created by their chief, as identified by parent_work_id.
|
||||
thread_chl = bli_thrinfo_create
|
||||
(
|
||||
new_comms[ parent_work_id ],
|
||||
child_comm_id,
|
||||
child_n_way,
|
||||
child_work_id,
|
||||
TRUE,
|
||||
NULL
|
||||
);
|
||||
|
||||
bli_thread_obarrier( thread_par );
|
||||
|
||||
// The parent's chief thread frees the temporary array of thrcomm_t
|
||||
// pointers.
|
||||
if ( bli_thread_am_ochief( thread_par ) )
|
||||
{
|
||||
if ( parent_n_way > BLIS_NUM_STATIC_COMMS )
|
||||
bli_free_intl( new_comms );
|
||||
}
|
||||
|
||||
return thread_chl;
|
||||
}
|
||||
|
||||
void bli_thrinfo_grow
|
||||
(
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
)
|
||||
{
|
||||
// If the sub-node of the thrinfo_t object is non-NULL, we don't
|
||||
// need to create it, and will just use the existing sub-node as-is.
|
||||
if ( bli_thrinfo_sub_node( thread ) != NULL ) return;
|
||||
|
||||
// Create a new node (or, if needed, multiple nodes) and return the
|
||||
// pointer to the (eldest) child.
|
||||
thrinfo_t* thread_child = bli_thrinfo_rgrow
|
||||
(
|
||||
cntx,
|
||||
cntl,
|
||||
bli_cntl_sub_node( cntl ),
|
||||
thread
|
||||
);
|
||||
|
||||
// Attach the child thrinfo_t node to its parent structure.
|
||||
bli_thrinfo_set_sub_node( thread_child, thread );
|
||||
}
|
||||
|
||||
thrinfo_t* bli_thrinfo_rgrow
|
||||
(
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl_par,
|
||||
cntl_t* cntl_cur,
|
||||
thrinfo_t* thread_par
|
||||
)
|
||||
{
|
||||
thrinfo_t* thread_cur;
|
||||
|
||||
// We must handle two cases: those where the next node in the
|
||||
// control tree is a partitioning node, and those where it is
|
||||
// a non-partitioning (ie: packing) node.
|
||||
if ( bli_cntl_bszid( cntl_cur ) != BLIS_NO_PART )
|
||||
{
|
||||
// Create the child thrinfo_t node corresponding to cntl_cur,
|
||||
// with cntl_par being the parent.
|
||||
thread_cur = bli_thrinfo_create_for_cntl
|
||||
(
|
||||
cntx,
|
||||
cntl_par,
|
||||
cntl_cur,
|
||||
thread_par
|
||||
);
|
||||
}
|
||||
else // if ( bli_cntl_bszid( cntl_cur ) == BLIS_NO_PART )
|
||||
{
|
||||
// Recursively grow the thread structure and return the top-most
|
||||
// thrinfo_t node of that segment.
|
||||
thrinfo_t* thread_seg = bli_thrinfo_rgrow
|
||||
(
|
||||
cntx,
|
||||
cntl_par,
|
||||
bli_cntl_sub_node( cntl_cur ),
|
||||
thread_par
|
||||
);
|
||||
|
||||
// Create a thrinfo_t node corresponding to cntl_cur. Notice that
|
||||
// the free_comm field is set to FALSE, since cntl_cur is a
|
||||
// non-partitioning node. The communicator used here will be
|
||||
// freed when thread_seg, or one of its descendents, is freed.
|
||||
thread_cur = bli_thrinfo_create
|
||||
(
|
||||
bli_thrinfo_ocomm( thread_seg ),
|
||||
bli_thread_ocomm_id( thread_seg ),
|
||||
bli_cntx_get_num_threads_in( cntx, cntl_cur ),
|
||||
bli_thread_ocomm_id( thread_seg ),
|
||||
FALSE,
|
||||
thread_seg
|
||||
);
|
||||
|
||||
// Attach the child thrinfo_t node to its parent structure.
|
||||
bli_thrinfo_set_sub_node( thread_cur, thread_par );
|
||||
}
|
||||
|
||||
return thread_cur;
|
||||
}
|
||||
|
||||
|
||||
@@ -45,13 +45,6 @@ struct thrinfo_s
|
||||
// Our thread id within the ocomm thread communicator.
|
||||
dim_t ocomm_id;
|
||||
|
||||
// The thread communicator for the other threads sharing the same work
|
||||
// at this level.
|
||||
thrcomm_t* icomm;
|
||||
|
||||
// Our thread id within the icomm thread communicator.
|
||||
dim_t icomm_id;
|
||||
|
||||
// The number of distinct threads used to parallelize the loop.
|
||||
dim_t n_way;
|
||||
|
||||
@@ -62,7 +55,7 @@ struct thrinfo_s
|
||||
// this is field is true, but when nodes are created that share the same
|
||||
// communicators as other nodes (such as with packm nodes), this is set
|
||||
// to false.
|
||||
bool_t free_comms;
|
||||
bool_t free_comm;
|
||||
|
||||
struct thrinfo_s* sub_node;
|
||||
};
|
||||
@@ -71,30 +64,40 @@ typedef struct thrinfo_s thrinfo_t;
|
||||
//
|
||||
// thrinfo_t macros
|
||||
// NOTE: The naming of these should be made consistent at some point.
|
||||
// (ie: bli_thrinfo_ vs. bli_thread_)
|
||||
//
|
||||
|
||||
#define bli_thread_num_threads( t ) ( (t)->ocomm->n_threads )
|
||||
// thrinfo_t query (field only)
|
||||
|
||||
#define bli_thread_n_way( t ) ( (t)->n_way )
|
||||
#define bli_thread_work_id( t ) ( (t)->work_id )
|
||||
#define bli_thread_num_threads( t ) ( (t)->ocomm->n_threads )
|
||||
|
||||
#define bli_thread_am_ochief( t ) ( (t)->ocomm_id == 0 )
|
||||
#define bli_thread_am_ichief( t ) ( (t)->icomm_id == 0 )
|
||||
#define bli_thread_n_way( t ) ( (t)->n_way )
|
||||
#define bli_thread_work_id( t ) ( (t)->work_id )
|
||||
#define bli_thread_ocomm_id( t ) ( (t)->ocomm_id )
|
||||
|
||||
#define bli_thrinfo_ocomm( t ) ( (t)->ocomm )
|
||||
#define bli_thrinfo_needs_free_comm( t ) ( (t)->free_comm )
|
||||
|
||||
#define bli_thrinfo_sub_node( t ) ( (t)->sub_node )
|
||||
|
||||
// thrinfo_t query (complex)
|
||||
|
||||
#define bli_thread_am_ochief( t ) ( (t)->ocomm_id == 0 )
|
||||
|
||||
// thrinfo_t modification
|
||||
|
||||
#define bli_thrinfo_set_sub_node( _sub_node, thread ) \
|
||||
{ \
|
||||
(thread)->sub_node = _sub_node; \
|
||||
}
|
||||
|
||||
// other thrinfo_t-related macros
|
||||
|
||||
#define bli_thread_obroadcast( t, p ) bli_thrcomm_bcast( (t)->ocomm, \
|
||||
(t)->ocomm_id, p )
|
||||
#define bli_thread_ibroadcast( t, p ) bli_thrcomm_bcast( (t)->icomm, \
|
||||
(t)->icomm_id, p )
|
||||
#define bli_thread_obarrier( t ) bli_thrcomm_barrier( (t)->ocomm, \
|
||||
(t)->ocomm_id )
|
||||
#define bli_thread_ibarrier( t ) bli_thrcomm_barrier( (t)->icomm, \
|
||||
(t)->icomm_id )
|
||||
|
||||
#define bli_thrinfo_ocomm( t ) ( (t)->ocomm )
|
||||
#define bli_thrinfo_icomm( t ) ( (t)->icomm )
|
||||
#define bli_thrinfo_needs_free_comms( t ) ( (t)->free_comms )
|
||||
|
||||
#define bli_thrinfo_sub_node( t ) ( (t)->sub_node )
|
||||
|
||||
//
|
||||
// Prototypes for level-3 thrinfo functions not specific to any operation.
|
||||
@@ -104,11 +107,9 @@ thrinfo_t* bli_thrinfo_create
|
||||
(
|
||||
thrcomm_t* ocomm,
|
||||
dim_t ocomm_id,
|
||||
thrcomm_t* icomm,
|
||||
dim_t icomm_id,
|
||||
dim_t n_way,
|
||||
dim_t work_id,
|
||||
bool_t free_comms,
|
||||
bool_t free_comm,
|
||||
thrinfo_t* sub_node
|
||||
);
|
||||
|
||||
@@ -117,11 +118,9 @@ void bli_thrinfo_init
|
||||
thrinfo_t* thread,
|
||||
thrcomm_t* ocomm,
|
||||
dim_t ocomm_id,
|
||||
thrcomm_t* icomm,
|
||||
dim_t icomm_id,
|
||||
dim_t n_way,
|
||||
dim_t work_id,
|
||||
bool_t free_comms,
|
||||
bool_t free_comm,
|
||||
thrinfo_t* sub_node
|
||||
);
|
||||
|
||||
@@ -130,9 +129,29 @@ void bli_thrinfo_init_single
|
||||
thrinfo_t* thread
|
||||
);
|
||||
|
||||
void bli_thrinfo_free
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
thrinfo_t* bli_thrinfo_create_for_cntl
|
||||
(
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl_par,
|
||||
cntl_t* cntl_chl,
|
||||
thrinfo_t* thread_par
|
||||
);
|
||||
|
||||
void bli_thrinfo_grow
|
||||
(
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
);
|
||||
|
||||
thrinfo_t* bli_thrinfo_rgrow
|
||||
(
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl_par,
|
||||
cntl_t* cntl_cur,
|
||||
thrinfo_t* thread_par
|
||||
);
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user