Added decorator for calling parallelized intermal functions

Will allow for easy support for different threading models
This commit is contained in:
Tyler Smith
2014-03-18 13:26:27 -05:00
parent 5296f58975
commit 0ac534cdf6
15 changed files with 177 additions and 125 deletions

View File

@@ -55,11 +55,6 @@ gemm_t* gemm_cntl_vl_mm;
gemm_t* gemm_cntl;
gemm_thrinfo_t* bli_gemm_cntl_get_thrinfos()
{
return bli_create_gemm_thrinfo_paths( );
}
void bli_gemm_cntl_init()
{
// Create blocksize objects for each dimension.

View File

@@ -65,4 +65,3 @@ gemm_t* bli_gemm_cntl_obj_create( impl_t impl_type,
gemm_t* sub_gemm,
unpackm_t* sub_unpack_c );
gemm_thrinfo_t* bli_gemm_cntl_get_thrinfos();

View File

@@ -74,22 +74,19 @@ void bli_gemm_front( obj_t* alpha,
bli_obj_induce_trans( c_local );
}
gemm_thrinfo_t* infos = bli_gemm_cntl_get_thrinfos();
dim_t n_threads = thread_num_threads( (&infos[0]) );
gemm_thrinfo_t** infos = bli_create_gemm_thrinfo_paths();
dim_t n_threads = thread_num_threads( infos[0] );
// Invoke the internal back-end.
_Pragma( "omp parallel num_threads(n_threads)" )
{
dim_t omp_id = omp_get_thread_num();
bli_gemm_int( alpha,
&a_local,
&b_local,
beta,
&c_local,
cntl,
&infos[omp_id] );
}
bli_level3_thread_decorator( n_threads,
(level3_int_t*) bli_gemm_int,
alpha,
&a_local,
&b_local,
beta,
&c_local,
(void*) cntl,
(void**) infos );
bli_gemm_thrinfo_free_paths( infos );
}

View File

@@ -95,11 +95,11 @@ dim_t read_env( char* env )
return number;
}
void bli_gemm_thrinfo_free_paths( gemm_thrinfo_t* threads )
void bli_gemm_thrinfo_free_paths( gemm_thrinfo_t** threads )
{
}
gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( )
gemm_thrinfo_t** bli_create_gemm_thrinfo_paths( )
{
dim_t jc_way = read_env( "BLIS_JC_NT" );
dim_t kc_way = read_env( "BLIS_KC_NT" );
@@ -117,7 +117,7 @@ gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( )
dim_t ir_nt = 1;
gemm_thrinfo_t* paths = (gemm_thrinfo_t*) malloc( global_num_threads * sizeof( gemm_thrinfo_t ) );
gemm_thrinfo_t** paths = (gemm_thrinfo_t**) malloc( global_num_threads * sizeof( gemm_thrinfo_t* ) );
thread_comm_t* global_comm = bli_create_communicator( global_num_threads );
for( int a = 0; a < jc_way; a++ )
@@ -170,11 +170,11 @@ gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( )
kc_way, b,
NULL, NULL, ic_info);
gemm_thrinfo_t* jc_info = &paths[global_comm_id];
bli_setup_gemm_thrinfo_node( jc_info, global_comm, global_comm_id,
jc_comm, jc_comm_id,
jc_way, a,
NULL, NULL, kc_info);
gemm_thrinfo_t* jc_info = bli_create_gemm_thrinfo_node( global_comm, global_comm_id,
jc_comm, jc_comm_id,
jc_way, a,
NULL, NULL, kc_info);
paths[global_comm_id] = jc_info;
}
}
}

View File

@@ -53,8 +53,8 @@ typedef struct gemm_thrinfo_s gemm_thrinfo_t;
#define gemm_thread_sub_opackm( thread ) thread->opackm
#define gemm_thread_sub_ipackm( thread ) thread->ipackm
gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( );
void bli_gemm_thrinfo_free_paths( gemm_thrinfo_t* );
gemm_thrinfo_t** bli_create_gemm_thrinfo_paths( );
void bli_gemm_thrinfo_free_paths( gemm_thrinfo_t** );
void bli_setup_gemm_thrinfo_node( gemm_thrinfo_t* thread,
thread_comm_t* ocomm, dim_t ocomm_id,

View File

@@ -80,23 +80,19 @@ void bli_hemm_front( side_t side,
bli_obj_swap( a_local, b_local );
}
gemm_thrinfo_t* infos = bli_gemm_cntl_get_thrinfos();
dim_t n_threads = thread_num_threads( (&infos[0]) );
gemm_thrinfo_t** infos = bli_create_gemm_thrinfo_paths();
dim_t n_threads = thread_num_threads( infos[0] );
// Invoke the internal back-end.
_Pragma( "omp parallel num_threads(n_threads)" )
{
dim_t omp_id = omp_get_thread_num();
// Invoke the internal back-end.
bli_gemm_int( alpha,
&a_local,
&b_local,
beta,
&c_local,
cntl,
&infos[omp_id] );
}
bli_level3_thread_decorator( n_threads,
(level3_int_t*) bli_gemm_int,
alpha,
&a_local,
&b_local,
beta,
&c_local,
(void*) cntl,
(void**) infos );
bli_gemm_thrinfo_free_paths( infos );
}

View File

@@ -109,22 +109,34 @@ void bli_her2k_front( obj_t* alpha,
&c_local,
cntl );
#else
// Invoke herk twice, using beta only the first time.
bli_herk_int( alpha,
&a_local,
&bh_local,
beta,
&c_local,
cntl,
&BLIS_HERK_SINGLE_THREADED );
bli_herk_int( &alpha_conj,
&b_local,
&ah_local,
&BLIS_ONE,
&c_local,
cntl,
&BLIS_HERK_SINGLE_THREADED );
// Invoke herk twice, using beta only the first time.
herk_thrinfo_t** infos = bli_create_herk_thrinfo_paths();
dim_t n_threads = thread_num_threads( infos[0] );
// Invoke the internal back-end.
bli_level3_thread_decorator( n_threads,
(level3_int_t*) bli_herk_int,
alpha,
&a_local,
&bh_local,
beta,
&c_local,
(void*) cntl,
(void**) infos );
bli_level3_thread_decorator( n_threads,
(level3_int_t*) bli_herk_int,
&alpha_conj,
&b_local,
&ah_local,
&BLIS_ONE,
&c_local,
(void*) cntl,
(void**) infos );
bli_herk_thrinfo_free_paths( infos );
#endif
}

View File

@@ -77,24 +77,20 @@ void bli_herk_front( obj_t* alpha,
bli_obj_induce_trans( c_local );
}
herk_thrinfo_t* infos = bli_create_herk_thrinfo_paths();
dim_t n_threads = thread_num_threads( (&infos[0]) );
herk_thrinfo_t** infos = bli_create_herk_thrinfo_paths();
dim_t n_threads = thread_num_threads( infos[0] );
// Invoke the internal back-end.
_Pragma( "omp parallel num_threads(n_threads)" )
{
dim_t omp_id = omp_get_thread_num();
// Invoke the internal back-end.
bli_level3_thread_decorator( n_threads,
(level3_int_t*) bli_herk_int,
alpha,
&a_local,
&ah_local,
beta,
&c_local,
(void*) cntl,
(void**) infos );
bli_herk_int( alpha,
&a_local,
&ah_local,
beta,
&c_local,
cntl,
&infos[omp_id] );
}
bli_herk_thrinfo_free_paths( infos );
}

View File

@@ -84,11 +84,11 @@ herk_thrinfo_t* bli_create_herk_thrinfo_node( thread_comm_t* ocomm, dim_t ocomm_
return thread;
}
void bli_herk_thrinfo_free_paths( herk_thrinfo_t* threads )
void bli_herk_thrinfo_free_paths( herk_thrinfo_t** threads )
{
}
herk_thrinfo_t* bli_create_herk_thrinfo_paths( )
herk_thrinfo_t** bli_create_herk_thrinfo_paths( )
{
dim_t jc_way = read_env( "BLIS_JC_NT" );
dim_t kc_way = read_env( "BLIS_KC_NT" );
@@ -106,7 +106,7 @@ herk_thrinfo_t* bli_create_herk_thrinfo_paths( )
dim_t ir_nt = 1;
herk_thrinfo_t* paths = (herk_thrinfo_t*) malloc( global_num_threads * sizeof( herk_thrinfo_t ) );
herk_thrinfo_t** paths = (herk_thrinfo_t**) malloc( global_num_threads * sizeof( herk_thrinfo_t* ) );
thread_comm_t* global_comm = bli_create_communicator( global_num_threads );
for( int a = 0; a < jc_way; a++ )
@@ -159,11 +159,12 @@ herk_thrinfo_t* bli_create_herk_thrinfo_paths( )
kc_way, b,
NULL, NULL, ic_info);
herk_thrinfo_t* jc_info = &paths[global_comm_id];
bli_setup_herk_thrinfo_node( jc_info, global_comm, global_comm_id,
jc_comm, jc_comm_id,
jc_way, a,
NULL, NULL, kc_info);
herk_thrinfo_t* jc_info = bli_create_herk_thrinfo_node( global_comm, global_comm_id,
jc_comm, jc_comm_id,
jc_way, a,
NULL, NULL, kc_info);
paths[global_comm_id] = jc_info;
}
}
}

View File

@@ -53,8 +53,8 @@ typedef struct herk_thrinfo_s herk_thrinfo_t;
#define herk_thread_sub_opackm( thread ) thread->opackm
#define herk_thread_sub_ipackm( thread ) thread->ipackm
herk_thrinfo_t* bli_herk_create_thrinfo_paths( );
void bli_herk_thrinfo_free_paths();
herk_thrinfo_t** bli_create_herk_thrinfo_paths( );
void bli_herk_thrinfo_free_paths( herk_thrinfo_t** paths );
void bli_setup_herk_thrinfo_node( herk_thrinfo_t* thread,
thread_comm_t* ocomm, dim_t ocomm_id,

View File

@@ -79,23 +79,19 @@ void bli_symm_front( side_t side,
bli_obj_swap( a_local, b_local );
}
gemm_thrinfo_t* infos = bli_gemm_cntl_get_thrinfos();
dim_t n_threads = thread_num_threads( (&infos[0]) );
// Invoke the internal back-end.
_Pragma( "omp parallel num_threads(n_threads)" )
{
dim_t omp_id = omp_get_thread_num();
bli_gemm_int( alpha,
&a_local,
&b_local,
beta,
&c_local,
cntl,
&infos[omp_id] );
}
gemm_thrinfo_t** infos = bli_create_gemm_thrinfo_paths();
dim_t n_threads = thread_num_threads( infos[0] );
// Invoke the internal back-end.
bli_level3_thread_decorator( n_threads,
(level3_int_t*) bli_gemm_int,
alpha,
&a_local,
&b_local,
beta,
&c_local,
(void*) cntl,
(void**) infos );
bli_gemm_thrinfo_free_paths( infos );
}

View File

@@ -93,21 +93,31 @@ void bli_syr2k_front( obj_t* alpha,
cntl );
#else
// Invoke herk twice, using beta only the first time.
bli_herk_int( alpha,
&a_local,
&bt_local,
beta,
&c_local,
cntl,
&BLIS_HERK_SINGLE_THREADED );
herk_thrinfo_t** infos = bli_create_herk_thrinfo_paths();
dim_t n_threads = thread_num_threads( infos[0] );
bli_herk_int( alpha,
&b_local,
&at_local,
&BLIS_ONE,
&c_local,
cntl,
&BLIS_HERK_SINGLE_THREADED );
// Invoke the internal back-end.
bli_level3_thread_decorator( n_threads,
(level3_int_t*) bli_herk_int,
alpha,
&a_local,
&bt_local,
beta,
&c_local,
(void*) cntl,
(void**) infos );
bli_level3_thread_decorator( n_threads,
(level3_int_t*) bli_herk_int,
alpha,
&b_local,
&at_local,
&BLIS_ONE,
&c_local,
(void*) cntl,
(void**) infos );
bli_herk_thrinfo_free_paths( infos );
#endif
}

View File

@@ -72,14 +72,21 @@ void bli_syrk_front( obj_t* alpha,
{
bli_obj_induce_trans( c_local );
}
herk_thrinfo_t** infos = bli_create_herk_thrinfo_paths();
dim_t n_threads = thread_num_threads( infos[0] );
// Invoke the internal back-end.
bli_herk_int( alpha,
&a_local,
&at_local,
beta,
&c_local,
cntl,
&BLIS_HERK_SINGLE_THREADED );
// Invoke the internal back-end.
bli_level3_thread_decorator( n_threads,
(level3_int_t*) bli_herk_int,
alpha,
&a_local,
&at_local,
beta,
&c_local,
(void*) cntl,
(void**) infos );
bli_herk_thrinfo_free_paths( infos );
}

View File

@@ -95,10 +95,14 @@ void bli_barrier( thread_comm_t* communicator, dim_t t_id )
bool_t my_sense = communicator->barrier_sense;
dim_t my_threads_arrived;
_Pragma("omp atomic capture")
my_threads_arrived = communicator->barrier_threads_arrived++;
/*
bli_set_lock(&communicator->barrier_lock);
my_threads_arrived = communicator->barrier_threads_arrived + 1;
communicator->barrier_threads_arrived = my_threads_arrived;
bli_unset_lock(&communicator->barrier_lock);
*/
if( my_threads_arrived == communicator->n_threads ) {
@@ -223,3 +227,31 @@ void bli_get_range( void* thr, dim_t size, dim_t block_factor, dim_t* start, dim
*start = work_id * n_pt;
*end = bli_min( *start + n_pt, size );
}
void bli_get_range_tri_weighted( void* thr, dim_t size, dim_t block_factor, bool_t forward, dim_t* start, dim_t* end)
{
}
void bli_level3_thread_decorator( dim_t n_threads,
level3_int_t* func,
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
void* cntl,
void** thread )
{
_Pragma( "omp parallel num_threads(n_threads)" )
{
dim_t omp_id = omp_get_thread_num();
(*func) ( alpha,
a,
b,
beta,
c,
cntl,
thread[omp_id] );
}
}

View File

@@ -99,4 +99,15 @@ void bli_setup_thread_info( thrinfo_t* thread, thread_comm_t* ocomm, dim_t ocomm
#include "bli_gemm_threading.h"
#include "bli_herk_threading.h"
typedef void (*level3_int_t) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, void* cntl, void* thread );
void bli_level3_thread_decorator( dim_t num_threads,
level3_int_t* func,
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
void* cntl,
void** thread );
#endif