From 8d55033c966feed99fcca2a58017c3ab5b1646dc Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Tue, 27 Sep 2016 15:20:58 -0500 Subject: [PATCH] Implemented distributed thrinfo_t management. Details: - Implemented Ricardo Magana's distributed thread info/communicator management. Rather that fully construct the thrinfo_t structures, from root to leaf, prior to spawning threads, the threads individually construct their thrinfo_t trees (or, chains), and do so incrementally, as needed, reusing the same structure nodes during subsequent blocked variant iterations. This required moving the initial creation of the thrinfo_t structure (now, the root nodes) from the _front() functions to the bli_l3_thread_decorator(). The incremental "growing" of the tree is performed in the internal back-end (ie: _int()) function, and so mostly invisible. Also, the incremental growth of the thrinfo_t tree is done as a function of the current and parent control tree nodes (as well as the parent thrinfo_t node), further reinforcing the parallel relationship between the two data structures. - Removed the "inner" communicator from thrinfo_t structure definition, as well as its id. Changed all APIs accordingly. Renamed bli_thrinfo_needs_free_comms() to bli_thrinfo_needs_free_comm(). - Defined bli_l3_thrinfo_print_paths(), which prints the information in an array of thrinfo_t* structure pointers. (Used only as a debugging/verification tool.) - Deprecated the following thrinfo_t creation functions: bli_packm_thrinfo_create() bli_l3_thrinfo_create() because they are no longer used. bli_thrinfo_create() is now called directly when creating thrinfo_t nodes. --- frame/1m/packm/bli_packm_thrinfo.c | 12 +- frame/1m/packm/bli_packm_thrinfo.h | 8 +- frame/3/bli_l3_thrinfo.c | 356 ++++++++++++++++++---------- frame/3/bli_l3_thrinfo.h | 38 ++- frame/3/gemm/bli_gemm_blk_var3.c | 4 +- frame/3/gemm/bli_gemm_cntl.c | 17 +- frame/3/gemm/bli_gemm_front.c | 22 +- frame/3/gemm/bli_gemm_int.c | 18 +- frame/3/hemm/bli_hemm_front.c | 10 +- frame/3/her2k/bli_her2k_front.c | 19 +- frame/3/herk/bli_herk_front.c | 10 +- frame/3/symm/bli_symm_front.c | 10 +- frame/3/syr2k/bli_syr2k_front.c | 15 +- frame/3/syrk/bli_syrk_front.c | 10 +- frame/3/trmm/bli_trmm_front.c | 10 +- frame/3/trmm3/bli_trmm3_front.c | 10 +- frame/3/trsm/bli_trsm_blk_var3.c | 3 +- frame/3/trsm/bli_trsm_cntl.c | 34 ++- frame/3/trsm/bli_trsm_front.c | 10 +- frame/3/trsm/bli_trsm_int.c | 3 + frame/base/bli_cntx.c | 121 ++++++++++ frame/base/bli_cntx.h | 47 ++++ frame/include/bli_type_defs.h | 17 ++ frame/thread/bli_thrcomm.h | 6 + frame/thread/bli_thrcomm_openmp.c | 28 ++- frame/thread/bli_thrcomm_pthreads.c | 83 ++++--- frame/thread/bli_thrcomm_single.c | 26 +- frame/thread/bli_thread.c | 8 +- frame/thread/bli_thread.h | 18 +- frame/thread/bli_thrinfo.c | 201 ++++++++++++++-- frame/thread/bli_thrinfo.h | 77 +++--- 31 files changed, 887 insertions(+), 364 deletions(-) diff --git a/frame/1m/packm/bli_packm_thrinfo.c b/frame/1m/packm/bli_packm_thrinfo.c index 1c1265661..2287a7222 100644 --- a/frame/1m/packm/bli_packm_thrinfo.c +++ b/frame/1m/packm/bli_packm_thrinfo.c @@ -34,12 +34,11 @@ #include "blis.h" +#if 0 thrinfo_t* bli_packm_thrinfo_create ( thrcomm_t* ocomm, dim_t ocomm_id, - thrcomm_t* icomm, - dim_t icomm_id, dim_t n_way, dim_t work_id, thrinfo_t* sub_node @@ -51,7 +50,6 @@ thrinfo_t* bli_packm_thrinfo_create ( thread, ocomm, ocomm_id, - icomm, icomm_id, n_way, work_id, FALSE, @@ -60,14 +58,13 @@ thrinfo_t* bli_packm_thrinfo_create return thread; } +#endif void bli_packm_thrinfo_init ( thrinfo_t* thread, thrcomm_t* ocomm, dim_t ocomm_id, - thrcomm_t* icomm, - dim_t icomm_id, dim_t n_way, dim_t work_id, thrinfo_t* sub_node @@ -77,7 +74,6 @@ void bli_packm_thrinfo_init ( thread, ocomm, ocomm_id, - icomm, icomm_id, n_way, work_id, FALSE, sub_node @@ -93,13 +89,13 @@ void bli_packm_thrinfo_init_single ( thread, &BLIS_SINGLE_COMM, 0, - &BLIS_SINGLE_COMM, 0, 1, 0, NULL ); } +#if 0 void bli_packm_thrinfo_free ( thrinfo_t* thread @@ -109,4 +105,4 @@ void bli_packm_thrinfo_free thread != &BLIS_PACKM_SINGLE_THREADED ) bli_free_intl( thread ); } - +#endif diff --git a/frame/1m/packm/bli_packm_thrinfo.h b/frame/1m/packm/bli_packm_thrinfo.h index 7b6d7ae4d..5da496f96 100644 --- a/frame/1m/packm/bli_packm_thrinfo.h +++ b/frame/1m/packm/bli_packm_thrinfo.h @@ -42,24 +42,22 @@ // thrinfo_t APIs specific to packm. // +#if 0 thrinfo_t* bli_packm_thrinfo_create ( thrcomm_t* ocomm, dim_t ocomm_id, - thrcomm_t* icomm, - dim_t icomm_id, dim_t n_way, dim_t work_id, thrinfo_t* sub_node ); +#endif void bli_packm_thrinfo_init ( thrinfo_t* thread, thrcomm_t* ocomm, dim_t ocomm_id, - thrcomm_t* icomm, - dim_t icomm_id, dim_t n_way, dim_t work_id, thrinfo_t* sub_node @@ -70,8 +68,10 @@ void bli_packm_thrinfo_init_single thrinfo_t* thread ); +#if 0 void bli_packm_thrinfo_free ( thrinfo_t* thread ); +#endif diff --git a/frame/3/bli_l3_thrinfo.c b/frame/3/bli_l3_thrinfo.c index 36b65b52b..78b2b775c 100644 --- a/frame/3/bli_l3_thrinfo.c +++ b/frame/3/bli_l3_thrinfo.c @@ -35,12 +35,11 @@ #include "blis.h" #include "assert.h" +#if 0 thrinfo_t* bli_l3_thrinfo_create ( thrcomm_t* ocomm, dim_t ocomm_id, - thrcomm_t* icomm, - dim_t icomm_id, dim_t n_way, dim_t work_id, thrinfo_t* sub_node @@ -49,21 +48,19 @@ thrinfo_t* bli_l3_thrinfo_create return bli_thrinfo_create ( ocomm, ocomm_id, - icomm, icomm_id, n_way, work_id, TRUE, sub_node ); } +#endif void bli_l3_thrinfo_init ( thrinfo_t* thread, thrcomm_t* ocomm, dim_t ocomm_id, - thrcomm_t* icomm, - dim_t icomm_id, dim_t n_way, dim_t work_id, thrinfo_t* sub_node @@ -73,7 +70,6 @@ void bli_l3_thrinfo_init ( thread, ocomm, ocomm_id, - icomm, icomm_id, n_way, work_id, TRUE, @@ -105,14 +101,12 @@ void bli_l3_thrinfo_free // is marked as needing them to be freed. The most common example of // thrinfo_t nodes NOT marked as needing their comms freed are those // associated with packm thrinfo_t nodes. - if ( bli_thrinfo_needs_free_comms( thread ) ) + if ( bli_thrinfo_needs_free_comm( thread ) ) { // The ochief always frees his communicator, and the ichief free its // communicator if we are at the leaf node. if ( bli_thread_am_ochief( thread ) ) bli_thrcomm_free( bli_thrinfo_ocomm( thread ) ); - if ( thrinfo_sub_node == NULL && bli_thread_am_ichief( thread ) ) - bli_thrcomm_free( bli_thrinfo_icomm( thread ) ); } // Free all children of the current thrinfo_t. @@ -124,117 +118,208 @@ void bli_l3_thrinfo_free // ----------------------------------------------------------------------------- -//#define PRINT_THRINFO - -thrinfo_t** bli_l3_thrinfo_create_paths +void bli_l3_thrinfo_create_root ( - opid_t l3_op, - side_t side + dim_t id, + thrcomm_t* gl_comm, + cntx_t* cntx, + cntl_t* cntl, + thrinfo_t** thread ) { - dim_t jc_in, jc_way; - dim_t kc_in, kc_way; - dim_t ic_in, ic_way; - dim_t jr_in, jr_way; - dim_t ir_in, ir_way; + // Query the global communicator for the total number of threads to use. + dim_t n_threads = bli_thrcomm_num_threads( gl_comm ); -#ifdef BLIS_ENABLE_MULTITHREADING - jc_in = bli_env_read_nway( "BLIS_JC_NT" ); - //kc_way = bli_env_read_nway( "BLIS_KC_NT" ); - kc_in = 1; - ic_in = bli_env_read_nway( "BLIS_IC_NT" ); - jr_in = bli_env_read_nway( "BLIS_JR_NT" ); - ir_in = bli_env_read_nway( "BLIS_IR_NT" ); -#else - jc_in = 1; - kc_in = 1; - ic_in = 1; - jr_in = 1; - ir_in = 1; -#endif + // Use the thread id passed in as the global communicator id. + dim_t gl_comm_id = id; - if ( l3_op == BLIS_TRMM ) - { - // We reconfigure the parallelism for trmm_r due to a dependency in - // the jc loop. (NOTE: This dependency does not exist for trmm3.) - if ( bli_is_right( side ) ) - { - jc_way = 1; - kc_way = kc_in; - ic_way = ic_in; - jr_way = jr_in * jc_in; - ir_way = ir_in; - } - else // if ( bli_is_left( side ) ) - { - jc_way = jc_in; - kc_way = kc_in; - ic_way = ic_in; - jr_way = jr_in; - ir_way = ir_in; - } - } - else if ( l3_op == BLIS_TRSM ) - { - if ( bli_is_right( side ) ) - { + // Use the blocksize id of the current (root) control tree node to + // query the top-most ways of parallelism to obtain. + bszid_t bszid = bli_cntl_bszid( cntl ); + dim_t xx_way = bli_cntx_way_for_bszid( bszid, cntx ); - jc_way = 1; - kc_way = 1; - ic_way = jc_in * ic_in * jr_in; - jr_way = 1; - ir_way = 1; - } - else // if ( bli_is_left( side ) ) - { - jc_way = 1; - kc_way = 1; - ic_way = 1; - jr_way = ic_in * jr_in * ir_in; - ir_way = 1; - } - } - else // all other level-3 operations + // Determine the work id for this thrinfo_t node. + dim_t work_id = gl_comm_id / ( n_threads / xx_way ); + + // Create the root thrinfo_t node. + *thread = bli_thrinfo_create + ( + gl_comm, + gl_comm_id, + xx_way, + work_id, + TRUE, + NULL + ); +} + +// ----------------------------------------------------------------------------- + +void bli_l3_thrinfo_print_paths + ( + thrinfo_t** threads + ) +{ + dim_t n_threads = bli_thread_num_threads( threads[0] ); + dim_t gl_comm_id; + + thrinfo_t* jc_info = threads[0]; + thrinfo_t* pc_info = bli_thrinfo_sub_node( jc_info ); + thrinfo_t* pb_info = bli_thrinfo_sub_node( pc_info ); + thrinfo_t* ic_info = bli_thrinfo_sub_node( pb_info ); + thrinfo_t* pa_info = bli_thrinfo_sub_node( ic_info ); + thrinfo_t* jr_info = bli_thrinfo_sub_node( pa_info ); + thrinfo_t* ir_info = bli_thrinfo_sub_node( jr_info ); + + dim_t jc_way = bli_thread_n_way( jc_info ); + dim_t pc_way = bli_thread_n_way( pc_info ); + dim_t pb_way = bli_thread_n_way( pb_info ); + dim_t ic_way = bli_thread_n_way( ic_info ); + dim_t pa_way = bli_thread_n_way( pa_info ); + dim_t jr_way = bli_thread_n_way( jr_info ); + dim_t ir_way = bli_thread_n_way( ir_info ); + + dim_t gl_nt = bli_thread_num_threads( jc_info ); + dim_t jc_nt = bli_thread_num_threads( pc_info ); + dim_t pc_nt = bli_thread_num_threads( pb_info ); + dim_t pb_nt = bli_thread_num_threads( ic_info ); + dim_t ic_nt = bli_thread_num_threads( pa_info ); + dim_t pa_nt = bli_thread_num_threads( jr_info ); + dim_t jr_nt = bli_thread_num_threads( ir_info ); + + printf( " gl jc kc pb ic pa jr ir\n" ); + printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n", + gl_nt, jc_nt, pc_nt, pb_nt, ic_nt, pa_nt, jr_nt, (dim_t)1 ); + printf( "\n" ); + printf( " jc kc pb ic pa jr ir\n" ); + printf( "xx_way: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n", + jc_way, pc_way, pb_way, ic_way, pa_way, jr_way, ir_way ); + printf( "=================================================\n" ); + + for ( gl_comm_id = 0; gl_comm_id < n_threads; ++gl_comm_id ) { - jc_way = jc_in; - kc_way = kc_in; - ic_way = ic_in; - jr_way = jr_in; - ir_way = ir_in; + jc_info = threads[gl_comm_id]; + pc_info = bli_thrinfo_sub_node( jc_info ); + pb_info = bli_thrinfo_sub_node( pc_info ); + ic_info = bli_thrinfo_sub_node( pb_info ); + pa_info = bli_thrinfo_sub_node( ic_info ); + jr_info = bli_thrinfo_sub_node( pa_info ); + ir_info = bli_thrinfo_sub_node( jr_info ); + + dim_t gl_comm_id = bli_thread_ocomm_id( jc_info ); + dim_t jc_comm_id = bli_thread_ocomm_id( pc_info ); + dim_t pc_comm_id = bli_thread_ocomm_id( pb_info ); + dim_t pb_comm_id = bli_thread_ocomm_id( ic_info ); + dim_t ic_comm_id = bli_thread_ocomm_id( pa_info ); + dim_t pa_comm_id = bli_thread_ocomm_id( jr_info ); + dim_t jr_comm_id = bli_thread_ocomm_id( ir_info ); + + dim_t jc_work_id = bli_thread_work_id( jc_info ); + dim_t pc_work_id = bli_thread_work_id( pc_info ); + dim_t pb_work_id = bli_thread_work_id( pb_info ); + dim_t ic_work_id = bli_thread_work_id( ic_info ); + dim_t pa_work_id = bli_thread_work_id( pa_info ); + dim_t jr_work_id = bli_thread_work_id( jr_info ); + dim_t ir_work_id = bli_thread_work_id( ir_info ); + +printf( " gl jc pb kc pa ic jr \n" ); +printf( "comm ids: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n", +gl_comm_id, jc_comm_id, pc_comm_id, pb_comm_id, ic_comm_id, pa_comm_id, jr_comm_id ); +printf( "work ids: %4ld %4ld %4lu %4lu %4ld %4ld %4ld\n", +jc_work_id, pc_work_id, pb_work_id, ic_work_id, pa_work_id, jr_work_id, ir_work_id ); +printf( "---------------------------------------\n" ); } +} - dim_t global_num_threads = jc_way * kc_way * ic_way * jr_way * ir_way; - assert( global_num_threads != 0 ); +// ----------------------------------------------------------------------------- - dim_t jc_nt = kc_way * ic_way * jr_way * ir_way; - dim_t kc_nt = ic_way * jr_way * ir_way; +#if 0 +thrinfo_t** bli_l3_thrinfo_create_roots + ( + cntx_t* cntx, + cntl_t* cntl + ) +{ + // Query the context for the total number of threads to use. + dim_t n_threads = bli_cntx_get_num_threads( cntx ); + + // Create a global thread communicator for all the threads. + thrcomm_t* gl_comm = bli_thrcomm_create( n_threads ); + + // Allocate an array of thrinfo_t pointers, one for each thread. + thrinfo_t** paths = bli_malloc_intl( n_threads * sizeof( thrinfo_t* ) ); + + // Use the blocksize id of the current (root) control tree node to + // query the top-most ways of parallelism to obtain. + bszid_t bszid = bli_cntl_bszid( cntl ); + dim_t xx_way = bli_cntx_way_for_bszid( bszid, cntx ); + + dim_t gl_comm_id; + + // Create one thrinfo_t node for each thread in the (global) communicator. + for ( gl_comm_id = 0; gl_comm_id < n_threads; ++gl_comm_id ) + { + dim_t work_id = gl_comm_id / ( n_threads / xx_way ); + + paths[ gl_comm_id ] = bli_thrinfo_create + ( + gl_comm, + gl_comm_id, + xx_way, + work_id, + TRUE, + NULL + ); + } + + return paths; +} + +//#define PRINT_THRINFO + +thrinfo_t** bli_l3_thrinfo_create_full_paths + ( + cntx_t* cntx + ) +{ + dim_t jc_way = bli_cntx_jc_way( cntx ); + dim_t pc_way = bli_cntx_pc_way( cntx ); + dim_t ic_way = bli_cntx_ic_way( cntx ); + dim_t jr_way = bli_cntx_jr_way( cntx ); + dim_t ir_way = bli_cntx_ir_way( cntx ); + + dim_t gl_nt = jc_way * pc_way * ic_way * jr_way * ir_way; + dim_t jc_nt = pc_way * ic_way * jr_way * ir_way; + dim_t pc_nt = ic_way * jr_way * ir_way; dim_t ic_nt = jr_way * ir_way; dim_t jr_nt = ir_way; dim_t ir_nt = 1; + assert( gl_nt != 0 ); + #ifdef PRINT_THRINFO -printf( " jc kc ic jr ir\n" ); -printf( "xx_way: %4lu %4lu %4lu %4lu %4lu\n", - jc_way, kc_way, ic_way, jr_way, ir_way ); +printf( " gl jc kc pb ic pa jr ir\n" ); +printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n", +gl_nt, jc_nt, pc_nt, pc_nt, ic_nt, ic_nt, jr_nt, ir_nt ); printf( "\n" ); -printf( " gl jc kc ic jr ir\n" ); -printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu\n", -global_num_threads, jc_nt, kc_nt, ic_nt, jr_nt, ir_nt ); -printf( "=======================================\n" ); +printf( " jc kc pb ic pa jr ir\n" ); +printf( "xx_way: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n", +jc_way, pc_way, (dim_t)0, ic_way, (dim_t)0, jr_way, ir_way ); +printf( "=================================================\n" ); #endif - thrinfo_t** paths = bli_malloc_intl( global_num_threads * sizeof( thrinfo_t* ) ); + thrinfo_t** paths = bli_malloc_intl( gl_nt * sizeof( thrinfo_t* ) ); - thrcomm_t* global_comm = bli_thrcomm_create( global_num_threads ); + thrcomm_t* gl_comm = bli_thrcomm_create( gl_nt ); for( int a = 0; a < jc_way; a++ ) { thrcomm_t* jc_comm = bli_thrcomm_create( jc_nt ); - for( int b = 0; b < kc_way; b++ ) + for( int b = 0; b < pc_way; b++ ) { - thrcomm_t* kc_comm = bli_thrcomm_create( kc_nt ); + thrcomm_t* pc_comm = bli_thrcomm_create( pc_nt ); for( int c = 0; c < ic_way; c++ ) { @@ -246,73 +331,83 @@ printf( "=======================================\n" ); for( int e = 0; e < ir_way; e++ ) { - thrcomm_t* ir_comm = bli_thrcomm_create( ir_nt ); - - dim_t ir_comm_id = 0; - dim_t jr_comm_id = e*ir_nt + ir_comm_id; - dim_t ic_comm_id = d*jr_nt + jr_comm_id; - dim_t kc_comm_id = c*ic_nt + ic_comm_id; - dim_t jc_comm_id = b*kc_nt + kc_comm_id; - dim_t global_comm_id = a*jc_nt + jc_comm_id; + //thrcomm_t* ir_comm = bli_thrcomm_create( ir_nt ); + dim_t ir_comm_id = 0; + dim_t jr_comm_id = e*ir_nt + ir_comm_id; + dim_t ic_comm_id = d*jr_nt + jr_comm_id; + dim_t pc_comm_id = c*ic_nt + ic_comm_id; + dim_t jc_comm_id = b*pc_nt + pc_comm_id; + dim_t gl_comm_id = a*jc_nt + jc_comm_id; // macro-kernel loops thrinfo_t* ir_info = bli_l3_thrinfo_create( jr_comm, jr_comm_id, - ir_comm, ir_comm_id, ir_way, e, NULL ); thrinfo_t* jr_info = bli_l3_thrinfo_create( ic_comm, ic_comm_id, - jr_comm, jr_comm_id, jr_way, d, ir_info ); // packa - thrinfo_t* pack_ic_in + thrinfo_t* pa_info = bli_packm_thrinfo_create( ic_comm, ic_comm_id, - jr_comm, jr_comm_id, ic_nt, ic_comm_id, jr_info ); // blk_var1 thrinfo_t* ic_info = - bli_l3_thrinfo_create( kc_comm, kc_comm_id, - ic_comm, ic_comm_id, + bli_l3_thrinfo_create( pc_comm, pc_comm_id, ic_way, c, - pack_ic_in ); + pa_info ); // packb - thrinfo_t* pack_kc_in + thrinfo_t* pb_info = - bli_packm_thrinfo_create( kc_comm, kc_comm_id, - ic_comm, ic_comm_id, - kc_nt, kc_comm_id, + bli_packm_thrinfo_create( pc_comm, pc_comm_id, + pc_nt, pc_comm_id, ic_info ); // blk_var3 - thrinfo_t* kc_info + thrinfo_t* pc_info = bli_l3_thrinfo_create( jc_comm, jc_comm_id, - kc_comm, kc_comm_id, - kc_way, b, - pack_kc_in ); + pc_way, b, + pb_info ); // blk_var2 thrinfo_t* jc_info = - bli_l3_thrinfo_create( global_comm, global_comm_id, - jc_comm, jc_comm_id, + bli_l3_thrinfo_create( gl_comm, gl_comm_id, jc_way, a, - kc_info ); + pc_info ); - paths[global_comm_id] = jc_info; + paths[gl_comm_id] = jc_info; #ifdef PRINT_THRINFO -printf( " gl jc kc ic jr ir\n" ); -printf( "comm ids: %4lu %4lu %4lu %4lu %4lu %4lu\n", -global_comm_id, jc_comm_id, kc_comm_id, ic_comm_id, jr_comm_id, ir_comm_id ); -//printf( " a b c d e\n" ); -printf( "work ids: %4ld %4ld %4ld %4ld %4ld\n", (long int)a, (long int)b, (long int)c, (long int)d, (long int)e ); -printf( "---------------------------------------\n" ); +{ +dim_t gl_comm_id = bli_thread_ocomm_id( jc_info ); +dim_t jc_comm_id = bli_thread_ocomm_id( pc_info ); +dim_t pc_comm_id = bli_thread_ocomm_id( pb_info ); +dim_t pb_comm_id = bli_thread_ocomm_id( ic_info ); +dim_t ic_comm_id = bli_thread_ocomm_id( pa_info ); +dim_t pa_comm_id = bli_thread_ocomm_id( jr_info ); +dim_t jr_comm_id = bli_thread_ocomm_id( ir_info ); + +dim_t jc_work_id = bli_thread_work_id( jc_info ); +dim_t pc_work_id = bli_thread_work_id( pc_info ); +dim_t pb_work_id = bli_thread_work_id( pb_info ); +dim_t ic_work_id = bli_thread_work_id( ic_info ); +dim_t pa_work_id = bli_thread_work_id( pa_info ); +dim_t jr_work_id = bli_thread_work_id( jr_info ); +dim_t ir_work_id = bli_thread_work_id( ir_info ); + +printf( " gl jc pb kc pa ic jr \n" ); +printf( "comm ids: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n", +gl_comm_id, jc_comm_id, pc_comm_id, pb_comm_id, ic_comm_id, pa_comm_id, jr_comm_id ); +printf( "work ids: %4ld %4ld %4lu %4lu %4ld %4ld %4ld\n", +jc_work_id, pc_work_id, pb_work_id, ic_work_id, pa_work_id, jr_work_id, ir_work_id ); +printf( "-------------------------------------------------\n" ); +} #endif } @@ -330,15 +425,16 @@ exit(1); void bli_l3_thrinfo_free_paths ( - thrinfo_t** threads, - dim_t num + thrinfo_t** threads ) { + dim_t n_threads = bli_thread_num_threads( threads[0] ); dim_t i; - for ( i = 0; i < num; ++i ) + for ( i = 0; i < n_threads; ++i ) bli_l3_thrinfo_free( threads[i] ); bli_free_intl( threads ); } +#endif diff --git a/frame/3/bli_l3_thrinfo.h b/frame/3/bli_l3_thrinfo.h index 7eac72298..71dea7645 100644 --- a/frame/3/bli_l3_thrinfo.h +++ b/frame/3/bli_l3_thrinfo.h @@ -61,24 +61,22 @@ // thrinfo_t APIs specific to level-3 operations. // +#if 0 thrinfo_t* bli_l3_thrinfo_create ( thrcomm_t* ocomm, dim_t ocomm_id, - thrcomm_t* icomm, - dim_t icomm_id, dim_t n_way, dim_t work_id, thrinfo_t* sub_node ); +#endif void bli_l3_thrinfo_init ( thrinfo_t* thread, thrcomm_t* ocomm, dim_t ocomm_id, - thrcomm_t* icomm, - dim_t icomm_id, dim_t n_way, dim_t work_id, thrinfo_t* sub_node @@ -96,15 +94,37 @@ void bli_l3_thrinfo_free // ----------------------------------------------------------------------------- -thrinfo_t** bli_l3_thrinfo_create_paths +void bli_l3_thrinfo_create_root ( - opid_t l3_op, - side_t side + dim_t id, + thrcomm_t* gl_comm, + cntx_t* cntx, + cntl_t* cntl, + thrinfo_t** thread + ); + +void bli_l3_thrinfo_print_paths + ( + thrinfo_t** threads + ); + +// ----------------------------------------------------------------------------- + +#if 0 +thrinfo_t** bli_l3_thrinfo_create_roots + ( + cntx_t* cntx, + cntl_t* cntl + ); + +thrinfo_t** bli_l3_thrinfo_create_full_paths + ( + cntx_t* cntx ); void bli_l3_thrinfo_free_paths ( - thrinfo_t** threads, - dim_t num + thrinfo_t** threads ); +#endif diff --git a/frame/3/gemm/bli_gemm_blk_var3.c b/frame/3/gemm/bli_gemm_blk_var3.c index 7be9c6a58..0148428df 100644 --- a/frame/3/gemm/bli_gemm_blk_var3.c +++ b/frame/3/gemm/bli_gemm_blk_var3.c @@ -84,10 +84,10 @@ void bli_gemm_blk_var3 c, cntx, bli_cntl_sub_node( cntl ), - bli_thrinfo_sub_node( thread) + bli_thrinfo_sub_node( thread ) ); - bli_thread_ibarrier( thread ); + bli_thread_obarrier( bli_thrinfo_sub_node( thread ) ); // This variant executes multiple rank-k updates. Therefore, if the // internal beta scalar on matrix C is non-zero, we must use it diff --git a/frame/3/gemm/bli_gemm_cntl.c b/frame/3/gemm/bli_gemm_cntl.c index 3f3773418..b3494b174 100644 --- a/frame/3/gemm/bli_gemm_cntl.c +++ b/frame/3/gemm/bli_gemm_cntl.c @@ -46,14 +46,21 @@ cntl_t* bli_gemm_cntl_create if ( family == BLIS_HERK ) macro_kernel_p = bli_herk_x_ker_var2; else if ( family == BLIS_TRMM ) macro_kernel_p = bli_trmm_xx_ker_var2; - // Create a node for the macro-kernel. - cntl_t* gemm_cntl_bp_ke = bli_gemm_cntl_obj_create + // Create two nodes for the macro-kernel. + cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_obj_create ( - BLIS_NR, // bszid not used by macro-kernel. - macro_kernel_p, + BLIS_MR, // needed for bli_thrinfo_rgrow() + NULL, // variant function pointer not used NULL // no sub-node; this is the leaf of the tree. ); + cntl_t* gemm_cntl_bp_bu = bli_gemm_cntl_obj_create + ( + BLIS_NR, // not used by macro-kernel, but needed for bli_thrinfo_rgrow() + macro_kernel_p, + gemm_cntl_bu_ke + ); + // Create a node for packing matrix A. cntl_t* gemm_cntl_packa = bli_packm_cntl_obj_create ( @@ -66,7 +73,7 @@ cntl_t* bli_gemm_cntl_create FALSE, // reverse iteration if lower? BLIS_PACKED_ROW_PANELS, BLIS_BUFFER_FOR_A_BLOCK, - gemm_cntl_bp_ke + gemm_cntl_bp_bu ); // Create a node for partitioning the m dimension by MC. diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index 0782d7272..324655655 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -85,13 +85,19 @@ void bli_gemm_front // Set the operation family id in the context. bli_cntx_set_family( BLIS_GEMM, cntx ); - thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_GEMM, BLIS_LEFT ); - dim_t n_threads = bli_thread_num_threads( infos[0] ); + // Record the threading for each level within the context. + bli_cntx_set_thrloop_from_env( BLIS_GEMM, BLIS_LEFT, cntx ); - // Invoke the internal back-end. + // Create the first node in the thrinfo_t tree for each thread. +//thrinfo_t** infos = bli_l3_thrinfo_create_full_paths( cntx ); +//bli_l3_thrinfo_print_paths( infos ); +//exit(1); +//cntl = bli_gemm_cntl_create( BLIS_GEMM ); + //thrinfo_t** infos = bli_l3_thrinfo_create_roots( cntx, cntl ); + + // Invoke the internal back-end via the thread handler. bli_l3_thread_decorator ( - n_threads, bli_gemm_int, alpha, &a_local, @@ -99,10 +105,12 @@ void bli_gemm_front beta, &c_local, cntx, - cntl, - infos + cntl ); +//bli_l3_thrinfo_print_paths( infos ); +//exit(1); - bli_l3_thrinfo_free_paths( infos, n_threads ); + // Free the thrinfo_t structures. + //bli_l3_thrinfo_free_paths( infos ); } diff --git a/frame/3/gemm/bli_gemm_int.c b/frame/3/gemm/bli_gemm_int.c index 18e531879..b24f2a25d 100644 --- a/frame/3/gemm/bli_gemm_int.c +++ b/frame/3/gemm/bli_gemm_int.c @@ -50,7 +50,6 @@ void bli_gemm_int obj_t b_local; obj_t c_local; gemm_voft f; - ind_t im; // Check parameters. if ( bli_error_checking_is_enabled() ) @@ -102,17 +101,22 @@ void bli_gemm_int bli_obj_scalar_apply_scalar( beta, &c_local ); } + // Create the next node in the thrinfo_t structure. + bli_thrinfo_grow( cntx, cntl, thread ); + // Extract the function pointer from the current control tree node. f = bli_cntl_var_func( cntl ); // Somewhat hackish support for 3m3, 3m2, and 4m1b method implementations. - im = bli_cntx_get_ind_method( cntx ); - - if ( im != BLIS_NAT ) { - if ( im == BLIS_3M3 && f == bli_gemm_packa ) f = bli_gemm3m3_packa; - else if ( im == BLIS_3M2 && f == bli_gemm_ker_var2 ) f = bli_gemm3m2_ker_var2; - else if ( im == BLIS_4M1B && f == bli_gemm_ker_var2 ) f = bli_gemm4mb_ker_var2; + ind_t im = bli_cntx_get_ind_method( cntx ); + + if ( im != BLIS_NAT ) + { + if ( im == BLIS_3M3 && f == bli_gemm_packa ) f = bli_gemm3m3_packa; + else if ( im == BLIS_3M2 && f == bli_gemm_ker_var2 ) f = bli_gemm3m2_ker_var2; + else if ( im == BLIS_4M1B && f == bli_gemm_ker_var2 ) f = bli_gemm4mb_ker_var2; + } } // Invoke the variant. diff --git a/frame/3/hemm/bli_hemm_front.c b/frame/3/hemm/bli_hemm_front.c index ed7e03b9c..8bede097b 100644 --- a/frame/3/hemm/bli_hemm_front.c +++ b/frame/3/hemm/bli_hemm_front.c @@ -92,13 +92,12 @@ void bli_hemm_front // Set the operation family id in the context. bli_cntx_set_family( BLIS_GEMM, cntx ); - thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_HEMM, BLIS_LEFT ); - dim_t n_threads = bli_thread_num_threads( infos[0] ); + // Record the threading for each level within the context. + bli_cntx_set_thrloop_from_env( BLIS_HEMM, BLIS_LEFT, cntx ); // Invoke the internal back-end. bli_l3_thread_decorator ( - n_threads, bli_gemm_int, alpha, &a_local, @@ -106,10 +105,7 @@ void bli_hemm_front beta, &c_local, cntx, - cntl, - infos + cntl ); - - bli_l3_thrinfo_free_paths( infos, n_threads ); } diff --git a/frame/3/her2k/bli_her2k_front.c b/frame/3/her2k/bli_her2k_front.c index f72dedf87..7350b5785 100644 --- a/frame/3/her2k/bli_her2k_front.c +++ b/frame/3/her2k/bli_her2k_front.c @@ -110,14 +110,14 @@ void bli_her2k_front // Set the operation family id in the context. bli_cntx_set_family( BLIS_HERK, cntx ); - // Invoke herk twice, using beta only the first time. - thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_HER2K, BLIS_LEFT ); - dim_t n_threads = bli_thread_num_threads( infos[0] ); + // Record the threading for each level within the context. + bli_cntx_set_thrloop_from_env( BLIS_HER2K, BLIS_LEFT, cntx ); - // Invoke the internal back-end. + // Invoke herk twice, using beta only the first time. + + // Invoke the internal back-end. bli_l3_thread_decorator ( - n_threads, bli_gemm_int, alpha, &a_local, @@ -125,13 +125,11 @@ void bli_her2k_front beta, &c_local, cntx, - cntl, - infos + cntl ); bli_l3_thread_decorator ( - n_threads, bli_gemm_int, &alpha_conj, &b_local, @@ -139,12 +137,9 @@ void bli_her2k_front &BLIS_ONE, &c_local, cntx, - cntl, - infos + cntl ); - bli_l3_thrinfo_free_paths( infos, n_threads ); - // The Hermitian rank-2k product was computed as A*B'+B*A', even for // the diagonal elements. Mathematically, the imaginary components of // diagonal elements of a Hermitian rank-2k product should always be diff --git a/frame/3/herk/bli_herk_front.c b/frame/3/herk/bli_herk_front.c index 3abfa9baf..7fcd2d356 100644 --- a/frame/3/herk/bli_herk_front.c +++ b/frame/3/herk/bli_herk_front.c @@ -90,13 +90,12 @@ void bli_herk_front // Set the operation family id in the context. bli_cntx_set_family( BLIS_HERK, cntx ); - thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_HERK, BLIS_LEFT ); - dim_t n_threads = bli_thread_num_threads( infos[0] ); + // Record the threading for each level within the context. + bli_cntx_set_thrloop_from_env( BLIS_HERK, BLIS_LEFT, cntx ); // Invoke the internal back-end. bli_l3_thread_decorator ( - n_threads, bli_gemm_int, alpha, &a_local, @@ -104,12 +103,9 @@ void bli_herk_front beta, &c_local, cntx, - cntl, - infos + cntl ); - bli_l3_thrinfo_free_paths( infos, n_threads ); - // The Hermitian rank-k product was computed as A*A', even for the // diagonal elements. Mathematically, the imaginary components of // diagonal elements of a Hermitian rank-k product should always be diff --git a/frame/3/symm/bli_symm_front.c b/frame/3/symm/bli_symm_front.c index b864ce06a..cd2f3a20e 100644 --- a/frame/3/symm/bli_symm_front.c +++ b/frame/3/symm/bli_symm_front.c @@ -91,13 +91,12 @@ void bli_symm_front // Set the operation family id in the context. bli_cntx_set_family( BLIS_GEMM, cntx ); - thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_SYMM, BLIS_LEFT ); - dim_t n_threads = bli_thread_num_threads( infos[0] ); + // Record the threading for each level within the context. + bli_cntx_set_thrloop_from_env( BLIS_SYMM, BLIS_LEFT, cntx ); // Invoke the internal back-end. bli_l3_thread_decorator ( - n_threads, bli_gemm_int, alpha, &a_local, @@ -105,10 +104,7 @@ void bli_symm_front beta, &c_local, cntx, - cntl, - infos + cntl ); - - bli_l3_thrinfo_free_paths( infos, n_threads ); } diff --git a/frame/3/syr2k/bli_syr2k_front.c b/frame/3/syr2k/bli_syr2k_front.c index 936c43635..47ce91795 100644 --- a/frame/3/syr2k/bli_syr2k_front.c +++ b/frame/3/syr2k/bli_syr2k_front.c @@ -91,14 +91,14 @@ void bli_syr2k_front // Set the operation family id in the context. bli_cntx_set_family( BLIS_HERK, cntx ); + // Record the threading for each level within the context. + bli_cntx_set_thrloop_from_env( BLIS_SYR2K, BLIS_LEFT, cntx ); + // Invoke herk twice, using beta only the first time. - thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_SYR2K, BLIS_LEFT ); - dim_t n_threads = bli_thread_num_threads( infos[0] ); // Invoke the internal back-end. bli_l3_thread_decorator ( - n_threads, bli_gemm_int, alpha, &a_local, @@ -106,13 +106,11 @@ void bli_syr2k_front beta, &c_local, cntx, - cntl, - infos + cntl ); bli_l3_thread_decorator ( - n_threads, bli_gemm_int, alpha, &b_local, @@ -120,10 +118,7 @@ void bli_syr2k_front &BLIS_ONE, &c_local, cntx, - cntl, - infos + cntl ); - - bli_l3_thrinfo_free_paths( infos, n_threads ); } diff --git a/frame/3/syrk/bli_syrk_front.c b/frame/3/syrk/bli_syrk_front.c index 8b379ab0e..f037eb1c1 100644 --- a/frame/3/syrk/bli_syrk_front.c +++ b/frame/3/syrk/bli_syrk_front.c @@ -84,13 +84,12 @@ void bli_syrk_front // Set the operation family id in the context. bli_cntx_set_family( BLIS_HERK, cntx ); - thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_SYRK, BLIS_LEFT ); - dim_t n_threads = bli_thread_num_threads( infos[0] ); + // Record the threading for each level within the context. + bli_cntx_set_thrloop_from_env( BLIS_SYRK, BLIS_LEFT, cntx ); // Invoke the internal back-end. bli_l3_thread_decorator ( - n_threads, bli_gemm_int, alpha, &a_local, @@ -98,10 +97,7 @@ void bli_syrk_front beta, &c_local, cntx, - cntl, - infos + cntl ); - - bli_l3_thrinfo_free_paths( infos, n_threads ); } diff --git a/frame/3/trmm/bli_trmm_front.c b/frame/3/trmm/bli_trmm_front.c index 689acbb72..c7231c839 100644 --- a/frame/3/trmm/bli_trmm_front.c +++ b/frame/3/trmm/bli_trmm_front.c @@ -134,13 +134,12 @@ void bli_trmm_front // Set the operation family id in the context. bli_cntx_set_family( BLIS_TRMM, cntx ); - thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_TRMM, side ); - dim_t n_threads = bli_thread_num_threads( infos[0] ); + // Record the threading for each level within the context. + bli_cntx_set_thrloop_from_env( BLIS_TRMM, side, cntx ); // Invoke the internal back-end. bli_l3_thread_decorator ( - n_threads, bli_gemm_int, alpha, &a_local, @@ -148,10 +147,7 @@ void bli_trmm_front &BLIS_ZERO, &c_local, cntx, - cntl, - infos + cntl ); - - bli_l3_thrinfo_free_paths( infos, n_threads ); } diff --git a/frame/3/trmm3/bli_trmm3_front.c b/frame/3/trmm3/bli_trmm3_front.c index e9e9261f0..cf97bbcf2 100644 --- a/frame/3/trmm3/bli_trmm3_front.c +++ b/frame/3/trmm3/bli_trmm3_front.c @@ -133,13 +133,12 @@ void bli_trmm3_front // Set the operation family id in the context. bli_cntx_set_family( BLIS_TRMM, cntx ); - thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_TRMM3, side ); - dim_t n_threads = bli_thread_num_threads( infos[0] ); + // Record the threading for each level within the context. + bli_cntx_set_thrloop_from_env( BLIS_TRMM3, side, cntx ); // Invoke the internal back-end. bli_l3_thread_decorator ( - n_threads, bli_gemm_int, alpha, &a_local, @@ -147,10 +146,7 @@ void bli_trmm3_front beta, &c_local, cntx, - cntl, - infos + cntl ); - - bli_l3_thrinfo_free_paths( infos, n_threads ); } diff --git a/frame/3/trsm/bli_trsm_blk_var3.c b/frame/3/trsm/bli_trsm_blk_var3.c index 9d726389f..7b428c8ef 100644 --- a/frame/3/trsm/bli_trsm_blk_var3.c +++ b/frame/3/trsm/bli_trsm_blk_var3.c @@ -87,7 +87,8 @@ void bli_trsm_blk_var3 bli_thrinfo_sub_node( thread ) ); - bli_thread_ibarrier( thread ); + //bli_thread_ibarrier( thread ); + bli_thread_obarrier( bli_thrinfo_sub_node( thread ) ); // This variant executes multiple rank-k updates. Therefore, if the // internal alpha scalars on A/B and C are non-zero, we must ensure diff --git a/frame/3/trsm/bli_trsm_cntl.c b/frame/3/trsm/bli_trsm_cntl.c index b4f7422ba..78bd5eeb9 100644 --- a/frame/3/trsm/bli_trsm_cntl.c +++ b/frame/3/trsm/bli_trsm_cntl.c @@ -50,14 +50,21 @@ cntl_t* bli_trsm_l_cntl_create { void* macro_kernel_p = bli_trsm_xx_ker_var2; - // Create a node for the macro-kernel. - cntl_t* trsm_cntl_bp_ke = bli_trsm_cntl_obj_create + // Create two nodes for the macro-kernel. + cntl_t* trsm_cntl_bu_ke = bli_trsm_cntl_obj_create ( - BLIS_NR, // bszid not used by macro-kernel. - macro_kernel_p, + BLIS_MR, // needed for bli_thrinfo_rgrow() + NULL, // variant function pointer not used NULL // no sub-node; this is the leaf of the tree. ); + cntl_t* trsm_cntl_bp_bu = bli_trsm_cntl_obj_create + ( + BLIS_NR, // not used by macro-kernel, but needed for bli_thrinfo_rgrow() + macro_kernel_p, + trsm_cntl_bu_ke + ); + // Create a node for packing matrix A. cntl_t* trsm_cntl_packa = bli_packm_cntl_obj_create ( @@ -70,7 +77,7 @@ cntl_t* bli_trsm_l_cntl_create FALSE, // reverse iteration if lower? BLIS_PACKED_ROW_PANELS, BLIS_BUFFER_FOR_A_BLOCK, - trsm_cntl_bp_ke + trsm_cntl_bp_bu ); // Create a node for partitioning the m dimension by MC. @@ -122,14 +129,21 @@ cntl_t* bli_trsm_r_cntl_create { void* macro_kernel_p = bli_trsm_xx_ker_var2; - // Create a node for the macro-kernel. - cntl_t* trsm_cntl_bp_ke = bli_trsm_cntl_obj_create + // Create two nodes for the macro-kernel. + cntl_t* trsm_cntl_bu_ke = bli_trsm_cntl_obj_create ( - BLIS_NR, // bszid not used by macro-kernel. - macro_kernel_p, + BLIS_MR, // needed for bli_thrinfo_rgrow() + NULL, // variant function pointer not used NULL // no sub-node; this is the leaf of the tree. ); + cntl_t* trsm_cntl_bp_bu = bli_trsm_cntl_obj_create + ( + BLIS_NR, // not used by macro-kernel, but needed for bli_thrinfo_rgrow() + macro_kernel_p, + trsm_cntl_bu_ke + ); + // Create a node for packing matrix A. cntl_t* trsm_cntl_packa = bli_packm_cntl_obj_create ( @@ -142,7 +156,7 @@ cntl_t* bli_trsm_r_cntl_create FALSE, // reverse iteration if lower? BLIS_PACKED_ROW_PANELS, BLIS_BUFFER_FOR_A_BLOCK, - trsm_cntl_bp_ke + trsm_cntl_bp_bu ); // Create a node for partitioning the m dimension by MC. diff --git a/frame/3/trsm/bli_trsm_front.c b/frame/3/trsm/bli_trsm_front.c index 3466d2d18..95c2d6aab 100644 --- a/frame/3/trsm/bli_trsm_front.c +++ b/frame/3/trsm/bli_trsm_front.c @@ -119,13 +119,12 @@ void bli_trsm_front // Set the operation family id in the context. bli_cntx_set_family( BLIS_TRSM, cntx ); - thrinfo_t** infos = bli_l3_thrinfo_create_paths( BLIS_TRSM, side ); - dim_t n_threads = bli_thread_num_threads( infos[0] ); + // Record the threading for each level within the context. + bli_cntx_set_thrloop_from_env( BLIS_TRSM, side, cntx ); // Invoke the internal back-end. bli_l3_thread_decorator ( - n_threads, bli_trsm_int, alpha, &a_local, @@ -133,10 +132,7 @@ void bli_trsm_front alpha, &c_local, cntx, - cntl, - infos + cntl ); - - bli_l3_thrinfo_free_paths( infos, n_threads ); } diff --git a/frame/3/trsm/bli_trsm_int.c b/frame/3/trsm/bli_trsm_int.c index e6614cb3f..796af7866 100644 --- a/frame/3/trsm/bli_trsm_int.c +++ b/frame/3/trsm/bli_trsm_int.c @@ -117,6 +117,9 @@ void bli_trsm_int // FGVZ->TMS: Is this barrier still needed? bli_thread_obarrier( thread ); + // Create the next node in the thrinfo_t structure. + bli_thrinfo_grow( cntx, cntl, thread ); + // Extract the function pointer from the current control tree node. f = bli_cntl_var_func( cntl ); diff --git a/frame/base/bli_cntx.c b/frame/base/bli_cntx.c index f2885cca3..31e995e1b 100644 --- a/frame/base/bli_cntx.c +++ b/frame/base/bli_cntx.c @@ -341,6 +341,37 @@ pack_t bli_cntx_get_pack_schema_b( cntx_t* cntx ) } #endif +dim_t bli_cntx_get_num_threads( cntx_t* cntx ) +{ + return bli_cntx_jc_way( cntx ) * + bli_cntx_pc_way( cntx ) * + bli_cntx_ic_way( cntx ) * + bli_cntx_jr_way( cntx ) * + bli_cntx_ir_way( cntx ); +} + +dim_t bli_cntx_get_num_threads_in( cntx_t* cntx, cntl_t* cntl ) +{ + dim_t n_threads_in = 1; + + for ( ; cntl != NULL; cntl = bli_cntl_sub_node( cntl ) ) + { + bszid_t bszid = bli_cntl_bszid( cntl ); + dim_t cur_way; + + // We assume bszid is in {KR,MR,NR,MC,KC,NR} if it is not + // BLIS_NO_PART. + if ( bszid != BLIS_NO_PART ) + cur_way = bli_cntx_way_for_bszid( bszid, cntx ); + else + cur_way = 1; + + n_threads_in *= cur_way; + } + + return n_threads_in; +} + // ----------------------------------------------------------------------------- #if 1 @@ -663,6 +694,96 @@ void bli_cntx_set_pack_schema_c( pack_t schema_c, bli_cntx_set_schema_c( schema_c, cntx ); } +void bli_cntx_set_thrloop_from_env( opid_t l3_op, side_t side, cntx_t* cntx ) +{ + dim_t jc, pc, ic, jr, ir; + +#ifdef BLIS_ENABLE_MULTITHREADING + jc = bli_env_read_nway( "BLIS_JC_NT" ); + //pc = bli_env_read_nway( "BLIS_KC_NT" ); + pc = 1; + ic = bli_env_read_nway( "BLIS_IC_NT" ); + jr = bli_env_read_nway( "BLIS_JR_NT" ); + ir = bli_env_read_nway( "BLIS_IR_NT" ); +#else + jc = 1; + pc = 1; + ic = 1; + jr = 1; + ir = 1; +#endif + + if ( l3_op == BLIS_TRMM ) + { + // We reconfigure the paralelism from trmm_r due to a dependency in + // the jc loop. (NOTE: This dependency does not exist for trmm3 ) + if ( bli_is_right( side ) ) + { + bli_cntx_set_thrloop + ( + 1, + pc, + ic, + jr * jc, + ir, + cntx + ); + } + else // if ( bli_is_left( side ) ) + { + bli_cntx_set_thrloop + ( + jc, + pc, + ic, + jr, + ir, + cntx + ); + } + } + else if ( l3_op == BLIS_TRSM ) + { + if ( bli_is_right( side ) ) + { + bli_cntx_set_thrloop + ( + 1, + 1, + jc * ic * jr, + 1, + 1, + cntx + ); + } + else // if ( bli_is_left( side ) ) + { + bli_cntx_set_thrloop + ( + 1, + 1, + 1, + ic * jr * ir, + 1, + cntx + ); + } + } + else // if ( l3_op == BLIS_TRSM ) + { + bli_cntx_set_thrloop + ( + jc, + pc, + ic, + jr, + ir, + cntx + ); + } +} + + // ----------------------------------------------------------------------------- bool_t bli_cntx_l3_nat_ukr_prefers_rows_dt( num_t dt, diff --git a/frame/base/bli_cntx.h b/frame/base/bli_cntx.h index 21f9c0fe0..6aed68111 100644 --- a/frame/base/bli_cntx.h +++ b/frame/base/bli_cntx.h @@ -59,6 +59,8 @@ typedef struct cntx_s pack_t schema_b; pack_t schema_c; + dim_t* thrloop; + membrk_t* membrk; } cntx_t; */ @@ -127,6 +129,36 @@ typedef struct cntx_s \ ( (cntx)->membrk ) +#define bli_cntx_thrloop( cntx ) \ +\ + ( (cntx)->thrloop ) + +#if 1 +#define bli_cntx_jc_way( cntx ) \ +\ + ( (cntx)->thrloop[ BLIS_NC ] ) + +#define bli_cntx_pc_way( cntx ) \ +\ + ( (cntx)->thrloop[ BLIS_KC ] ) + +#define bli_cntx_ic_way( cntx ) \ +\ + ( (cntx)->thrloop[ BLIS_MC ] ) + +#define bli_cntx_jr_way( cntx ) \ +\ + ( (cntx)->thrloop[ BLIS_NR ] ) + +#define bli_cntx_ir_way( cntx ) \ +\ + ( (cntx)->thrloop[ BLIS_MR ] ) +#endif + +#define bli_cntx_way_for_bszid( bszid, cntx ) \ +\ + ( (cntx)->thrloop[ bszid ] ) + // cntx_t modification (fields only) #define bli_cntx_set_blkszs_buf( _blkszs, cntx_p ) \ @@ -199,6 +231,16 @@ typedef struct cntx_s (cntx_p)->membrk = _membrk; \ } +#define bli_cntx_set_thrloop( jc_, pc_, ic_, jr_, ir_, cntx_p ) \ +{ \ + (cntx_p)->thrloop[ BLIS_NC ] = jc_; \ + (cntx_p)->thrloop[ BLIS_KC ] = pc_; \ + (cntx_p)->thrloop[ BLIS_MC ] = ic_; \ + (cntx_p)->thrloop[ BLIS_NR ] = jr_; \ + (cntx_p)->thrloop[ BLIS_MR ] = ir_; \ + (cntx_p)->thrloop[ BLIS_KR ] = 1; \ +} + // cntx_t query (complex) #define bli_cntx_get_blksz_def_dt( dt, bs_id, cntx ) \ @@ -356,6 +398,8 @@ func_t* bli_cntx_get_packm_ukr( cntx_t* cntx ); //pack_t bli_cntx_get_pack_schema_a( cntx_t* cntx ); //pack_t bli_cntx_get_pack_schema_b( cntx_t* cntx ); //pack_t bli_cntx_get_pack_schema_c( cntx_t* cntx ); +dim_t bli_cntx_get_num_threads( cntx_t* cntx ); +dim_t bli_cntx_get_num_threads_in( cntx_t* cntx, cntl_t* cntl ); // set functions @@ -390,6 +434,9 @@ void bli_cntx_set_pack_schema_b( pack_t schema_b, cntx_t* cntx ); void bli_cntx_set_pack_schema_c( pack_t schema_c, cntx_t* cntx ); +void bli_cntx_set_thrloop_from_env( opid_t l3_op, + side_t side, + cntx_t* cntx ); // other query functions diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 086740cfd..726f4a700 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -638,6 +638,21 @@ typedef enum #define BLIS_NUM_UKR_IMPL_TYPES 4 +#if 0 +typedef enum +{ + BLIS_JC_IDX = 0, + BLIS_PC_IDX, + BLIS_IC_IDX, + BLIS_JR_IDX, + BLIS_IR_IDX, + BLIS_PR_IDX, +} thridx_t; +#endif + +#define BLIS_NUM_LOOPS 6 + + // -- Operation ID type -- typedef enum @@ -949,6 +964,8 @@ typedef struct cntx_s pack_t schema_b; pack_t schema_c; + dim_t thrloop[ BLIS_NUM_LOOPS ]; + membrk_t* membrk; } cntx_t; diff --git a/frame/thread/bli_thrcomm.h b/frame/thread/bli_thrcomm.h index 6b4d2de1a..593f8d7fa 100644 --- a/frame/thread/bli_thrcomm.h +++ b/frame/thread/bli_thrcomm.h @@ -41,6 +41,12 @@ #include "bli_thrcomm_openmp.h" #include "bli_thrcomm_pthreads.h" + +// thrcomm_t query (field only) + +#define bli_thrcomm_num_threads( comm ) ( (comm)->n_threads ) + + // Thread communicator prototypes. thrcomm_t* bli_thrcomm_create( dim_t n_threads ); void bli_thrcomm_free( thrcomm_t* communicator ); diff --git a/frame/thread/bli_thrcomm_openmp.c b/frame/thread/bli_thrcomm_openmp.c index 7c1fe69f9..68d9d7a29 100644 --- a/frame/thread/bli_thrcomm_openmp.c +++ b/frame/thread/bli_thrcomm_openmp.c @@ -201,7 +201,6 @@ void bli_thrcomm_tree_barrier( barrier_t* barack ) void bli_l3_thread_decorator ( - dim_t n_threads, l3int_t func, obj_t* alpha, obj_t* a, @@ -209,20 +208,28 @@ void bli_l3_thread_decorator obj_t* beta, obj_t* c, cntx_t* cntx, - cntl_t* cntl, - thrinfo_t** thread + cntl_t* cntl ) { + // Query the total number of threads from the context. + dim_t n_threads = bli_cntx_get_num_threads( cntx ); + + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* gl_comm = bli_thrcomm_create( n_threads ); + _Pragma( "omp parallel num_threads(n_threads)" ) { - dim_t omp_id = omp_get_thread_num(); - thrinfo_t* thread_i = thread[omp_id]; + dim_t id = omp_get_thread_num(); cntl_t* cntl_use; + thrinfo_t* thread; // Create a default control tree for the operation, if needed. bli_l3_cntl_create_if( a, b, c, cntx, cntl, &cntl_use ); + // Create the root node of the current thread's thrinfo_t structure. + bli_l3_thrinfo_create_root( id, gl_comm, cntx, cntl_use, &thread ); + func ( alpha, @@ -232,12 +239,19 @@ void bli_l3_thread_decorator c, cntx, cntl_use, - thread[omp_id] + thread ); // Free the control tree, if one was created locally. - bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread_i ); + bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread ); + + // Free the current thread's thrinfo_t structure. + bli_l3_thrinfo_free( thread ); } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called above). } #endif diff --git a/frame/thread/bli_thrcomm_pthreads.c b/frame/thread/bli_thrcomm_pthreads.c index 0f2707d91..230b63905 100644 --- a/frame/thread/bli_thrcomm_pthreads.c +++ b/frame/thread/bli_thrcomm_pthreads.c @@ -136,7 +136,8 @@ typedef struct thread_data obj_t* c; cntx_t* cntx; cntl_t* cntl; - thrinfo_t* thread; + dim_t id; + thrcomm_t* gl_comm; } thread_data_t; // Entry point for additional threads @@ -151,13 +152,18 @@ void* bli_l3_thread_entry( void* data_void ) obj_t* c = data->c; cntx_t* cntx = data->cntx; cntl_t* cntl = data->cntl; - thrinfo_t* thread_i = data->thread; + dim_t id = data->id; + thrcomm_t* gl_comm = data->gl_comm; cntl_t* cntl_use; + thrinfo_t* thread; // Create a default control tree for the operation, if needed. bli_l3_cntl_create_if( a, b, c, cntx, cntl, &cntl_use ); + // Create the root node of the current thread's thrinfo_t structure. + bli_l3_thrinfo_create_root( id, gl_comm, cntx, cntl_use, &thread ); + data->func ( alpha, @@ -171,14 +177,16 @@ void* bli_l3_thread_entry( void* data_void ) ); // Free the control tree, if one was created locally. - bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread_i ); + bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread ); + + // Free the current thread's thrinfo_t structure. + bli_l3_thrinfo_free( thread ); return NULL; } void bli_l3_thread_decorator ( - dim_t n_threads, l3int_t func, obj_t* alpha, obj_t* a, @@ -186,50 +194,51 @@ void bli_l3_thread_decorator obj_t* beta, obj_t* c, cntx_t* cntx, - cntl_t* cntl, - thrinfo_t** thread + cntl_t* cntl ) { - pthread_t* pthreads = bli_malloc_intl( sizeof( pthread_t ) * n_threads ); - thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads ); + // Query the total number of threads from the context. + dim_t n_threads = bli_cntx_get_num_threads( cntx ); - for ( int i = 1; i < n_threads; i++ ) + // Allocate an array of pthread objects and auxiliary data structs to pass + // to the thread entry functions. + pthread_t* pthreads = bli_malloc_intl( sizeof( pthread_t ) * n_threads ); + thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads ); + + // Allocate a global communicator for the root thrinfo_t structures. + thrcomm_t* gl_comm = bli_thrcomm_create( n_threads ); + + // NOTE: We must iterate backwards so that the chief thread (thread id 0) + // can spawn all other threads before proceeding with its own computation. + for ( dim_t id = n_threads - 1; 0 <= id; id-- ) { // Set up thread data for additional threads (beyond thread 0). - datas[i].func = func; - datas[i].alpha = alpha; - datas[i].a = a; - datas[i].b = b; - datas[i].beta = beta; - datas[i].c = c; - datas[i].cntx = cntx; - datas[i].cntl = cntl; - datas[i].thread = thread[i]; + datas[id].func = func; + datas[id].alpha = alpha; + datas[id].a = a; + datas[id].b = b; + datas[id].beta = beta; + datas[id].c = c; + datas[id].cntx = cntx; + datas[id].cntl = cntl; + datas[id].id = id; + datas[id].gl_comm = gl_comm; - // Spawn additional threads. - pthread_create( &pthreads[i], NULL, &bli_l3_thread_entry, &datas[i] ); - } - - - // The main thread executes this. - { - cntl_t* cntl_use; - - // Create a default control tree for the operation, if needed. - bli_l3_cntl_create_if( a, b, c, cntx, cntl, &cntl_use ); - - // Thread 0 simply executes func. - func( alpha, a, b, beta, c, cntx, cntl, thread[0] ); - - // Free the control tree, if one was created locally. - bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread[0] ); + // Spawn additional threads for ids greater than 1. + if ( id != 0 ) + pthread_create( &pthreads[id], NULL, &bli_l3_thread_entry, &datas[id] ); + else + bli_l3_thread_entry( ( void* )(&datas[0]) ); } + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). // Thread 0 waits for additional threads to finish. - for ( int i = 1; i < n_threads; i++) + for ( dim_t id = 1; id < n_threads; id++ ) { - pthread_join( pthreads[i], NULL ); + pthread_join( pthreads[id], NULL ); } bli_free_intl( pthreads ); diff --git a/frame/thread/bli_thrcomm_single.c b/frame/thread/bli_thrcomm_single.c index 99de67220..c038f59a0 100644 --- a/frame/thread/bli_thrcomm_single.c +++ b/frame/thread/bli_thrcomm_single.c @@ -73,7 +73,6 @@ void bli_thrcomm_barrier( thrcomm_t* communicator, dim_t t_id ) void bli_l3_thread_decorator ( - dim_t n_threads, l3int_t func, obj_t* alpha, obj_t* a, @@ -81,17 +80,25 @@ void bli_l3_thread_decorator obj_t* beta, obj_t* c, cntx_t* cntx, - cntl_t* cntl, - thrinfo_t** thread + cntl_t* cntl ) { - thrinfo_t* thread_i = thread[0]; + // For sequential execution, we use only one thread. + dim_t n_threads = 1; + dim_t id = 0; + + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* gl_comm = bli_thrcomm_create( n_threads ); cntl_t* cntl_use; + thrinfo_t* thread; // Create a default control tree for the operation, if needed. bli_l3_cntl_create_if( a, b, c, cntx, cntl, &cntl_use ); + // Create the root node of the thread's thrinfo_t structure. + bli_l3_thrinfo_create_root( id, gl_comm, cntx, cntl_use, &thread ); + func ( alpha, @@ -101,11 +108,18 @@ void bli_l3_thread_decorator c, cntx, cntl_use, - thread[0] + thread ); // Free the control tree, if one was created locally. - bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread_i ); + bli_l3_cntl_free_if( a, b, c, cntx, cntl, cntl_use, thread ); + + // Free the current thread's thrinfo_t structure. + bli_l3_thrinfo_free( thread ); + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called above). } diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index 43f0eaf8b..d42744162 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -78,8 +78,8 @@ void bli_thread_get_range_sub dim_t* end ) { - dim_t n_way = thread->n_way; - dim_t work_id = thread->work_id; + dim_t n_way = bli_thread_n_way( thread ); + dim_t work_id = bli_thread_work_id( thread ); dim_t all_start = 0; dim_t all_end = n; @@ -511,8 +511,8 @@ siz_t bli_thread_get_range_weighted_sub dim_t* j_end_thr ) { - dim_t n_way = thread->n_way; - dim_t my_id = thread->work_id; + dim_t n_way = bli_thread_n_way( thread ); + dim_t my_id = bli_thread_work_id( thread ); dim_t bf_left = n % bf; diff --git a/frame/thread/bli_thread.h b/frame/thread/bli_thread.h index 10097c39e..5b9443587 100644 --- a/frame/thread/bli_thread.h +++ b/frame/thread/bli_thread.h @@ -173,16 +173,14 @@ typedef void (*l3int_t) // Level-3 thread decorator prototype void bli_l3_thread_decorator ( - dim_t n_threads, - l3int_t func, - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl, - thrinfo_t** thread + l3int_t func, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl ); // Miscellaneous prototypes diff --git a/frame/thread/bli_thrinfo.c b/frame/thread/bli_thrinfo.c index 4cf55b3d4..bad5c2772 100644 --- a/frame/thread/bli_thrinfo.c +++ b/frame/thread/bli_thrinfo.c @@ -38,11 +38,9 @@ thrinfo_t* bli_thrinfo_create ( thrcomm_t* ocomm, dim_t ocomm_id, - thrcomm_t* icomm, - dim_t icomm_id, dim_t n_way, dim_t work_id, - bool_t free_comms, + bool_t free_comm, thrinfo_t* sub_node ) { @@ -52,9 +50,8 @@ thrinfo_t* bli_thrinfo_create ( thread, ocomm, ocomm_id, - icomm, icomm_id, n_way, work_id, - free_comms, + free_comm, sub_node ); @@ -66,23 +63,19 @@ void bli_thrinfo_init thrinfo_t* thread, thrcomm_t* ocomm, dim_t ocomm_id, - thrcomm_t* icomm, - dim_t icomm_id, dim_t n_way, dim_t work_id, - bool_t free_comms, + bool_t free_comm, thrinfo_t* sub_node ) { - thread->ocomm = ocomm; - thread->ocomm_id = ocomm_id; - thread->icomm = icomm; - thread->icomm_id = icomm_id; - thread->n_way = n_way; - thread->work_id = work_id; - thread->free_comms = free_comms; + thread->ocomm = ocomm; + thread->ocomm_id = ocomm_id; + thread->n_way = n_way; + thread->work_id = work_id; + thread->free_comm = free_comm; - thread->sub_node = sub_node; + thread->sub_node = sub_node; } void bli_thrinfo_init_single @@ -94,7 +87,6 @@ void bli_thrinfo_init_single ( thread, &BLIS_SINGLE_COMM, 0, - &BLIS_SINGLE_COMM, 0, 1, 0, FALSE, @@ -102,3 +94,178 @@ void bli_thrinfo_init_single ); } +// ----------------------------------------------------------------------------- + +#include "assert.h" + +#define BLIS_NUM_STATIC_COMMS 18 + +thrinfo_t* bli_thrinfo_create_for_cntl + ( + cntx_t* cntx, + cntl_t* cntl_par, + cntl_t* cntl_chl, + thrinfo_t* thread_par + ) +{ + thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ]; + thrcomm_t** new_comms = NULL; + + thrinfo_t* thread_chl; + + bszid_t bszid_chl = bli_cntl_bszid( cntl_chl ); + + dim_t parent_nt_in = bli_thread_num_threads( thread_par ); + dim_t parent_n_way = bli_thread_n_way( thread_par ); + dim_t parent_comm_id = bli_thread_ocomm_id( thread_par ); + dim_t parent_work_id = bli_thread_work_id( thread_par ); + + dim_t child_nt_in; + dim_t child_comm_id; + dim_t child_n_way; + dim_t child_work_id; + + // Sanity check: make sure the number of threads in the parent's + // communicator is divisible by the number of new sub-groups. + assert( parent_nt_in % parent_n_way == 0 ); + + // Compute: + // - the number of threads inside the new child comm, + // - the current thread's id within the new communicator, + // - the current thread's work id, given the ways of parallelism + // to be obtained within the next loop. + child_nt_in = bli_cntx_get_num_threads_in( cntx, cntl_chl ); + child_n_way = bli_cntx_way_for_bszid( bszid_chl, cntx ); + child_comm_id = parent_comm_id % child_nt_in; + child_work_id = child_comm_id / ( child_nt_in / child_n_way ); + + // The parent's chief thread creates a temporary array of thrcomm_t + // pointers. + if ( bli_thread_am_ochief( thread_par ) ) + { + if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) + new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ) ); + else + new_comms = static_comms; + } + + // Broadcast the temporary array to all threads in the parent's + // communicator. + new_comms = bli_thread_obroadcast( thread_par, new_comms ); + + // Chiefs in the child communicator allocate the communicator + // object and store it in the array element corresponding to the + // parent's work id. + if ( child_comm_id == 0 ) + new_comms[ parent_work_id ] = bli_thrcomm_create( child_nt_in ); + + bli_thread_obarrier( thread_par ); + + // All threads create a new thrinfo_t node using the communicator + // that was created by their chief, as identified by parent_work_id. + thread_chl = bli_thrinfo_create + ( + new_comms[ parent_work_id ], + child_comm_id, + child_n_way, + child_work_id, + TRUE, + NULL + ); + + bli_thread_obarrier( thread_par ); + + // The parent's chief thread frees the temporary array of thrcomm_t + // pointers. + if ( bli_thread_am_ochief( thread_par ) ) + { + if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) + bli_free_intl( new_comms ); + } + + return thread_chl; +} + +void bli_thrinfo_grow + ( + cntx_t* cntx, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + // If the sub-node of the thrinfo_t object is non-NULL, we don't + // need to create it, and will just use the existing sub-node as-is. + if ( bli_thrinfo_sub_node( thread ) != NULL ) return; + + // Create a new node (or, if needed, multiple nodes) and return the + // pointer to the (eldest) child. + thrinfo_t* thread_child = bli_thrinfo_rgrow + ( + cntx, + cntl, + bli_cntl_sub_node( cntl ), + thread + ); + + // Attach the child thrinfo_t node to its parent structure. + bli_thrinfo_set_sub_node( thread_child, thread ); +} + +thrinfo_t* bli_thrinfo_rgrow + ( + cntx_t* cntx, + cntl_t* cntl_par, + cntl_t* cntl_cur, + thrinfo_t* thread_par + ) +{ + thrinfo_t* thread_cur; + + // We must handle two cases: those where the next node in the + // control tree is a partitioning node, and those where it is + // a non-partitioning (ie: packing) node. + if ( bli_cntl_bszid( cntl_cur ) != BLIS_NO_PART ) + { + // Create the child thrinfo_t node corresponding to cntl_cur, + // with cntl_par being the parent. + thread_cur = bli_thrinfo_create_for_cntl + ( + cntx, + cntl_par, + cntl_cur, + thread_par + ); + } + else // if ( bli_cntl_bszid( cntl_cur ) == BLIS_NO_PART ) + { + // Recursively grow the thread structure and return the top-most + // thrinfo_t node of that segment. + thrinfo_t* thread_seg = bli_thrinfo_rgrow + ( + cntx, + cntl_par, + bli_cntl_sub_node( cntl_cur ), + thread_par + ); + + // Create a thrinfo_t node corresponding to cntl_cur. Notice that + // the free_comm field is set to FALSE, since cntl_cur is a + // non-partitioning node. The communicator used here will be + // freed when thread_seg, or one of its descendents, is freed. + thread_cur = bli_thrinfo_create + ( + bli_thrinfo_ocomm( thread_seg ), + bli_thread_ocomm_id( thread_seg ), + bli_cntx_get_num_threads_in( cntx, cntl_cur ), + bli_thread_ocomm_id( thread_seg ), + FALSE, + thread_seg + ); + + // Attach the child thrinfo_t node to its parent structure. + bli_thrinfo_set_sub_node( thread_cur, thread_par ); + } + + return thread_cur; +} + diff --git a/frame/thread/bli_thrinfo.h b/frame/thread/bli_thrinfo.h index 9c0b28575..93bf19e50 100644 --- a/frame/thread/bli_thrinfo.h +++ b/frame/thread/bli_thrinfo.h @@ -45,13 +45,6 @@ struct thrinfo_s // Our thread id within the ocomm thread communicator. dim_t ocomm_id; - // The thread communicator for the other threads sharing the same work - // at this level. - thrcomm_t* icomm; - - // Our thread id within the icomm thread communicator. - dim_t icomm_id; - // The number of distinct threads used to parallelize the loop. dim_t n_way; @@ -62,7 +55,7 @@ struct thrinfo_s // this is field is true, but when nodes are created that share the same // communicators as other nodes (such as with packm nodes), this is set // to false. - bool_t free_comms; + bool_t free_comm; struct thrinfo_s* sub_node; }; @@ -71,30 +64,40 @@ typedef struct thrinfo_s thrinfo_t; // // thrinfo_t macros // NOTE: The naming of these should be made consistent at some point. +// (ie: bli_thrinfo_ vs. bli_thread_) // -#define bli_thread_num_threads( t ) ( (t)->ocomm->n_threads ) +// thrinfo_t query (field only) -#define bli_thread_n_way( t ) ( (t)->n_way ) -#define bli_thread_work_id( t ) ( (t)->work_id ) +#define bli_thread_num_threads( t ) ( (t)->ocomm->n_threads ) -#define bli_thread_am_ochief( t ) ( (t)->ocomm_id == 0 ) -#define bli_thread_am_ichief( t ) ( (t)->icomm_id == 0 ) +#define bli_thread_n_way( t ) ( (t)->n_way ) +#define bli_thread_work_id( t ) ( (t)->work_id ) +#define bli_thread_ocomm_id( t ) ( (t)->ocomm_id ) + +#define bli_thrinfo_ocomm( t ) ( (t)->ocomm ) +#define bli_thrinfo_needs_free_comm( t ) ( (t)->free_comm ) + +#define bli_thrinfo_sub_node( t ) ( (t)->sub_node ) + +// thrinfo_t query (complex) + +#define bli_thread_am_ochief( t ) ( (t)->ocomm_id == 0 ) + +// thrinfo_t modification + +#define bli_thrinfo_set_sub_node( _sub_node, thread ) \ +{ \ + (thread)->sub_node = _sub_node; \ +} + +// other thrinfo_t-related macros #define bli_thread_obroadcast( t, p ) bli_thrcomm_bcast( (t)->ocomm, \ (t)->ocomm_id, p ) -#define bli_thread_ibroadcast( t, p ) bli_thrcomm_bcast( (t)->icomm, \ - (t)->icomm_id, p ) #define bli_thread_obarrier( t ) bli_thrcomm_barrier( (t)->ocomm, \ (t)->ocomm_id ) -#define bli_thread_ibarrier( t ) bli_thrcomm_barrier( (t)->icomm, \ - (t)->icomm_id ) -#define bli_thrinfo_ocomm( t ) ( (t)->ocomm ) -#define bli_thrinfo_icomm( t ) ( (t)->icomm ) -#define bli_thrinfo_needs_free_comms( t ) ( (t)->free_comms ) - -#define bli_thrinfo_sub_node( t ) ( (t)->sub_node ) // // Prototypes for level-3 thrinfo functions not specific to any operation. @@ -104,11 +107,9 @@ thrinfo_t* bli_thrinfo_create ( thrcomm_t* ocomm, dim_t ocomm_id, - thrcomm_t* icomm, - dim_t icomm_id, dim_t n_way, dim_t work_id, - bool_t free_comms, + bool_t free_comm, thrinfo_t* sub_node ); @@ -117,11 +118,9 @@ void bli_thrinfo_init thrinfo_t* thread, thrcomm_t* ocomm, dim_t ocomm_id, - thrcomm_t* icomm, - dim_t icomm_id, dim_t n_way, dim_t work_id, - bool_t free_comms, + bool_t free_comm, thrinfo_t* sub_node ); @@ -130,9 +129,29 @@ void bli_thrinfo_init_single thrinfo_t* thread ); -void bli_thrinfo_free +// ----------------------------------------------------------------------------- + +thrinfo_t* bli_thrinfo_create_for_cntl ( + cntx_t* cntx, + cntl_t* cntl_par, + cntl_t* cntl_chl, + thrinfo_t* thread_par + ); + +void bli_thrinfo_grow + ( + cntx_t* cntx, + cntl_t* cntl, thrinfo_t* thread ); +thrinfo_t* bli_thrinfo_rgrow + ( + cntx_t* cntx, + cntl_t* cntl_par, + cntl_t* cntl_cur, + thrinfo_t* thread_par + ); + #endif