diff --git a/frame/3/gemm/bli_gemm_cntl.c b/frame/3/gemm/bli_gemm_cntl.c index fd6f92c14..2fccb5fc7 100644 --- a/frame/3/gemm/bli_gemm_cntl.c +++ b/frame/3/gemm/bli_gemm_cntl.c @@ -55,11 +55,6 @@ gemm_t* gemm_cntl_vl_mm; gemm_t* gemm_cntl; -gemm_thrinfo_t* bli_gemm_cntl_get_thrinfos() -{ - return bli_create_gemm_thrinfo_paths( ); -} - void bli_gemm_cntl_init() { // Create blocksize objects for each dimension. diff --git a/frame/3/gemm/bli_gemm_cntl.h b/frame/3/gemm/bli_gemm_cntl.h index 136f89ef5..882b746eb 100644 --- a/frame/3/gemm/bli_gemm_cntl.h +++ b/frame/3/gemm/bli_gemm_cntl.h @@ -65,4 +65,3 @@ gemm_t* bli_gemm_cntl_obj_create( impl_t impl_type, gemm_t* sub_gemm, unpackm_t* sub_unpack_c ); -gemm_thrinfo_t* bli_gemm_cntl_get_thrinfos(); diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index 88bc32d9a..a17a600b5 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -74,22 +74,19 @@ void bli_gemm_front( obj_t* alpha, bli_obj_induce_trans( c_local ); } - gemm_thrinfo_t* infos = bli_gemm_cntl_get_thrinfos(); - dim_t n_threads = thread_num_threads( (&infos[0]) ); + gemm_thrinfo_t** infos = bli_create_gemm_thrinfo_paths(); + dim_t n_threads = thread_num_threads( infos[0] ); // Invoke the internal back-end. - _Pragma( "omp parallel num_threads(n_threads)" ) - { - dim_t omp_id = omp_get_thread_num(); - - bli_gemm_int( alpha, - &a_local, - &b_local, - beta, - &c_local, - cntl, - &infos[omp_id] ); - } + bli_level3_thread_decorator( n_threads, + (level3_int_t*) bli_gemm_int, + alpha, + &a_local, + &b_local, + beta, + &c_local, + (void*) cntl, + (void**) infos ); bli_gemm_thrinfo_free_paths( infos ); } diff --git a/frame/3/gemm/bli_gemm_threading.c b/frame/3/gemm/bli_gemm_threading.c index 627df7f9a..6d2ec5f1b 100644 --- a/frame/3/gemm/bli_gemm_threading.c +++ b/frame/3/gemm/bli_gemm_threading.c @@ -95,11 +95,11 @@ dim_t read_env( char* env ) return number; } -void bli_gemm_thrinfo_free_paths( gemm_thrinfo_t* threads ) +void bli_gemm_thrinfo_free_paths( gemm_thrinfo_t** threads ) { } -gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( ) +gemm_thrinfo_t** bli_create_gemm_thrinfo_paths( ) { dim_t jc_way = read_env( "BLIS_JC_NT" ); dim_t kc_way = read_env( "BLIS_KC_NT" ); @@ -117,7 +117,7 @@ gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( ) dim_t ir_nt = 1; - gemm_thrinfo_t* paths = (gemm_thrinfo_t*) malloc( global_num_threads * sizeof( gemm_thrinfo_t ) ); + gemm_thrinfo_t** paths = (gemm_thrinfo_t**) malloc( global_num_threads * sizeof( gemm_thrinfo_t* ) ); thread_comm_t* global_comm = bli_create_communicator( global_num_threads ); for( int a = 0; a < jc_way; a++ ) @@ -170,11 +170,11 @@ gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( ) kc_way, b, NULL, NULL, ic_info); - gemm_thrinfo_t* jc_info = &paths[global_comm_id]; - bli_setup_gemm_thrinfo_node( jc_info, global_comm, global_comm_id, - jc_comm, jc_comm_id, - jc_way, a, - NULL, NULL, kc_info); + gemm_thrinfo_t* jc_info = bli_create_gemm_thrinfo_node( global_comm, global_comm_id, + jc_comm, jc_comm_id, + jc_way, a, + NULL, NULL, kc_info); + paths[global_comm_id] = jc_info; } } } diff --git a/frame/3/gemm/bli_gemm_threading.h b/frame/3/gemm/bli_gemm_threading.h index 280ba96ad..54a8f4884 100644 --- a/frame/3/gemm/bli_gemm_threading.h +++ b/frame/3/gemm/bli_gemm_threading.h @@ -53,8 +53,8 @@ typedef struct gemm_thrinfo_s gemm_thrinfo_t; #define gemm_thread_sub_opackm( thread ) thread->opackm #define gemm_thread_sub_ipackm( thread ) thread->ipackm -gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( ); -void bli_gemm_thrinfo_free_paths( gemm_thrinfo_t* ); +gemm_thrinfo_t** bli_create_gemm_thrinfo_paths( ); +void bli_gemm_thrinfo_free_paths( gemm_thrinfo_t** ); void bli_setup_gemm_thrinfo_node( gemm_thrinfo_t* thread, thread_comm_t* ocomm, dim_t ocomm_id, diff --git a/frame/3/hemm/bli_hemm_front.c b/frame/3/hemm/bli_hemm_front.c index fde8f9f70..9d1a7ea5c 100644 --- a/frame/3/hemm/bli_hemm_front.c +++ b/frame/3/hemm/bli_hemm_front.c @@ -80,23 +80,19 @@ void bli_hemm_front( side_t side, bli_obj_swap( a_local, b_local ); } - gemm_thrinfo_t* infos = bli_gemm_cntl_get_thrinfos(); - dim_t n_threads = thread_num_threads( (&infos[0]) ); + gemm_thrinfo_t** infos = bli_create_gemm_thrinfo_paths(); + dim_t n_threads = thread_num_threads( infos[0] ); // Invoke the internal back-end. - _Pragma( "omp parallel num_threads(n_threads)" ) - { - dim_t omp_id = omp_get_thread_num(); - - // Invoke the internal back-end. - bli_gemm_int( alpha, - &a_local, - &b_local, - beta, - &c_local, - cntl, - &infos[omp_id] ); - } + bli_level3_thread_decorator( n_threads, + (level3_int_t*) bli_gemm_int, + alpha, + &a_local, + &b_local, + beta, + &c_local, + (void*) cntl, + (void**) infos ); bli_gemm_thrinfo_free_paths( infos ); } diff --git a/frame/3/her2k/bli_her2k_front.c b/frame/3/her2k/bli_her2k_front.c index 1097c338c..6d019fe57 100644 --- a/frame/3/her2k/bli_her2k_front.c +++ b/frame/3/her2k/bli_her2k_front.c @@ -109,22 +109,34 @@ void bli_her2k_front( obj_t* alpha, &c_local, cntl ); #else - // Invoke herk twice, using beta only the first time. - bli_herk_int( alpha, - &a_local, - &bh_local, - beta, - &c_local, - cntl, - &BLIS_HERK_SINGLE_THREADED ); - bli_herk_int( &alpha_conj, - &b_local, - &ah_local, - &BLIS_ONE, - &c_local, - cntl, - &BLIS_HERK_SINGLE_THREADED ); + // Invoke herk twice, using beta only the first time. + herk_thrinfo_t** infos = bli_create_herk_thrinfo_paths(); + dim_t n_threads = thread_num_threads( infos[0] ); + + // Invoke the internal back-end. + bli_level3_thread_decorator( n_threads, + (level3_int_t*) bli_herk_int, + alpha, + &a_local, + &bh_local, + beta, + &c_local, + (void*) cntl, + (void**) infos ); + + bli_level3_thread_decorator( n_threads, + (level3_int_t*) bli_herk_int, + &alpha_conj, + &b_local, + &ah_local, + &BLIS_ONE, + &c_local, + (void*) cntl, + (void**) infos ); + + bli_herk_thrinfo_free_paths( infos ); + #endif } diff --git a/frame/3/herk/bli_herk_front.c b/frame/3/herk/bli_herk_front.c index 19a033a57..33c36fd3b 100644 --- a/frame/3/herk/bli_herk_front.c +++ b/frame/3/herk/bli_herk_front.c @@ -77,24 +77,20 @@ void bli_herk_front( obj_t* alpha, bli_obj_induce_trans( c_local ); } - herk_thrinfo_t* infos = bli_create_herk_thrinfo_paths(); - dim_t n_threads = thread_num_threads( (&infos[0]) ); + herk_thrinfo_t** infos = bli_create_herk_thrinfo_paths(); + dim_t n_threads = thread_num_threads( infos[0] ); - // Invoke the internal back-end. - _Pragma( "omp parallel num_threads(n_threads)" ) - { - dim_t omp_id = omp_get_thread_num(); + // Invoke the internal back-end. + bli_level3_thread_decorator( n_threads, + (level3_int_t*) bli_herk_int, + alpha, + &a_local, + &ah_local, + beta, + &c_local, + (void*) cntl, + (void**) infos ); - - bli_herk_int( alpha, - &a_local, - &ah_local, - beta, - &c_local, - cntl, - &infos[omp_id] ); - } - bli_herk_thrinfo_free_paths( infos ); } diff --git a/frame/3/herk/bli_herk_threading.c b/frame/3/herk/bli_herk_threading.c index ec6c9d31c..942014883 100644 --- a/frame/3/herk/bli_herk_threading.c +++ b/frame/3/herk/bli_herk_threading.c @@ -84,11 +84,11 @@ herk_thrinfo_t* bli_create_herk_thrinfo_node( thread_comm_t* ocomm, dim_t ocomm_ return thread; } -void bli_herk_thrinfo_free_paths( herk_thrinfo_t* threads ) +void bli_herk_thrinfo_free_paths( herk_thrinfo_t** threads ) { } -herk_thrinfo_t* bli_create_herk_thrinfo_paths( ) +herk_thrinfo_t** bli_create_herk_thrinfo_paths( ) { dim_t jc_way = read_env( "BLIS_JC_NT" ); dim_t kc_way = read_env( "BLIS_KC_NT" ); @@ -106,7 +106,7 @@ herk_thrinfo_t* bli_create_herk_thrinfo_paths( ) dim_t ir_nt = 1; - herk_thrinfo_t* paths = (herk_thrinfo_t*) malloc( global_num_threads * sizeof( herk_thrinfo_t ) ); + herk_thrinfo_t** paths = (herk_thrinfo_t**) malloc( global_num_threads * sizeof( herk_thrinfo_t* ) ); thread_comm_t* global_comm = bli_create_communicator( global_num_threads ); for( int a = 0; a < jc_way; a++ ) @@ -159,11 +159,12 @@ herk_thrinfo_t* bli_create_herk_thrinfo_paths( ) kc_way, b, NULL, NULL, ic_info); - herk_thrinfo_t* jc_info = &paths[global_comm_id]; - bli_setup_herk_thrinfo_node( jc_info, global_comm, global_comm_id, - jc_comm, jc_comm_id, - jc_way, a, - NULL, NULL, kc_info); + herk_thrinfo_t* jc_info = bli_create_herk_thrinfo_node( global_comm, global_comm_id, + jc_comm, jc_comm_id, + jc_way, a, + NULL, NULL, kc_info); + + paths[global_comm_id] = jc_info; } } } diff --git a/frame/3/herk/bli_herk_threading.h b/frame/3/herk/bli_herk_threading.h index f0e206cc7..05e038aab 100644 --- a/frame/3/herk/bli_herk_threading.h +++ b/frame/3/herk/bli_herk_threading.h @@ -53,8 +53,8 @@ typedef struct herk_thrinfo_s herk_thrinfo_t; #define herk_thread_sub_opackm( thread ) thread->opackm #define herk_thread_sub_ipackm( thread ) thread->ipackm -herk_thrinfo_t* bli_herk_create_thrinfo_paths( ); -void bli_herk_thrinfo_free_paths(); +herk_thrinfo_t** bli_create_herk_thrinfo_paths( ); +void bli_herk_thrinfo_free_paths( herk_thrinfo_t** paths ); void bli_setup_herk_thrinfo_node( herk_thrinfo_t* thread, thread_comm_t* ocomm, dim_t ocomm_id, diff --git a/frame/3/symm/bli_symm_front.c b/frame/3/symm/bli_symm_front.c index 99c628c88..cce25b4c8 100644 --- a/frame/3/symm/bli_symm_front.c +++ b/frame/3/symm/bli_symm_front.c @@ -79,23 +79,19 @@ void bli_symm_front( side_t side, bli_obj_swap( a_local, b_local ); } - gemm_thrinfo_t* infos = bli_gemm_cntl_get_thrinfos(); - dim_t n_threads = thread_num_threads( (&infos[0]) ); - - // Invoke the internal back-end. - _Pragma( "omp parallel num_threads(n_threads)" ) - { - dim_t omp_id = omp_get_thread_num(); - - - bli_gemm_int( alpha, - &a_local, - &b_local, - beta, - &c_local, - cntl, - &infos[omp_id] ); - } + gemm_thrinfo_t** infos = bli_create_gemm_thrinfo_paths(); + dim_t n_threads = thread_num_threads( infos[0] ); + + // Invoke the internal back-end. + bli_level3_thread_decorator( n_threads, + (level3_int_t*) bli_gemm_int, + alpha, + &a_local, + &b_local, + beta, + &c_local, + (void*) cntl, + (void**) infos ); bli_gemm_thrinfo_free_paths( infos ); } diff --git a/frame/3/syr2k/bli_syr2k_front.c b/frame/3/syr2k/bli_syr2k_front.c index ab2d0d700..fb5d4f0f6 100644 --- a/frame/3/syr2k/bli_syr2k_front.c +++ b/frame/3/syr2k/bli_syr2k_front.c @@ -93,21 +93,31 @@ void bli_syr2k_front( obj_t* alpha, cntl ); #else // Invoke herk twice, using beta only the first time. - bli_herk_int( alpha, - &a_local, - &bt_local, - beta, - &c_local, - cntl, - &BLIS_HERK_SINGLE_THREADED ); + herk_thrinfo_t** infos = bli_create_herk_thrinfo_paths(); + dim_t n_threads = thread_num_threads( infos[0] ); - bli_herk_int( alpha, - &b_local, - &at_local, - &BLIS_ONE, - &c_local, - cntl, - &BLIS_HERK_SINGLE_THREADED ); + // Invoke the internal back-end. + bli_level3_thread_decorator( n_threads, + (level3_int_t*) bli_herk_int, + alpha, + &a_local, + &bt_local, + beta, + &c_local, + (void*) cntl, + (void**) infos ); + + bli_level3_thread_decorator( n_threads, + (level3_int_t*) bli_herk_int, + alpha, + &b_local, + &at_local, + &BLIS_ONE, + &c_local, + (void*) cntl, + (void**) infos ); + + bli_herk_thrinfo_free_paths( infos ); #endif } diff --git a/frame/3/syrk/bli_syrk_front.c b/frame/3/syrk/bli_syrk_front.c index 9022c9442..d9039cdb0 100644 --- a/frame/3/syrk/bli_syrk_front.c +++ b/frame/3/syrk/bli_syrk_front.c @@ -72,14 +72,21 @@ void bli_syrk_front( obj_t* alpha, { bli_obj_induce_trans( c_local ); } + + herk_thrinfo_t** infos = bli_create_herk_thrinfo_paths(); + dim_t n_threads = thread_num_threads( infos[0] ); - // Invoke the internal back-end. - bli_herk_int( alpha, - &a_local, - &at_local, - beta, - &c_local, - cntl, - &BLIS_HERK_SINGLE_THREADED ); + // Invoke the internal back-end. + bli_level3_thread_decorator( n_threads, + (level3_int_t*) bli_herk_int, + alpha, + &a_local, + &at_local, + beta, + &c_local, + (void*) cntl, + (void**) infos ); + + bli_herk_thrinfo_free_paths( infos ); } diff --git a/frame/base/bli_threading.c b/frame/base/bli_threading.c index f830ebc2d..53405bd96 100644 --- a/frame/base/bli_threading.c +++ b/frame/base/bli_threading.c @@ -95,10 +95,14 @@ void bli_barrier( thread_comm_t* communicator, dim_t t_id ) bool_t my_sense = communicator->barrier_sense; dim_t my_threads_arrived; + _Pragma("omp atomic capture") + my_threads_arrived = communicator->barrier_threads_arrived++; +/* bli_set_lock(&communicator->barrier_lock); my_threads_arrived = communicator->barrier_threads_arrived + 1; communicator->barrier_threads_arrived = my_threads_arrived; bli_unset_lock(&communicator->barrier_lock); +*/ if( my_threads_arrived == communicator->n_threads ) { @@ -223,3 +227,31 @@ void bli_get_range( void* thr, dim_t size, dim_t block_factor, dim_t* start, dim *start = work_id * n_pt; *end = bli_min( *start + n_pt, size ); } + +void bli_get_range_tri_weighted( void* thr, dim_t size, dim_t block_factor, bool_t forward, dim_t* start, dim_t* end) +{ +} + +void bli_level3_thread_decorator( dim_t n_threads, + level3_int_t* func, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + void* cntl, + void** thread ) +{ + _Pragma( "omp parallel num_threads(n_threads)" ) + { + dim_t omp_id = omp_get_thread_num(); + + (*func) ( alpha, + a, + b, + beta, + c, + cntl, + thread[omp_id] ); + } +} diff --git a/frame/base/bli_threading.h b/frame/base/bli_threading.h index b944457b5..fdd3ae32a 100644 --- a/frame/base/bli_threading.h +++ b/frame/base/bli_threading.h @@ -99,4 +99,15 @@ void bli_setup_thread_info( thrinfo_t* thread, thread_comm_t* ocomm, dim_t ocomm #include "bli_gemm_threading.h" #include "bli_herk_threading.h" +typedef void (*level3_int_t) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, void* cntl, void* thread ); +void bli_level3_thread_decorator( dim_t num_threads, + level3_int_t* func, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + void* cntl, + void** thread ); + #endif