Merge pull request #94 from flame/distcomm

Implemented distributed thrinfo_t management.
This commit is contained in:
Field G. Van Zee
2016-10-04 15:53:46 -05:00
committed by GitHub
31 changed files with 887 additions and 364 deletions

View File

@@ -34,12 +34,11 @@
#include "blis.h"
#if 0
thrinfo_t* bli_packm_thrinfo_create
(
thrcomm_t* ocomm,
dim_t ocomm_id,
thrcomm_t* icomm,
dim_t icomm_id,
dim_t n_way,
dim_t work_id,
thrinfo_t* sub_node
@@ -51,7 +50,6 @@ thrinfo_t* bli_packm_thrinfo_create
(
thread,
ocomm, ocomm_id,
icomm, icomm_id,
n_way,
work_id,
FALSE,
@@ -60,14 +58,13 @@ thrinfo_t* bli_packm_thrinfo_create
return thread;
}
#endif
void bli_packm_thrinfo_init
(
thrinfo_t* thread,
thrcomm_t* ocomm,
dim_t ocomm_id,
thrcomm_t* icomm,
dim_t icomm_id,
dim_t n_way,
dim_t work_id,
thrinfo_t* sub_node
@@ -77,7 +74,6 @@ void bli_packm_thrinfo_init
(
thread,
ocomm, ocomm_id,
icomm, icomm_id,
n_way, work_id,
FALSE,
sub_node
@@ -93,13 +89,13 @@ void bli_packm_thrinfo_init_single
(
thread,
&BLIS_SINGLE_COMM, 0,
&BLIS_SINGLE_COMM, 0,
1,
0,
NULL
);
}
#if 0
void bli_packm_thrinfo_free
(
thrinfo_t* thread
@@ -109,4 +105,4 @@ void bli_packm_thrinfo_free
thread != &BLIS_PACKM_SINGLE_THREADED )
bli_free_intl( thread );
}
#endif

View File

@@ -42,24 +42,22 @@
// thrinfo_t APIs specific to packm.
//
#if 0
thrinfo_t* bli_packm_thrinfo_create
(
thrcomm_t* ocomm,
dim_t ocomm_id,
thrcomm_t* icomm,
dim_t icomm_id,
dim_t n_way,
dim_t work_id,
thrinfo_t* sub_node
);
#endif
void bli_packm_thrinfo_init
(
thrinfo_t* thread,
thrcomm_t* ocomm,
dim_t ocomm_id,
thrcomm_t* icomm,
dim_t icomm_id,
dim_t n_way,
dim_t work_id,
thrinfo_t* sub_node
@@ -70,8 +68,10 @@ void bli_packm_thrinfo_init_single
thrinfo_t* thread
);
#if 0
void bli_packm_thrinfo_free
(
thrinfo_t* thread
);
#endif

View File

@@ -35,12 +35,11 @@
#include "blis.h"
#include "assert.h"
#if 0
thrinfo_t* bli_l3_thrinfo_create
(
thrcomm_t* ocomm,
dim_t ocomm_id,
thrcomm_t* icomm,
dim_t icomm_id,
dim_t n_way,
dim_t work_id,
thrinfo_t* sub_node
@@ -49,21 +48,19 @@ thrinfo_t* bli_l3_thrinfo_create
return bli_thrinfo_create
(
ocomm, ocomm_id,
icomm, icomm_id,
n_way,
work_id,
TRUE,
sub_node
);
}
#endif
void bli_l3_thrinfo_init
(
thrinfo_t* thread,
thrcomm_t* ocomm,
dim_t ocomm_id,
thrcomm_t* icomm,
dim_t icomm_id,
dim_t n_way,
dim_t work_id,
thrinfo_t* sub_node
@@ -73,7 +70,6 @@ void bli_l3_thrinfo_init
(
thread,
ocomm, ocomm_id,
icomm, icomm_id,
n_way,
work_id,
TRUE,
@@ -105,14 +101,12 @@ void bli_l3_thrinfo_free
// is marked as needing them to be freed. The most common example of
// thrinfo_t nodes NOT marked as needing their comms freed are those
// associated with packm thrinfo_t nodes.
if ( bli_thrinfo_needs_free_comms( thread ) )
if ( bli_thrinfo_needs_free_comm( thread ) )
{
// The ochief always frees his communicator, and the ichief free its
// communicator if we are at the leaf node.
if ( bli_thread_am_ochief( thread ) )
bli_thrcomm_free( bli_thrinfo_ocomm( thread ) );
if ( thrinfo_sub_node == NULL && bli_thread_am_ichief( thread ) )
bli_thrcomm_free( bli_thrinfo_icomm( thread ) );
}
// Free all children of the current thrinfo_t.
@@ -124,117 +118,208 @@ void bli_l3_thrinfo_free
// -----------------------------------------------------------------------------
//#define PRINT_THRINFO
thrinfo_t** bli_l3_thrinfo_create_paths
void bli_l3_thrinfo_create_root
(
opid_t l3_op,
side_t side
dim_t id,
thrcomm_t* gl_comm,
cntx_t* cntx,
cntl_t* cntl,
thrinfo_t** thread
)
{
dim_t jc_in, jc_way;
dim_t kc_in, kc_way;
dim_t ic_in, ic_way;
dim_t jr_in, jr_way;
dim_t ir_in, ir_way;
// Query the global communicator for the total number of threads to use.
dim_t n_threads = bli_thrcomm_num_threads( gl_comm );
#ifdef BLIS_ENABLE_MULTITHREADING
jc_in = bli_env_read_nway( "BLIS_JC_NT" );
//kc_way = bli_env_read_nway( "BLIS_KC_NT" );
kc_in = 1;
ic_in = bli_env_read_nway( "BLIS_IC_NT" );
jr_in = bli_env_read_nway( "BLIS_JR_NT" );
ir_in = bli_env_read_nway( "BLIS_IR_NT" );
#else
jc_in = 1;
kc_in = 1;
ic_in = 1;
jr_in = 1;
ir_in = 1;
#endif
// Use the thread id passed in as the global communicator id.
dim_t gl_comm_id = id;
if ( l3_op == BLIS_TRMM )
{
// We reconfigure the parallelism for trmm_r due to a dependency in
// the jc loop. (NOTE: This dependency does not exist for trmm3.)
if ( bli_is_right( side ) )
{
jc_way = 1;
kc_way = kc_in;
ic_way = ic_in;
jr_way = jr_in * jc_in;
ir_way = ir_in;
}
else // if ( bli_is_left( side ) )
{
jc_way = jc_in;
kc_way = kc_in;
ic_way = ic_in;
jr_way = jr_in;
ir_way = ir_in;
}
}
else if ( l3_op == BLIS_TRSM )
{
if ( bli_is_right( side ) )
{
// Use the blocksize id of the current (root) control tree node to
// query the top-most ways of parallelism to obtain.
bszid_t bszid = bli_cntl_bszid( cntl );
dim_t xx_way = bli_cntx_way_for_bszid( bszid, cntx );
jc_way = 1;
kc_way = 1;
ic_way = jc_in * ic_in * jr_in;
jr_way = 1;
ir_way = 1;
}
else // if ( bli_is_left( side ) )
{
jc_way = 1;
kc_way = 1;
ic_way = 1;
jr_way = ic_in * jr_in * ir_in;
ir_way = 1;
}
}
else // all other level-3 operations
// Determine the work id for this thrinfo_t node.
dim_t work_id = gl_comm_id / ( n_threads / xx_way );
// Create the root thrinfo_t node.
*thread = bli_thrinfo_create
(
gl_comm,
gl_comm_id,
xx_way,
work_id,
TRUE,
NULL
);
}
// -----------------------------------------------------------------------------
void bli_l3_thrinfo_print_paths
(
thrinfo_t** threads
)
{
dim_t n_threads = bli_thread_num_threads( threads[0] );
dim_t gl_comm_id;
thrinfo_t* jc_info = threads[0];
thrinfo_t* pc_info = bli_thrinfo_sub_node( jc_info );
thrinfo_t* pb_info = bli_thrinfo_sub_node( pc_info );
thrinfo_t* ic_info = bli_thrinfo_sub_node( pb_info );
thrinfo_t* pa_info = bli_thrinfo_sub_node( ic_info );
thrinfo_t* jr_info = bli_thrinfo_sub_node( pa_info );
thrinfo_t* ir_info = bli_thrinfo_sub_node( jr_info );
dim_t jc_way = bli_thread_n_way( jc_info );
dim_t pc_way = bli_thread_n_way( pc_info );
dim_t pb_way = bli_thread_n_way( pb_info );
dim_t ic_way = bli_thread_n_way( ic_info );
dim_t pa_way = bli_thread_n_way( pa_info );
dim_t jr_way = bli_thread_n_way( jr_info );
dim_t ir_way = bli_thread_n_way( ir_info );
dim_t gl_nt = bli_thread_num_threads( jc_info );
dim_t jc_nt = bli_thread_num_threads( pc_info );
dim_t pc_nt = bli_thread_num_threads( pb_info );
dim_t pb_nt = bli_thread_num_threads( ic_info );
dim_t ic_nt = bli_thread_num_threads( pa_info );
dim_t pa_nt = bli_thread_num_threads( jr_info );
dim_t jr_nt = bli_thread_num_threads( ir_info );
printf( " gl jc kc pb ic pa jr ir\n" );
printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
gl_nt, jc_nt, pc_nt, pb_nt, ic_nt, pa_nt, jr_nt, (dim_t)1 );
printf( "\n" );
printf( " jc kc pb ic pa jr ir\n" );
printf( "xx_way: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
jc_way, pc_way, pb_way, ic_way, pa_way, jr_way, ir_way );
printf( "=================================================\n" );
for ( gl_comm_id = 0; gl_comm_id < n_threads; ++gl_comm_id )
{
jc_way = jc_in;
kc_way = kc_in;
ic_way = ic_in;
jr_way = jr_in;
ir_way = ir_in;
jc_info = threads[gl_comm_id];
pc_info = bli_thrinfo_sub_node( jc_info );
pb_info = bli_thrinfo_sub_node( pc_info );
ic_info = bli_thrinfo_sub_node( pb_info );
pa_info = bli_thrinfo_sub_node( ic_info );
jr_info = bli_thrinfo_sub_node( pa_info );
ir_info = bli_thrinfo_sub_node( jr_info );
dim_t gl_comm_id = bli_thread_ocomm_id( jc_info );
dim_t jc_comm_id = bli_thread_ocomm_id( pc_info );
dim_t pc_comm_id = bli_thread_ocomm_id( pb_info );
dim_t pb_comm_id = bli_thread_ocomm_id( ic_info );
dim_t ic_comm_id = bli_thread_ocomm_id( pa_info );
dim_t pa_comm_id = bli_thread_ocomm_id( jr_info );
dim_t jr_comm_id = bli_thread_ocomm_id( ir_info );
dim_t jc_work_id = bli_thread_work_id( jc_info );
dim_t pc_work_id = bli_thread_work_id( pc_info );
dim_t pb_work_id = bli_thread_work_id( pb_info );
dim_t ic_work_id = bli_thread_work_id( ic_info );
dim_t pa_work_id = bli_thread_work_id( pa_info );
dim_t jr_work_id = bli_thread_work_id( jr_info );
dim_t ir_work_id = bli_thread_work_id( ir_info );
printf( " gl jc pb kc pa ic jr \n" );
printf( "comm ids: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
gl_comm_id, jc_comm_id, pc_comm_id, pb_comm_id, ic_comm_id, pa_comm_id, jr_comm_id );
printf( "work ids: %4ld %4ld %4lu %4lu %4ld %4ld %4ld\n",
jc_work_id, pc_work_id, pb_work_id, ic_work_id, pa_work_id, jr_work_id, ir_work_id );
printf( "---------------------------------------\n" );
}
}
dim_t global_num_threads = jc_way * kc_way * ic_way * jr_way * ir_way;
assert( global_num_threads != 0 );
// -----------------------------------------------------------------------------
dim_t jc_nt = kc_way * ic_way * jr_way * ir_way;
dim_t kc_nt = ic_way * jr_way * ir_way;
#if 0
thrinfo_t** bli_l3_thrinfo_create_roots
(
cntx_t* cntx,
cntl_t* cntl
)
{
// Query the context for the total number of threads to use.
dim_t n_threads = bli_cntx_get_num_threads( cntx );
// Create a global thread communicator for all the threads.
thrcomm_t* gl_comm = bli_thrcomm_create( n_threads );
// Allocate an array of thrinfo_t pointers, one for each thread.
thrinfo_t** paths = bli_malloc_intl( n_threads * sizeof( thrinfo_t* ) );
// Use the blocksize id of the current (root) control tree node to
// query the top-most ways of parallelism to obtain.
bszid_t bszid = bli_cntl_bszid( cntl );
dim_t xx_way = bli_cntx_way_for_bszid( bszid, cntx );
dim_t gl_comm_id;
// Create one thrinfo_t node for each thread in the (global) communicator.
for ( gl_comm_id = 0; gl_comm_id < n_threads; ++gl_comm_id )
{
dim_t work_id = gl_comm_id / ( n_threads / xx_way );
paths[ gl_comm_id ] = bli_thrinfo_create
(
gl_comm,
gl_comm_id,
xx_way,
work_id,
TRUE,
NULL
);
}
return paths;
}
//#define PRINT_THRINFO
thrinfo_t** bli_l3_thrinfo_create_full_paths
(
cntx_t* cntx
)
{
dim_t jc_way = bli_cntx_jc_way( cntx );
dim_t pc_way = bli_cntx_pc_way( cntx );
dim_t ic_way = bli_cntx_ic_way( cntx );
dim_t jr_way = bli_cntx_jr_way( cntx );
dim_t ir_way = bli_cntx_ir_way( cntx );
dim_t gl_nt = jc_way * pc_way * ic_way * jr_way * ir_way;
dim_t jc_nt = pc_way * ic_way * jr_way * ir_way;
dim_t pc_nt = ic_way * jr_way * ir_way;
dim_t ic_nt = jr_way * ir_way;
dim_t jr_nt = ir_way;
dim_t ir_nt = 1;
assert( gl_nt != 0 );
#ifdef PRINT_THRINFO
printf( " jc kc ic jr ir\n" );
printf( "xx_way: %4lu %4lu %4lu %4lu %4lu\n",
jc_way, kc_way, ic_way, jr_way, ir_way );
printf( " gl jc kc pb ic pa jr ir\n" );
printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
gl_nt, jc_nt, pc_nt, pc_nt, ic_nt, ic_nt, jr_nt, ir_nt );
printf( "\n" );
printf( " gl jc kc ic jr ir\n" );
printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu\n",
global_num_threads, jc_nt, kc_nt, ic_nt, jr_nt, ir_nt );
printf( "=======================================\n" );
printf( " jc kc pb ic pa jr ir\n" );
printf( "xx_way: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
jc_way, pc_way, (dim_t)0, ic_way, (dim_t)0, jr_way, ir_way );
printf( "=================================================\n" );
#endif
thrinfo_t** paths = bli_malloc_intl( global_num_threads * sizeof( thrinfo_t* ) );
thrinfo_t** paths = bli_malloc_intl( gl_nt * sizeof( thrinfo_t* ) );
thrcomm_t* global_comm = bli_thrcomm_create( global_num_threads );
thrcomm_t* gl_comm = bli_thrcomm_create( gl_nt );
for( int a = 0; a < jc_way; a++ )
{
thrcomm_t* jc_comm = bli_thrcomm_create( jc_nt );
for( int b = 0; b < kc_way; b++ )
for( int b = 0; b < pc_way; b++ )
{
thrcomm_t* kc_comm = bli_thrcomm_create( kc_nt );
thrcomm_t* pc_comm = bli_thrcomm_create( pc_nt );
for( int c = 0; c < ic_way; c++ )
{
@@ -246,73 +331,83 @@ printf( "=======================================\n" );
for( int e = 0; e < ir_way; e++ )
{
thrcomm_t* ir_comm = bli_thrcomm_create( ir_nt );
dim_t ir_comm_id = 0;
dim_t jr_comm_id = e*ir_nt + ir_comm_id;
dim_t ic_comm_id = d*jr_nt + jr_comm_id;
dim_t kc_comm_id = c*ic_nt + ic_comm_id;
dim_t jc_comm_id = b*kc_nt + kc_comm_id;
dim_t global_comm_id = a*jc_nt + jc_comm_id;
//thrcomm_t* ir_comm = bli_thrcomm_create( ir_nt );
dim_t ir_comm_id = 0;
dim_t jr_comm_id = e*ir_nt + ir_comm_id;
dim_t ic_comm_id = d*jr_nt + jr_comm_id;
dim_t pc_comm_id = c*ic_nt + ic_comm_id;
dim_t jc_comm_id = b*pc_nt + pc_comm_id;
dim_t gl_comm_id = a*jc_nt + jc_comm_id;
// macro-kernel loops
thrinfo_t* ir_info
=
bli_l3_thrinfo_create( jr_comm, jr_comm_id,
ir_comm, ir_comm_id,
ir_way, e,
NULL );
thrinfo_t* jr_info
=
bli_l3_thrinfo_create( ic_comm, ic_comm_id,
jr_comm, jr_comm_id,
jr_way, d,
ir_info );
// packa
thrinfo_t* pack_ic_in
thrinfo_t* pa_info
=
bli_packm_thrinfo_create( ic_comm, ic_comm_id,
jr_comm, jr_comm_id,
ic_nt, ic_comm_id,
jr_info );
// blk_var1
thrinfo_t* ic_info
=
bli_l3_thrinfo_create( kc_comm, kc_comm_id,
ic_comm, ic_comm_id,
bli_l3_thrinfo_create( pc_comm, pc_comm_id,
ic_way, c,
pack_ic_in );
pa_info );
// packb
thrinfo_t* pack_kc_in
thrinfo_t* pb_info
=
bli_packm_thrinfo_create( kc_comm, kc_comm_id,
ic_comm, ic_comm_id,
kc_nt, kc_comm_id,
bli_packm_thrinfo_create( pc_comm, pc_comm_id,
pc_nt, pc_comm_id,
ic_info );
// blk_var3
thrinfo_t* kc_info
thrinfo_t* pc_info
=
bli_l3_thrinfo_create( jc_comm, jc_comm_id,
kc_comm, kc_comm_id,
kc_way, b,
pack_kc_in );
pc_way, b,
pb_info );
// blk_var2
thrinfo_t* jc_info
=
bli_l3_thrinfo_create( global_comm, global_comm_id,
jc_comm, jc_comm_id,
bli_l3_thrinfo_create( gl_comm, gl_comm_id,
jc_way, a,
kc_info );
pc_info );
paths[global_comm_id] = jc_info;
paths[gl_comm_id] = jc_info;
#ifdef PRINT_THRINFO
printf( " gl jc kc ic jr ir\n" );
printf( "comm ids: %4lu %4lu %4lu %4lu %4lu %4lu\n",
global_comm_id, jc_comm_id, kc_comm_id, ic_comm_id, jr_comm_id, ir_comm_id );
//printf( " a b c d e\n" );
printf( "work ids: %4ld %4ld %4ld %4ld %4ld\n", (long int)a, (long int)b, (long int)c, (long int)d, (long int)e );
printf( "---------------------------------------\n" );
{
dim_t gl_comm_id = bli_thread_ocomm_id( jc_info );
dim_t jc_comm_id = bli_thread_ocomm_id( pc_info );
dim_t pc_comm_id = bli_thread_ocomm_id( pb_info );
dim_t pb_comm_id = bli_thread_ocomm_id( ic_info );
dim_t ic_comm_id = bli_thread_ocomm_id( pa_info );
dim_t pa_comm_id = bli_thread_ocomm_id( jr_info );
dim_t jr_comm_id = bli_thread_ocomm_id( ir_info );
dim_t jc_work_id = bli_thread_work_id( jc_info );
dim_t pc_work_id = bli_thread_work_id( pc_info );
dim_t pb_work_id = bli_thread_work_id( pb_info );
dim_t ic_work_id = bli_thread_work_id( ic_info );
dim_t pa_work_id = bli_thread_work_id( pa_info );
dim_t jr_work_id = bli_thread_work_id( jr_info );
dim_t ir_work_id = bli_thread_work_id( ir_info );
printf( " gl jc pb kc pa ic jr \n" );
printf( "comm ids: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
gl_comm_id, jc_comm_id, pc_comm_id, pb_comm_id, ic_comm_id, pa_comm_id, jr_comm_id );
printf( "work ids: %4ld %4ld %4lu %4lu %4ld %4ld %4ld\n",
jc_work_id, pc_work_id, pb_work_id, ic_work_id, pa_work_id, jr_work_id, ir_work_id );
printf( "-------------------------------------------------\n" );
}
#endif
}
@@ -330,15 +425,16 @@ exit(1);
void bli_l3_thrinfo_free_paths
(
thrinfo_t** threads,
dim_t num
thrinfo_t** threads
)
{
dim_t n_threads = bli_thread_num_threads( threads[0] );
dim_t i;
for ( i = 0; i < num; ++i )
for ( i = 0; i < n_threads; ++i )
bli_l3_thrinfo_free( threads[i] );
bli_free_intl( threads );
}
#endif

View File

@@ -61,24 +61,22 @@
// thrinfo_t APIs specific to level-3 operations.
//
#if 0
thrinfo_t* bli_l3_thrinfo_create
(
thrcomm_t* ocomm,
dim_t ocomm_id,
thrcomm_t* icomm,
dim_t icomm_id,
dim_t n_way,
dim_t work_id,
thrinfo_t* sub_node
);
#endif
void bli_l3_thrinfo_init
(
thrinfo_t* thread,
thrcomm_t* ocomm,
dim_t ocomm_id,
thrcomm_t* icomm,
dim_t icomm_id,
dim_t n_way,
dim_t work_id,
thrinfo_t* sub_node
@@ -96,15 +94,37 @@ void bli_l3_thrinfo_free
// -----------------------------------------------------------------------------
thrinfo_t** bli_l3_thrinfo_create_paths
void bli_l3_thrinfo_create_root
(
opid_t l3_op,
side_t side
dim_t id,
thrcomm_t* gl_comm,
cntx_t* cntx,
cntl_t* cntl,
thrinfo_t** thread
);
void bli_l3_thrinfo_print_paths
(
thrinfo_t** threads
);
// -----------------------------------------------------------------------------
#if 0
thrinfo_t** bli_l3_thrinfo_create_roots
(
cntx_t* cntx,
cntl_t* cntl
);
thrinfo_t** bli_l3_thrinfo_create_full_paths
(
cntx_t* cntx
);
void bli_l3_thrinfo_free_paths
(
thrinfo_t** threads,
dim_t num
thrinfo_t** threads
);
#endif

View File

@@ -84,10 +84,10 @@ void bli_gemm_blk_var3
c,
cntx,
bli_cntl_sub_node( cntl ),
bli_thrinfo_sub_node( thread)
bli_thrinfo_sub_node( thread )
);
bli_thread_ibarrier( thread );
bli_thread_obarrier( bli_thrinfo_sub_node( thread ) );
// This variant executes multiple rank-k updates. Therefore, if the
// internal beta scalar on matrix C is non-zero, we must use it

View File

@@ -46,14 +46,21 @@ cntl_t* bli_gemm_cntl_create
if ( family == BLIS_HERK ) macro_kernel_p = bli_herk_x_ker_var2;
else if ( family == BLIS_TRMM ) macro_kernel_p = bli_trmm_xx_ker_var2;
// Create a node for the macro-kernel.
cntl_t* gemm_cntl_bp_ke = bli_gemm_cntl_obj_create
// Create two nodes for the macro-kernel.
cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_obj_create
(
BLIS_NR, // bszid not used by macro-kernel.
macro_kernel_p,
BLIS_MR, // needed for bli_thrinfo_rgrow()
NULL, // variant function pointer not used
NULL // no sub-node; this is the leaf of the tree.
);
cntl_t* gemm_cntl_bp_bu = bli_gemm_cntl_obj_create
(
BLIS_NR, // not used by macro-kernel, but needed for bli_thrinfo_rgrow()
macro_kernel_p,
gemm_cntl_bu_ke
);
// Create a node for packing matrix A.
cntl_t* gemm_cntl_packa = bli_packm_cntl_obj_create
(
@@ -66,7 +73,7 @@ cntl_t* bli_gemm_cntl_create
FALSE, // reverse iteration if lower?
BLIS_PACKED_ROW_PANELS,
BLIS_BUFFER_FOR_A_BLOCK,
gemm_cntl_bp_ke
gemm_cntl_bp_bu
);
// Create a node for partitioning the m dimension by MC.

View File

@@ -85,13 +85,19 @@ void bli_gemm_front
// Set the operation family id in the context.
bli_cntx_set_family( BLIS_GEMM, cntx );
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_GEMM, BLIS_LEFT );
dim_t n_threads = bli_thread_num_threads( infos[0] );
// Record the threading for each level within the context.
bli_cntx_set_thrloop_from_env( BLIS_GEMM, BLIS_LEFT, cntx );
// Invoke the internal back-end.
// Create the first node in the thrinfo_t tree for each thread.
//thrinfo_t** infos = bli_l3_thrinfo_create_full_paths( cntx );
//bli_l3_thrinfo_print_paths( infos );
//exit(1);
//cntl = bli_gemm_cntl_create( BLIS_GEMM );
//thrinfo_t** infos = bli_l3_thrinfo_create_roots( cntx, cntl );
// Invoke the internal back-end via the thread handler.
bli_l3_thread_decorator
(
n_threads,
bli_gemm_int,
alpha,
&a_local,
@@ -99,10 +105,12 @@ void bli_gemm_front
beta,
&c_local,
cntx,
cntl,
infos
cntl
);
//bli_l3_thrinfo_print_paths( infos );
//exit(1);
bli_l3_thrinfo_free_paths( infos, n_threads );
// Free the thrinfo_t structures.
//bli_l3_thrinfo_free_paths( infos );
}

View File

@@ -50,7 +50,6 @@ void bli_gemm_int
obj_t b_local;
obj_t c_local;
gemm_voft f;
ind_t im;
// Check parameters.
if ( bli_error_checking_is_enabled() )
@@ -102,17 +101,22 @@ void bli_gemm_int
bli_obj_scalar_apply_scalar( beta, &c_local );
}
// Create the next node in the thrinfo_t structure.
bli_thrinfo_grow( cntx, cntl, thread );
// Extract the function pointer from the current control tree node.
f = bli_cntl_var_func( cntl );
// Somewhat hackish support for 3m3, 3m2, and 4m1b method implementations.
im = bli_cntx_get_ind_method( cntx );
if ( im != BLIS_NAT )
{
if ( im == BLIS_3M3 && f == bli_gemm_packa ) f = bli_gemm3m3_packa;
else if ( im == BLIS_3M2 && f == bli_gemm_ker_var2 ) f = bli_gemm3m2_ker_var2;
else if ( im == BLIS_4M1B && f == bli_gemm_ker_var2 ) f = bli_gemm4mb_ker_var2;
ind_t im = bli_cntx_get_ind_method( cntx );
if ( im != BLIS_NAT )
{
if ( im == BLIS_3M3 && f == bli_gemm_packa ) f = bli_gemm3m3_packa;
else if ( im == BLIS_3M2 && f == bli_gemm_ker_var2 ) f = bli_gemm3m2_ker_var2;
else if ( im == BLIS_4M1B && f == bli_gemm_ker_var2 ) f = bli_gemm4mb_ker_var2;
}
}
// Invoke the variant.

View File

@@ -92,13 +92,12 @@ void bli_hemm_front
// Set the operation family id in the context.
bli_cntx_set_family( BLIS_GEMM, cntx );
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_HEMM, BLIS_LEFT );
dim_t n_threads = bli_thread_num_threads( infos[0] );
// Record the threading for each level within the context.
bli_cntx_set_thrloop_from_env( BLIS_HEMM, BLIS_LEFT, cntx );
// Invoke the internal back-end.
bli_l3_thread_decorator
(
n_threads,
bli_gemm_int,
alpha,
&a_local,
@@ -106,10 +105,7 @@ void bli_hemm_front
beta,
&c_local,
cntx,
cntl,
infos
cntl
);
bli_l3_thrinfo_free_paths( infos, n_threads );
}

View File

@@ -110,14 +110,14 @@ void bli_her2k_front
// Set the operation family id in the context.
bli_cntx_set_family( BLIS_HERK, cntx );
// Invoke herk twice, using beta only the first time.
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_HER2K, BLIS_LEFT );
dim_t n_threads = bli_thread_num_threads( infos[0] );
// Record the threading for each level within the context.
bli_cntx_set_thrloop_from_env( BLIS_HER2K, BLIS_LEFT, cntx );
// Invoke the internal back-end.
// Invoke herk twice, using beta only the first time.
// Invoke the internal back-end.
bli_l3_thread_decorator
(
n_threads,
bli_gemm_int,
alpha,
&a_local,
@@ -125,13 +125,11 @@ void bli_her2k_front
beta,
&c_local,
cntx,
cntl,
infos
cntl
);
bli_l3_thread_decorator
(
n_threads,
bli_gemm_int,
&alpha_conj,
&b_local,
@@ -139,12 +137,9 @@ void bli_her2k_front
&BLIS_ONE,
&c_local,
cntx,
cntl,
infos
cntl
);
bli_l3_thrinfo_free_paths( infos, n_threads );
// The Hermitian rank-2k product was computed as A*B'+B*A', even for
// the diagonal elements. Mathematically, the imaginary components of
// diagonal elements of a Hermitian rank-2k product should always be

View File

@@ -90,13 +90,12 @@ void bli_herk_front
// Set the operation family id in the context.
bli_cntx_set_family( BLIS_HERK, cntx );
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_HERK, BLIS_LEFT );
dim_t n_threads = bli_thread_num_threads( infos[0] );
// Record the threading for each level within the context.
bli_cntx_set_thrloop_from_env( BLIS_HERK, BLIS_LEFT, cntx );
// Invoke the internal back-end.
bli_l3_thread_decorator
(
n_threads,
bli_gemm_int,
alpha,
&a_local,
@@ -104,12 +103,9 @@ void bli_herk_front
beta,
&c_local,
cntx,
cntl,
infos
cntl
);
bli_l3_thrinfo_free_paths( infos, n_threads );
// The Hermitian rank-k product was computed as A*A', even for the
// diagonal elements. Mathematically, the imaginary components of
// diagonal elements of a Hermitian rank-k product should always be

View File

@@ -91,13 +91,12 @@ void bli_symm_front
// Set the operation family id in the context.
bli_cntx_set_family( BLIS_GEMM, cntx );
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_SYMM, BLIS_LEFT );
dim_t n_threads = bli_thread_num_threads( infos[0] );
// Record the threading for each level within the context.
bli_cntx_set_thrloop_from_env( BLIS_SYMM, BLIS_LEFT, cntx );
// Invoke the internal back-end.
bli_l3_thread_decorator
(
n_threads,
bli_gemm_int,
alpha,
&a_local,
@@ -105,10 +104,7 @@ void bli_symm_front
beta,
&c_local,
cntx,
cntl,
infos
cntl
);
bli_l3_thrinfo_free_paths( infos, n_threads );
}

View File

@@ -91,14 +91,14 @@ void bli_syr2k_front
// Set the operation family id in the context.
bli_cntx_set_family( BLIS_HERK, cntx );
// Record the threading for each level within the context.
bli_cntx_set_thrloop_from_env( BLIS_SYR2K, BLIS_LEFT, cntx );
// Invoke herk twice, using beta only the first time.
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_SYR2K, BLIS_LEFT );
dim_t n_threads = bli_thread_num_threads( infos[0] );
// Invoke the internal back-end.
bli_l3_thread_decorator
(
n_threads,
bli_gemm_int,
alpha,
&a_local,
@@ -106,13 +106,11 @@ void bli_syr2k_front
beta,
&c_local,
cntx,
cntl,
infos
cntl
);
bli_l3_thread_decorator
(
n_threads,
bli_gemm_int,
alpha,
&b_local,
@@ -120,10 +118,7 @@ void bli_syr2k_front
&BLIS_ONE,
&c_local,
cntx,
cntl,
infos
cntl
);
bli_l3_thrinfo_free_paths( infos, n_threads );
}

View File

@@ -84,13 +84,12 @@ void bli_syrk_front
// Set the operation family id in the context.
bli_cntx_set_family( BLIS_HERK, cntx );
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_SYRK, BLIS_LEFT );
dim_t n_threads = bli_thread_num_threads( infos[0] );
// Record the threading for each level within the context.
bli_cntx_set_thrloop_from_env( BLIS_SYRK, BLIS_LEFT, cntx );
// Invoke the internal back-end.
bli_l3_thread_decorator
(
n_threads,
bli_gemm_int,
alpha,
&a_local,
@@ -98,10 +97,7 @@ void bli_syrk_front
beta,
&c_local,
cntx,
cntl,
infos
cntl
);
bli_l3_thrinfo_free_paths( infos, n_threads );
}

View File

@@ -134,13 +134,12 @@ void bli_trmm_front
// Set the operation family id in the context.
bli_cntx_set_family( BLIS_TRMM, cntx );
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_TRMM, side );
dim_t n_threads = bli_thread_num_threads( infos[0] );
// Record the threading for each level within the context.
bli_cntx_set_thrloop_from_env( BLIS_TRMM, side, cntx );
// Invoke the internal back-end.
bli_l3_thread_decorator
(
n_threads,
bli_gemm_int,
alpha,
&a_local,
@@ -148,10 +147,7 @@ void bli_trmm_front
&BLIS_ZERO,
&c_local,
cntx,
cntl,
infos
cntl
);
bli_l3_thrinfo_free_paths( infos, n_threads );
}

View File

@@ -133,13 +133,12 @@ void bli_trmm3_front
// Set the operation family id in the context.
bli_cntx_set_family( BLIS_TRMM, cntx );
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_TRMM3, side );
dim_t n_threads = bli_thread_num_threads( infos[0] );
// Record the threading for each level within the context.
bli_cntx_set_thrloop_from_env( BLIS_TRMM3, side, cntx );
// Invoke the internal back-end.
bli_l3_thread_decorator
(
n_threads,
bli_gemm_int,
alpha,
&a_local,
@@ -147,10 +146,7 @@ void bli_trmm3_front
beta,
&c_local,
cntx,
cntl,
infos
cntl
);
bli_l3_thrinfo_free_paths( infos, n_threads );
}

View File

@@ -87,7 +87,8 @@ void bli_trsm_blk_var3
bli_thrinfo_sub_node( thread )
);
bli_thread_ibarrier( thread );
//bli_thread_ibarrier( thread );
bli_thread_obarrier( bli_thrinfo_sub_node( thread ) );
// This variant executes multiple rank-k updates. Therefore, if the
// internal alpha scalars on A/B and C are non-zero, we must ensure

View File

@@ -50,14 +50,21 @@ cntl_t* bli_trsm_l_cntl_create
{
void* macro_kernel_p = bli_trsm_xx_ker_var2;
// Create a node for the macro-kernel.
cntl_t* trsm_cntl_bp_ke = bli_trsm_cntl_obj_create
// Create two nodes for the macro-kernel.
cntl_t* trsm_cntl_bu_ke = bli_trsm_cntl_obj_create
(
BLIS_NR, // bszid not used by macro-kernel.
macro_kernel_p,
BLIS_MR, // needed for bli_thrinfo_rgrow()
NULL, // variant function pointer not used
NULL // no sub-node; this is the leaf of the tree.
);
cntl_t* trsm_cntl_bp_bu = bli_trsm_cntl_obj_create
(
BLIS_NR, // not used by macro-kernel, but needed for bli_thrinfo_rgrow()
macro_kernel_p,
trsm_cntl_bu_ke
);
// Create a node for packing matrix A.
cntl_t* trsm_cntl_packa = bli_packm_cntl_obj_create
(
@@ -70,7 +77,7 @@ cntl_t* bli_trsm_l_cntl_create
FALSE, // reverse iteration if lower?
BLIS_PACKED_ROW_PANELS,
BLIS_BUFFER_FOR_A_BLOCK,
trsm_cntl_bp_ke
trsm_cntl_bp_bu
);
// Create a node for partitioning the m dimension by MC.
@@ -122,14 +129,21 @@ cntl_t* bli_trsm_r_cntl_create
{
void* macro_kernel_p = bli_trsm_xx_ker_var2;
// Create a node for the macro-kernel.
cntl_t* trsm_cntl_bp_ke = bli_trsm_cntl_obj_create
// Create two nodes for the macro-kernel.
cntl_t* trsm_cntl_bu_ke = bli_trsm_cntl_obj_create
(
BLIS_NR, // bszid not used by macro-kernel.
macro_kernel_p,
BLIS_MR, // needed for bli_thrinfo_rgrow()
NULL, // variant function pointer not used
NULL // no sub-node; this is the leaf of the tree.
);
cntl_t* trsm_cntl_bp_bu = bli_trsm_cntl_obj_create
(
BLIS_NR, // not used by macro-kernel, but needed for bli_thrinfo_rgrow()
macro_kernel_p,
trsm_cntl_bu_ke
);
// Create a node for packing matrix A.
cntl_t* trsm_cntl_packa = bli_packm_cntl_obj_create
(
@@ -142,7 +156,7 @@ cntl_t* bli_trsm_r_cntl_create
FALSE, // reverse iteration if lower?
BLIS_PACKED_ROW_PANELS,
BLIS_BUFFER_FOR_A_BLOCK,
trsm_cntl_bp_ke
trsm_cntl_bp_bu
);
// Create a node for partitioning the m dimension by MC.

View File

@@ -119,13 +119,12 @@ void bli_trsm_front
// Set the operation family id in the context.
bli_cntx_set_family( BLIS_TRSM, cntx );
thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_TRSM, side );
dim_t n_threads = bli_thread_num_threads( infos[0] );
// Record the threading for each level within the context.
bli_cntx_set_thrloop_from_env( BLIS_TRSM, side, cntx );
// Invoke the internal back-end.
bli_l3_thread_decorator
(
n_threads,
bli_trsm_int,
alpha,
&a_local,
@@ -133,10 +132,7 @@ void bli_trsm_front
alpha,
&c_local,
cntx,
cntl,
infos
cntl
);
bli_l3_thrinfo_free_paths( infos, n_threads );
}

View File

@@ -117,6 +117,9 @@ void bli_trsm_int
// FGVZ->TMS: Is this barrier still needed?
bli_thread_obarrier( thread );
// Create the next node in the thrinfo_t structure.
bli_thrinfo_grow( cntx, cntl, thread );
// Extract the function pointer from the current control tree node.
f = bli_cntl_var_func( cntl );

View File

@@ -341,6 +341,37 @@ pack_t bli_cntx_get_pack_schema_b( cntx_t* cntx )
}
#endif
dim_t bli_cntx_get_num_threads( cntx_t* cntx )
{
return bli_cntx_jc_way( cntx ) *
bli_cntx_pc_way( cntx ) *
bli_cntx_ic_way( cntx ) *
bli_cntx_jr_way( cntx ) *
bli_cntx_ir_way( cntx );
}
dim_t bli_cntx_get_num_threads_in( cntx_t* cntx, cntl_t* cntl )
{
dim_t n_threads_in = 1;
for ( ; cntl != NULL; cntl = bli_cntl_sub_node( cntl ) )
{
bszid_t bszid = bli_cntl_bszid( cntl );
dim_t cur_way;
// We assume bszid is in {KR,MR,NR,MC,KC,NR} if it is not
// BLIS_NO_PART.
if ( bszid != BLIS_NO_PART )
cur_way = bli_cntx_way_for_bszid( bszid, cntx );
else
cur_way = 1;
n_threads_in *= cur_way;
}
return n_threads_in;
}
// -----------------------------------------------------------------------------
#if 1
@@ -663,6 +694,96 @@ void bli_cntx_set_pack_schema_c( pack_t schema_c,
bli_cntx_set_schema_c( schema_c, cntx );
}
void bli_cntx_set_thrloop_from_env( opid_t l3_op, side_t side, cntx_t* cntx )
{
dim_t jc, pc, ic, jr, ir;
#ifdef BLIS_ENABLE_MULTITHREADING
jc = bli_env_read_nway( "BLIS_JC_NT" );
//pc = bli_env_read_nway( "BLIS_KC_NT" );
pc = 1;
ic = bli_env_read_nway( "BLIS_IC_NT" );
jr = bli_env_read_nway( "BLIS_JR_NT" );
ir = bli_env_read_nway( "BLIS_IR_NT" );
#else
jc = 1;
pc = 1;
ic = 1;
jr = 1;
ir = 1;
#endif
if ( l3_op == BLIS_TRMM )
{
// We reconfigure the paralelism from trmm_r due to a dependency in
// the jc loop. (NOTE: This dependency does not exist for trmm3 )
if ( bli_is_right( side ) )
{
bli_cntx_set_thrloop
(
1,
pc,
ic,
jr * jc,
ir,
cntx
);
}
else // if ( bli_is_left( side ) )
{
bli_cntx_set_thrloop
(
jc,
pc,
ic,
jr,
ir,
cntx
);
}
}
else if ( l3_op == BLIS_TRSM )
{
if ( bli_is_right( side ) )
{
bli_cntx_set_thrloop
(
1,
1,
jc * ic * jr,
1,
1,
cntx
);
}
else // if ( bli_is_left( side ) )
{
bli_cntx_set_thrloop
(
1,
1,
1,
ic * jr * ir,
1,
cntx
);
}
}
else // if ( l3_op == BLIS_TRSM )
{
bli_cntx_set_thrloop
(
jc,
pc,
ic,
jr,
ir,
cntx
);
}
}
// -----------------------------------------------------------------------------
bool_t bli_cntx_l3_nat_ukr_prefers_rows_dt( num_t dt,

View File

@@ -59,6 +59,8 @@ typedef struct cntx_s
pack_t schema_b;
pack_t schema_c;
dim_t* thrloop;
membrk_t* membrk;
} cntx_t;
*/
@@ -127,6 +129,36 @@ typedef struct cntx_s
\
( (cntx)->membrk )
#define bli_cntx_thrloop( cntx ) \
\
( (cntx)->thrloop )
#if 1
#define bli_cntx_jc_way( cntx ) \
\
( (cntx)->thrloop[ BLIS_NC ] )
#define bli_cntx_pc_way( cntx ) \
\
( (cntx)->thrloop[ BLIS_KC ] )
#define bli_cntx_ic_way( cntx ) \
\
( (cntx)->thrloop[ BLIS_MC ] )
#define bli_cntx_jr_way( cntx ) \
\
( (cntx)->thrloop[ BLIS_NR ] )
#define bli_cntx_ir_way( cntx ) \
\
( (cntx)->thrloop[ BLIS_MR ] )
#endif
#define bli_cntx_way_for_bszid( bszid, cntx ) \
\
( (cntx)->thrloop[ bszid ] )
// cntx_t modification (fields only)
#define bli_cntx_set_blkszs_buf( _blkszs, cntx_p ) \
@@ -199,6 +231,16 @@ typedef struct cntx_s
(cntx_p)->membrk = _membrk; \
}
#define bli_cntx_set_thrloop( jc_, pc_, ic_, jr_, ir_, cntx_p ) \
{ \
(cntx_p)->thrloop[ BLIS_NC ] = jc_; \
(cntx_p)->thrloop[ BLIS_KC ] = pc_; \
(cntx_p)->thrloop[ BLIS_MC ] = ic_; \
(cntx_p)->thrloop[ BLIS_NR ] = jr_; \
(cntx_p)->thrloop[ BLIS_MR ] = ir_; \
(cntx_p)->thrloop[ BLIS_KR ] = 1; \
}
// cntx_t query (complex)
#define bli_cntx_get_blksz_def_dt( dt, bs_id, cntx ) \
@@ -356,6 +398,8 @@ func_t* bli_cntx_get_packm_ukr( cntx_t* cntx );
//pack_t bli_cntx_get_pack_schema_a( cntx_t* cntx );
//pack_t bli_cntx_get_pack_schema_b( cntx_t* cntx );
//pack_t bli_cntx_get_pack_schema_c( cntx_t* cntx );
dim_t bli_cntx_get_num_threads( cntx_t* cntx );
dim_t bli_cntx_get_num_threads_in( cntx_t* cntx, cntl_t* cntl );
// set functions
@@ -390,6 +434,9 @@ void bli_cntx_set_pack_schema_b( pack_t schema_b,
cntx_t* cntx );
void bli_cntx_set_pack_schema_c( pack_t schema_c,
cntx_t* cntx );
void bli_cntx_set_thrloop_from_env( opid_t l3_op,
side_t side,
cntx_t* cntx );
// other query functions

View File

@@ -638,6 +638,21 @@ typedef enum
#define BLIS_NUM_UKR_IMPL_TYPES 4
#if 0
typedef enum
{
BLIS_JC_IDX = 0,
BLIS_PC_IDX,
BLIS_IC_IDX,
BLIS_JR_IDX,
BLIS_IR_IDX,
BLIS_PR_IDX,
} thridx_t;
#endif
#define BLIS_NUM_LOOPS 6
// -- Operation ID type --
typedef enum
@@ -949,6 +964,8 @@ typedef struct cntx_s
pack_t schema_b;
pack_t schema_c;
dim_t thrloop[ BLIS_NUM_LOOPS ];
membrk_t* membrk;
} cntx_t;

View File

@@ -41,6 +41,12 @@
#include "bli_thrcomm_openmp.h"
#include "bli_thrcomm_pthreads.h"
// thrcomm_t query (field only)
#define bli_thrcomm_num_threads( comm ) ( (comm)->n_threads )
// Thread communicator prototypes.
thrcomm_t* bli_thrcomm_create( dim_t n_threads );
void bli_thrcomm_free( thrcomm_t* communicator );

View File

@@ -201,7 +201,6 @@ void bli_thrcomm_tree_barrier( barrier_t* barack )
void bli_l3_thread_decorator
(
dim_t n_threads,
l3int_t func,
obj_t* alpha,
obj_t* a,
@@ -209,20 +208,28 @@ void bli_l3_thread_decorator
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl,
thrinfo_t** thread
cntl_t* cntl
)
{
// Query the total number of threads from the context.
dim_t n_threads = bli_cntx_get_num_threads( cntx );
// Allcoate a global communicator for the root thrinfo_t structures.
thrcomm_t* gl_comm = bli_thrcomm_create( n_threads );
_Pragma( "omp parallel num_threads(n_threads)" )
{
dim_t omp_id = omp_get_thread_num();
thrinfo_t* thread_i = thread[omp_id];
dim_t id = omp_get_thread_num();
cntl_t* cntl_use;
thrinfo_t* thread;
// Create a default control tree for the operation, if needed.
bli_l3_cntl_create_if( a, b, c, cntx, cntl, &cntl_use );
// Create the root node of the current thread's thrinfo_t structure.
bli_l3_thrinfo_create_root( id, gl_comm, cntx, cntl_use, &thread );
func
(
alpha,
@@ -232,12 +239,19 @@ void bli_l3_thread_decorator
c,
cntx,
cntl_use,
thread[omp_id]
thread
);
// Free the control tree, if one was created locally.
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread_i );
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread );
// Free the current thread's thrinfo_t structure.
bli_l3_thrinfo_free( thread );
}
// We shouldn't free the global communicator since it was already freed
// by the global communicator's chief thread in bli_l3_thrinfo_free()
// (called above).
}
#endif

View File

@@ -136,7 +136,8 @@ typedef struct thread_data
obj_t* c;
cntx_t* cntx;
cntl_t* cntl;
thrinfo_t* thread;
dim_t id;
thrcomm_t* gl_comm;
} thread_data_t;
// Entry point for additional threads
@@ -151,13 +152,18 @@ void* bli_l3_thread_entry( void* data_void )
obj_t* c = data->c;
cntx_t* cntx = data->cntx;
cntl_t* cntl = data->cntl;
thrinfo_t* thread_i = data->thread;
dim_t id = data->id;
thrcomm_t* gl_comm = data->gl_comm;
cntl_t* cntl_use;
thrinfo_t* thread;
// Create a default control tree for the operation, if needed.
bli_l3_cntl_create_if( a, b, c, cntx, cntl, &cntl_use );
// Create the root node of the current thread's thrinfo_t structure.
bli_l3_thrinfo_create_root( id, gl_comm, cntx, cntl_use, &thread );
data->func
(
alpha,
@@ -171,14 +177,16 @@ void* bli_l3_thread_entry( void* data_void )
);
// Free the control tree, if one was created locally.
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread_i );
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread );
// Free the current thread's thrinfo_t structure.
bli_l3_thrinfo_free( thread );
return NULL;
}
void bli_l3_thread_decorator
(
dim_t n_threads,
l3int_t func,
obj_t* alpha,
obj_t* a,
@@ -186,50 +194,51 @@ void bli_l3_thread_decorator
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl,
thrinfo_t** thread
cntl_t* cntl
)
{
pthread_t* pthreads = bli_malloc_intl( sizeof( pthread_t ) * n_threads );
thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads );
// Query the total number of threads from the context.
dim_t n_threads = bli_cntx_get_num_threads( cntx );
for ( int i = 1; i < n_threads; i++ )
// Allocate an array of pthread objects and auxiliary data structs to pass
// to the thread entry functions.
pthread_t* pthreads = bli_malloc_intl( sizeof( pthread_t ) * n_threads );
thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads );
// Allocate a global communicator for the root thrinfo_t structures.
thrcomm_t* gl_comm = bli_thrcomm_create( n_threads );
// NOTE: We must iterate backwards so that the chief thread (thread id 0)
// can spawn all other threads before proceeding with its own computation.
for ( dim_t id = n_threads - 1; 0 <= id; id-- )
{
// Set up thread data for additional threads (beyond thread 0).
datas[i].func = func;
datas[i].alpha = alpha;
datas[i].a = a;
datas[i].b = b;
datas[i].beta = beta;
datas[i].c = c;
datas[i].cntx = cntx;
datas[i].cntl = cntl;
datas[i].thread = thread[i];
datas[id].func = func;
datas[id].alpha = alpha;
datas[id].a = a;
datas[id].b = b;
datas[id].beta = beta;
datas[id].c = c;
datas[id].cntx = cntx;
datas[id].cntl = cntl;
datas[id].id = id;
datas[id].gl_comm = gl_comm;
// Spawn additional threads.
pthread_create( &pthreads[i], NULL, &bli_l3_thread_entry, &datas[i] );
}
// The main thread executes this.
{
cntl_t* cntl_use;
// Create a default control tree for the operation, if needed.
bli_l3_cntl_create_if( a, b, c, cntx, cntl, &cntl_use );
// Thread 0 simply executes func.
func( alpha, a, b, beta, c, cntx, cntl, thread[0] );
// Free the control tree, if one was created locally.
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread[0] );
// Spawn additional threads for ids greater than 1.
if ( id != 0 )
pthread_create( &pthreads[id], NULL, &bli_l3_thread_entry, &datas[id] );
else
bli_l3_thread_entry( ( void* )(&datas[0]) );
}
// We shouldn't free the global communicator since it was already freed
// by the global communicator's chief thread in bli_l3_thrinfo_free()
// (called from the thread entry function).
// Thread 0 waits for additional threads to finish.
for ( int i = 1; i < n_threads; i++)
for ( dim_t id = 1; id < n_threads; id++ )
{
pthread_join( pthreads[i], NULL );
pthread_join( pthreads[id], NULL );
}
bli_free_intl( pthreads );

View File

@@ -73,7 +73,6 @@ void bli_thrcomm_barrier( thrcomm_t* communicator, dim_t t_id )
void bli_l3_thread_decorator
(
dim_t n_threads,
l3int_t func,
obj_t* alpha,
obj_t* a,
@@ -81,17 +80,25 @@ void bli_l3_thread_decorator
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl,
thrinfo_t** thread
cntl_t* cntl
)
{
thrinfo_t* thread_i = thread[0];
// For sequential execution, we use only one thread.
dim_t n_threads = 1;
dim_t id = 0;
// Allcoate a global communicator for the root thrinfo_t structures.
thrcomm_t* gl_comm = bli_thrcomm_create( n_threads );
cntl_t* cntl_use;
thrinfo_t* thread;
// Create a default control tree for the operation, if needed.
bli_l3_cntl_create_if( a, b, c, cntx, cntl, &cntl_use );
// Create the root node of the thread's thrinfo_t structure.
bli_l3_thrinfo_create_root( id, gl_comm, cntx, cntl_use, &thread );
func
(
alpha,
@@ -101,11 +108,18 @@ void bli_l3_thread_decorator
c,
cntx,
cntl_use,
thread[0]
thread
);
// Free the control tree, if one was created locally.
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread_i );
bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread );
// Free the current thread's thrinfo_t structure.
bli_l3_thrinfo_free( thread );
// We shouldn't free the global communicator since it was already freed
// by the global communicator's chief thread in bli_l3_thrinfo_free()
// (called above).
}

View File

@@ -78,8 +78,8 @@ void bli_thread_get_range_sub
dim_t* end
)
{
dim_t n_way = thread->n_way;
dim_t work_id = thread->work_id;
dim_t n_way = bli_thread_n_way( thread );
dim_t work_id = bli_thread_work_id( thread );
dim_t all_start = 0;
dim_t all_end = n;
@@ -511,8 +511,8 @@ siz_t bli_thread_get_range_weighted_sub
dim_t* j_end_thr
)
{
dim_t n_way = thread->n_way;
dim_t my_id = thread->work_id;
dim_t n_way = bli_thread_n_way( thread );
dim_t my_id = bli_thread_work_id( thread );
dim_t bf_left = n % bf;

View File

@@ -173,16 +173,14 @@ typedef void (*l3int_t)
// Level-3 thread decorator prototype
void bli_l3_thread_decorator
(
dim_t n_threads,
l3int_t func,
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl,
thrinfo_t** thread
l3int_t func,
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
);
// Miscellaneous prototypes

View File

@@ -38,11 +38,9 @@ thrinfo_t* bli_thrinfo_create
(
thrcomm_t* ocomm,
dim_t ocomm_id,
thrcomm_t* icomm,
dim_t icomm_id,
dim_t n_way,
dim_t work_id,
bool_t free_comms,
bool_t free_comm,
thrinfo_t* sub_node
)
{
@@ -52,9 +50,8 @@ thrinfo_t* bli_thrinfo_create
(
thread,
ocomm, ocomm_id,
icomm, icomm_id,
n_way, work_id,
free_comms,
free_comm,
sub_node
);
@@ -66,23 +63,19 @@ void bli_thrinfo_init
thrinfo_t* thread,
thrcomm_t* ocomm,
dim_t ocomm_id,
thrcomm_t* icomm,
dim_t icomm_id,
dim_t n_way,
dim_t work_id,
bool_t free_comms,
bool_t free_comm,
thrinfo_t* sub_node
)
{
thread->ocomm = ocomm;
thread->ocomm_id = ocomm_id;
thread->icomm = icomm;
thread->icomm_id = icomm_id;
thread->n_way = n_way;
thread->work_id = work_id;
thread->free_comms = free_comms;
thread->ocomm = ocomm;
thread->ocomm_id = ocomm_id;
thread->n_way = n_way;
thread->work_id = work_id;
thread->free_comm = free_comm;
thread->sub_node = sub_node;
thread->sub_node = sub_node;
}
void bli_thrinfo_init_single
@@ -94,7 +87,6 @@ void bli_thrinfo_init_single
(
thread,
&BLIS_SINGLE_COMM, 0,
&BLIS_SINGLE_COMM, 0,
1,
0,
FALSE,
@@ -102,3 +94,178 @@ void bli_thrinfo_init_single
);
}
// -----------------------------------------------------------------------------
#include "assert.h"
#define BLIS_NUM_STATIC_COMMS 18
thrinfo_t* bli_thrinfo_create_for_cntl
(
cntx_t* cntx,
cntl_t* cntl_par,
cntl_t* cntl_chl,
thrinfo_t* thread_par
)
{
thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ];
thrcomm_t** new_comms = NULL;
thrinfo_t* thread_chl;
bszid_t bszid_chl = bli_cntl_bszid( cntl_chl );
dim_t parent_nt_in = bli_thread_num_threads( thread_par );
dim_t parent_n_way = bli_thread_n_way( thread_par );
dim_t parent_comm_id = bli_thread_ocomm_id( thread_par );
dim_t parent_work_id = bli_thread_work_id( thread_par );
dim_t child_nt_in;
dim_t child_comm_id;
dim_t child_n_way;
dim_t child_work_id;
// Sanity check: make sure the number of threads in the parent's
// communicator is divisible by the number of new sub-groups.
assert( parent_nt_in % parent_n_way == 0 );
// Compute:
// - the number of threads inside the new child comm,
// - the current thread's id within the new communicator,
// - the current thread's work id, given the ways of parallelism
// to be obtained within the next loop.
child_nt_in = bli_cntx_get_num_threads_in( cntx, cntl_chl );
child_n_way = bli_cntx_way_for_bszid( bszid_chl, cntx );
child_comm_id = parent_comm_id % child_nt_in;
child_work_id = child_comm_id / ( child_nt_in / child_n_way );
// The parent's chief thread creates a temporary array of thrcomm_t
// pointers.
if ( bli_thread_am_ochief( thread_par ) )
{
if ( parent_n_way > BLIS_NUM_STATIC_COMMS )
new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ) );
else
new_comms = static_comms;
}
// Broadcast the temporary array to all threads in the parent's
// communicator.
new_comms = bli_thread_obroadcast( thread_par, new_comms );
// Chiefs in the child communicator allocate the communicator
// object and store it in the array element corresponding to the
// parent's work id.
if ( child_comm_id == 0 )
new_comms[ parent_work_id ] = bli_thrcomm_create( child_nt_in );
bli_thread_obarrier( thread_par );
// All threads create a new thrinfo_t node using the communicator
// that was created by their chief, as identified by parent_work_id.
thread_chl = bli_thrinfo_create
(
new_comms[ parent_work_id ],
child_comm_id,
child_n_way,
child_work_id,
TRUE,
NULL
);
bli_thread_obarrier( thread_par );
// The parent's chief thread frees the temporary array of thrcomm_t
// pointers.
if ( bli_thread_am_ochief( thread_par ) )
{
if ( parent_n_way > BLIS_NUM_STATIC_COMMS )
bli_free_intl( new_comms );
}
return thread_chl;
}
void bli_thrinfo_grow
(
cntx_t* cntx,
cntl_t* cntl,
thrinfo_t* thread
)
{
// If the sub-node of the thrinfo_t object is non-NULL, we don't
// need to create it, and will just use the existing sub-node as-is.
if ( bli_thrinfo_sub_node( thread ) != NULL ) return;
// Create a new node (or, if needed, multiple nodes) and return the
// pointer to the (eldest) child.
thrinfo_t* thread_child = bli_thrinfo_rgrow
(
cntx,
cntl,
bli_cntl_sub_node( cntl ),
thread
);
// Attach the child thrinfo_t node to its parent structure.
bli_thrinfo_set_sub_node( thread_child, thread );
}
thrinfo_t* bli_thrinfo_rgrow
(
cntx_t* cntx,
cntl_t* cntl_par,
cntl_t* cntl_cur,
thrinfo_t* thread_par
)
{
thrinfo_t* thread_cur;
// We must handle two cases: those where the next node in the
// control tree is a partitioning node, and those where it is
// a non-partitioning (ie: packing) node.
if ( bli_cntl_bszid( cntl_cur ) != BLIS_NO_PART )
{
// Create the child thrinfo_t node corresponding to cntl_cur,
// with cntl_par being the parent.
thread_cur = bli_thrinfo_create_for_cntl
(
cntx,
cntl_par,
cntl_cur,
thread_par
);
}
else // if ( bli_cntl_bszid( cntl_cur ) == BLIS_NO_PART )
{
// Recursively grow the thread structure and return the top-most
// thrinfo_t node of that segment.
thrinfo_t* thread_seg = bli_thrinfo_rgrow
(
cntx,
cntl_par,
bli_cntl_sub_node( cntl_cur ),
thread_par
);
// Create a thrinfo_t node corresponding to cntl_cur. Notice that
// the free_comm field is set to FALSE, since cntl_cur is a
// non-partitioning node. The communicator used here will be
// freed when thread_seg, or one of its descendents, is freed.
thread_cur = bli_thrinfo_create
(
bli_thrinfo_ocomm( thread_seg ),
bli_thread_ocomm_id( thread_seg ),
bli_cntx_get_num_threads_in( cntx, cntl_cur ),
bli_thread_ocomm_id( thread_seg ),
FALSE,
thread_seg
);
// Attach the child thrinfo_t node to its parent structure.
bli_thrinfo_set_sub_node( thread_cur, thread_par );
}
return thread_cur;
}

View File

@@ -45,13 +45,6 @@ struct thrinfo_s
// Our thread id within the ocomm thread communicator.
dim_t ocomm_id;
// The thread communicator for the other threads sharing the same work
// at this level.
thrcomm_t* icomm;
// Our thread id within the icomm thread communicator.
dim_t icomm_id;
// The number of distinct threads used to parallelize the loop.
dim_t n_way;
@@ -62,7 +55,7 @@ struct thrinfo_s
// this is field is true, but when nodes are created that share the same
// communicators as other nodes (such as with packm nodes), this is set
// to false.
bool_t free_comms;
bool_t free_comm;
struct thrinfo_s* sub_node;
};
@@ -71,30 +64,40 @@ typedef struct thrinfo_s thrinfo_t;
//
// thrinfo_t macros
// NOTE: The naming of these should be made consistent at some point.
// (ie: bli_thrinfo_ vs. bli_thread_)
//
#define bli_thread_num_threads( t ) ( (t)->ocomm->n_threads )
// thrinfo_t query (field only)
#define bli_thread_n_way( t ) ( (t)->n_way )
#define bli_thread_work_id( t ) ( (t)->work_id )
#define bli_thread_num_threads( t ) ( (t)->ocomm->n_threads )
#define bli_thread_am_ochief( t ) ( (t)->ocomm_id == 0 )
#define bli_thread_am_ichief( t ) ( (t)->icomm_id == 0 )
#define bli_thread_n_way( t ) ( (t)->n_way )
#define bli_thread_work_id( t ) ( (t)->work_id )
#define bli_thread_ocomm_id( t ) ( (t)->ocomm_id )
#define bli_thrinfo_ocomm( t ) ( (t)->ocomm )
#define bli_thrinfo_needs_free_comm( t ) ( (t)->free_comm )
#define bli_thrinfo_sub_node( t ) ( (t)->sub_node )
// thrinfo_t query (complex)
#define bli_thread_am_ochief( t ) ( (t)->ocomm_id == 0 )
// thrinfo_t modification
#define bli_thrinfo_set_sub_node( _sub_node, thread ) \
{ \
(thread)->sub_node = _sub_node; \
}
// other thrinfo_t-related macros
#define bli_thread_obroadcast( t, p ) bli_thrcomm_bcast( (t)->ocomm, \
(t)->ocomm_id, p )
#define bli_thread_ibroadcast( t, p ) bli_thrcomm_bcast( (t)->icomm, \
(t)->icomm_id, p )
#define bli_thread_obarrier( t ) bli_thrcomm_barrier( (t)->ocomm, \
(t)->ocomm_id )
#define bli_thread_ibarrier( t ) bli_thrcomm_barrier( (t)->icomm, \
(t)->icomm_id )
#define bli_thrinfo_ocomm( t ) ( (t)->ocomm )
#define bli_thrinfo_icomm( t ) ( (t)->icomm )
#define bli_thrinfo_needs_free_comms( t ) ( (t)->free_comms )
#define bli_thrinfo_sub_node( t ) ( (t)->sub_node )
//
// Prototypes for level-3 thrinfo functions not specific to any operation.
@@ -104,11 +107,9 @@ thrinfo_t* bli_thrinfo_create
(
thrcomm_t* ocomm,
dim_t ocomm_id,
thrcomm_t* icomm,
dim_t icomm_id,
dim_t n_way,
dim_t work_id,
bool_t free_comms,
bool_t free_comm,
thrinfo_t* sub_node
);
@@ -117,11 +118,9 @@ void bli_thrinfo_init
thrinfo_t* thread,
thrcomm_t* ocomm,
dim_t ocomm_id,
thrcomm_t* icomm,
dim_t icomm_id,
dim_t n_way,
dim_t work_id,
bool_t free_comms,
bool_t free_comm,
thrinfo_t* sub_node
);
@@ -130,9 +129,29 @@ void bli_thrinfo_init_single
thrinfo_t* thread
);
void bli_thrinfo_free
// -----------------------------------------------------------------------------
thrinfo_t* bli_thrinfo_create_for_cntl
(
cntx_t* cntx,
cntl_t* cntl_par,
cntl_t* cntl_chl,
thrinfo_t* thread_par
);
void bli_thrinfo_grow
(
cntx_t* cntx,
cntl_t* cntl,
thrinfo_t* thread
);
thrinfo_t* bli_thrinfo_rgrow
(
cntx_t* cntx,
cntl_t* cntl_par,
cntl_t* cntl_cur,
thrinfo_t* thread_par
);
#endif