diff --git a/frame/3/gemm/bli_gemm_cntl.c b/frame/3/gemm/bli_gemm_cntl.c index d10c2daf6..753182a8f 100644 --- a/frame/3/gemm/bli_gemm_cntl.c +++ b/frame/3/gemm/bli_gemm_cntl.c @@ -55,11 +55,9 @@ gemm_t* gemm_cntl_vl_mm; gemm_t* gemm_cntl; -dim_t gemm_caucuses_at_level[5] = {1, 1, 2, 1, 1}; - gemm_thrinfo_t* bli_gemm_cntl_get_thrinfos() { - return bli_create_gemm_thrinfo_paths( gemm_caucuses_at_level, 5 ); + return bli_create_gemm_thrinfo_paths( ); } void bli_gemm_cntl_free_thrinfos(thrinfo_t* tofree) diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index af93b6079..1c26681af 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -90,5 +90,7 @@ void bli_gemm_front( obj_t* alpha, cntl, &infos[omp_id] ); } + + bli_gemm_cntl_free_thrinfos( infos ); } diff --git a/frame/3/gemm/bli_gemm_threading.c b/frame/3/gemm/bli_gemm_threading.c index 5c0a337ff..15c3aa84b 100644 --- a/frame/3/gemm/bli_gemm_threading.c +++ b/frame/3/gemm/bli_gemm_threading.c @@ -84,15 +84,24 @@ gemm_thrinfo_t* bli_create_gemm_thrinfo_node( thread_comm_t* ocomm, dim_t ocomm_ return thread; } -gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( dim_t* threads_at_level, dim_t n_levels ) +dim_t read_env( char* env ) { - assert(n_levels == 5); + dim_t number = 1; + char* str = getenv( env ); + if( str != NULL ) + { + number = strtol( str, NULL, 10 ); + } + return number; +} - dim_t jc_way = threads_at_level[0]; - dim_t kc_way = threads_at_level[1]; - dim_t ic_way = threads_at_level[2]; - dim_t jr_way = threads_at_level[3]; - dim_t ir_way = threads_at_level[4]; +gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( ) +{ + dim_t jc_way = read_env( "BLIS_JC_NT" ); + dim_t kc_way = read_env( "BLIS_KC_NT" ); + dim_t ic_way = read_env( "BLIS_IC_NT" ); + dim_t jr_way = read_env( "BLIS_JR_NT" ); + dim_t ir_way = read_env( "BLIS_IR_NT" ); dim_t global_num_threads = jc_way * kc_way * ic_way * jr_way * ir_way; assert( global_num_threads != 0 ); diff --git a/frame/3/gemm/bli_gemm_threading.h b/frame/3/gemm/bli_gemm_threading.h index 784a4b9ef..d046608da 100644 --- a/frame/3/gemm/bli_gemm_threading.h +++ b/frame/3/gemm/bli_gemm_threading.h @@ -53,7 +53,7 @@ typedef struct gemm_thrinfo_s gemm_thrinfo_t; #define gemm_thread_sub_opackm( thread ) thread->opackm #define gemm_thread_sub_ipackm( thread ) thread->ipackm -gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( dim_t* threads_at_level, dim_t n_levels ); +gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( ); void bli_setup_gemm_thrinfo_node( gemm_thrinfo_t* thread, thread_comm_t* ocomm, dim_t ocomm_id, diff --git a/frame/3/hemm/bli_hemm_front.c b/frame/3/hemm/bli_hemm_front.c index a99869dd2..4613857b8 100644 --- a/frame/3/hemm/bli_hemm_front.c +++ b/frame/3/hemm/bli_hemm_front.c @@ -80,13 +80,24 @@ void bli_hemm_front( side_t side, bli_obj_swap( a_local, b_local ); } - // Invoke the internal back-end. - bli_gemm_int( alpha, - &a_local, - &b_local, - beta, - &c_local, - cntl, - &BLIS_GEMM_SINGLE_THREADED ); + gemm_thrinfo_t* infos = bli_gemm_cntl_get_thrinfos(); + dim_t n_threads = thread_num_threads( (&infos[0]) ); + + // Invoke the internal back-end. + _Pragma( "omp parallel num_threads(n_threads)" ) + { + dim_t omp_id = omp_get_thread_num(); + + // Invoke the internal back-end. + bli_gemm_int( alpha, + &a_local, + &b_local, + beta, + &c_local, + cntl, + &infos[omp_id] ); + } + + bli_gemm_cntl_free_thrinfos( infos ); } diff --git a/frame/3/symm/bli_symm_front.c b/frame/3/symm/bli_symm_front.c index 5043f1355..abc7930a3 100644 --- a/frame/3/symm/bli_symm_front.c +++ b/frame/3/symm/bli_symm_front.c @@ -79,13 +79,24 @@ void bli_symm_front( side_t side, bli_obj_swap( a_local, b_local ); } + gemm_thrinfo_t* infos = bli_gemm_cntl_get_thrinfos(); + dim_t n_threads = thread_num_threads( (&infos[0]) ); + // Invoke the internal back-end. - bli_gemm_int( alpha, - &a_local, - &b_local, - beta, - &c_local, - cntl, - &BLIS_GEMM_SINGLE_THREADED ); + _Pragma( "omp parallel num_threads(n_threads)" ) + { + dim_t omp_id = omp_get_thread_num(); + + + bli_gemm_int( alpha, + &a_local, + &b_local, + beta, + &c_local, + cntl, + &infos[omp_id] ); + } + + bli_gemm_cntl_free_thrinfos( infos ); }