mirror of
https://github.com/amd/blis.git
synced 2026-05-25 02:44:31 +00:00
Added decorator for calling parallelized intermal functions
Will allow for easy support for different threading models
This commit is contained in:
@@ -55,11 +55,6 @@ gemm_t* gemm_cntl_vl_mm;
|
||||
|
||||
gemm_t* gemm_cntl;
|
||||
|
||||
gemm_thrinfo_t* bli_gemm_cntl_get_thrinfos()
|
||||
{
|
||||
return bli_create_gemm_thrinfo_paths( );
|
||||
}
|
||||
|
||||
void bli_gemm_cntl_init()
|
||||
{
|
||||
// Create blocksize objects for each dimension.
|
||||
|
||||
@@ -65,4 +65,3 @@ gemm_t* bli_gemm_cntl_obj_create( impl_t impl_type,
|
||||
gemm_t* sub_gemm,
|
||||
unpackm_t* sub_unpack_c );
|
||||
|
||||
gemm_thrinfo_t* bli_gemm_cntl_get_thrinfos();
|
||||
|
||||
@@ -74,22 +74,19 @@ void bli_gemm_front( obj_t* alpha,
|
||||
bli_obj_induce_trans( c_local );
|
||||
}
|
||||
|
||||
gemm_thrinfo_t* infos = bli_gemm_cntl_get_thrinfos();
|
||||
dim_t n_threads = thread_num_threads( (&infos[0]) );
|
||||
gemm_thrinfo_t** infos = bli_create_gemm_thrinfo_paths();
|
||||
dim_t n_threads = thread_num_threads( infos[0] );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
{
|
||||
dim_t omp_id = omp_get_thread_num();
|
||||
|
||||
bli_gemm_int( alpha,
|
||||
&a_local,
|
||||
&b_local,
|
||||
beta,
|
||||
&c_local,
|
||||
cntl,
|
||||
&infos[omp_id] );
|
||||
}
|
||||
bli_level3_thread_decorator( n_threads,
|
||||
(level3_int_t*) bli_gemm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
&b_local,
|
||||
beta,
|
||||
&c_local,
|
||||
(void*) cntl,
|
||||
(void**) infos );
|
||||
|
||||
bli_gemm_thrinfo_free_paths( infos );
|
||||
}
|
||||
|
||||
@@ -95,11 +95,11 @@ dim_t read_env( char* env )
|
||||
return number;
|
||||
}
|
||||
|
||||
void bli_gemm_thrinfo_free_paths( gemm_thrinfo_t* threads )
|
||||
void bli_gemm_thrinfo_free_paths( gemm_thrinfo_t** threads )
|
||||
{
|
||||
}
|
||||
|
||||
gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( )
|
||||
gemm_thrinfo_t** bli_create_gemm_thrinfo_paths( )
|
||||
{
|
||||
dim_t jc_way = read_env( "BLIS_JC_NT" );
|
||||
dim_t kc_way = read_env( "BLIS_KC_NT" );
|
||||
@@ -117,7 +117,7 @@ gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( )
|
||||
dim_t ir_nt = 1;
|
||||
|
||||
|
||||
gemm_thrinfo_t* paths = (gemm_thrinfo_t*) malloc( global_num_threads * sizeof( gemm_thrinfo_t ) );
|
||||
gemm_thrinfo_t** paths = (gemm_thrinfo_t**) malloc( global_num_threads * sizeof( gemm_thrinfo_t* ) );
|
||||
|
||||
thread_comm_t* global_comm = bli_create_communicator( global_num_threads );
|
||||
for( int a = 0; a < jc_way; a++ )
|
||||
@@ -170,11 +170,11 @@ gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( )
|
||||
kc_way, b,
|
||||
NULL, NULL, ic_info);
|
||||
|
||||
gemm_thrinfo_t* jc_info = &paths[global_comm_id];
|
||||
bli_setup_gemm_thrinfo_node( jc_info, global_comm, global_comm_id,
|
||||
jc_comm, jc_comm_id,
|
||||
jc_way, a,
|
||||
NULL, NULL, kc_info);
|
||||
gemm_thrinfo_t* jc_info = bli_create_gemm_thrinfo_node( global_comm, global_comm_id,
|
||||
jc_comm, jc_comm_id,
|
||||
jc_way, a,
|
||||
NULL, NULL, kc_info);
|
||||
paths[global_comm_id] = jc_info;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,8 +53,8 @@ typedef struct gemm_thrinfo_s gemm_thrinfo_t;
|
||||
#define gemm_thread_sub_opackm( thread ) thread->opackm
|
||||
#define gemm_thread_sub_ipackm( thread ) thread->ipackm
|
||||
|
||||
gemm_thrinfo_t* bli_create_gemm_thrinfo_paths( );
|
||||
void bli_gemm_thrinfo_free_paths( gemm_thrinfo_t* );
|
||||
gemm_thrinfo_t** bli_create_gemm_thrinfo_paths( );
|
||||
void bli_gemm_thrinfo_free_paths( gemm_thrinfo_t** );
|
||||
|
||||
void bli_setup_gemm_thrinfo_node( gemm_thrinfo_t* thread,
|
||||
thread_comm_t* ocomm, dim_t ocomm_id,
|
||||
|
||||
@@ -80,23 +80,19 @@ void bli_hemm_front( side_t side,
|
||||
bli_obj_swap( a_local, b_local );
|
||||
}
|
||||
|
||||
gemm_thrinfo_t* infos = bli_gemm_cntl_get_thrinfos();
|
||||
dim_t n_threads = thread_num_threads( (&infos[0]) );
|
||||
gemm_thrinfo_t** infos = bli_create_gemm_thrinfo_paths();
|
||||
dim_t n_threads = thread_num_threads( infos[0] );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
{
|
||||
dim_t omp_id = omp_get_thread_num();
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_gemm_int( alpha,
|
||||
&a_local,
|
||||
&b_local,
|
||||
beta,
|
||||
&c_local,
|
||||
cntl,
|
||||
&infos[omp_id] );
|
||||
}
|
||||
bli_level3_thread_decorator( n_threads,
|
||||
(level3_int_t*) bli_gemm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
&b_local,
|
||||
beta,
|
||||
&c_local,
|
||||
(void*) cntl,
|
||||
(void**) infos );
|
||||
|
||||
bli_gemm_thrinfo_free_paths( infos );
|
||||
}
|
||||
|
||||
@@ -109,22 +109,34 @@ void bli_her2k_front( obj_t* alpha,
|
||||
&c_local,
|
||||
cntl );
|
||||
#else
|
||||
// Invoke herk twice, using beta only the first time.
|
||||
bli_herk_int( alpha,
|
||||
&a_local,
|
||||
&bh_local,
|
||||
beta,
|
||||
&c_local,
|
||||
cntl,
|
||||
&BLIS_HERK_SINGLE_THREADED );
|
||||
|
||||
bli_herk_int( &alpha_conj,
|
||||
&b_local,
|
||||
&ah_local,
|
||||
&BLIS_ONE,
|
||||
&c_local,
|
||||
cntl,
|
||||
&BLIS_HERK_SINGLE_THREADED );
|
||||
// Invoke herk twice, using beta only the first time.
|
||||
herk_thrinfo_t** infos = bli_create_herk_thrinfo_paths();
|
||||
dim_t n_threads = thread_num_threads( infos[0] );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_level3_thread_decorator( n_threads,
|
||||
(level3_int_t*) bli_herk_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
&bh_local,
|
||||
beta,
|
||||
&c_local,
|
||||
(void*) cntl,
|
||||
(void**) infos );
|
||||
|
||||
bli_level3_thread_decorator( n_threads,
|
||||
(level3_int_t*) bli_herk_int,
|
||||
&alpha_conj,
|
||||
&b_local,
|
||||
&ah_local,
|
||||
&BLIS_ONE,
|
||||
&c_local,
|
||||
(void*) cntl,
|
||||
(void**) infos );
|
||||
|
||||
bli_herk_thrinfo_free_paths( infos );
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -77,24 +77,20 @@ void bli_herk_front( obj_t* alpha,
|
||||
bli_obj_induce_trans( c_local );
|
||||
}
|
||||
|
||||
herk_thrinfo_t* infos = bli_create_herk_thrinfo_paths();
|
||||
dim_t n_threads = thread_num_threads( (&infos[0]) );
|
||||
herk_thrinfo_t** infos = bli_create_herk_thrinfo_paths();
|
||||
dim_t n_threads = thread_num_threads( infos[0] );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
{
|
||||
dim_t omp_id = omp_get_thread_num();
|
||||
// Invoke the internal back-end.
|
||||
bli_level3_thread_decorator( n_threads,
|
||||
(level3_int_t*) bli_herk_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
&ah_local,
|
||||
beta,
|
||||
&c_local,
|
||||
(void*) cntl,
|
||||
(void**) infos );
|
||||
|
||||
|
||||
bli_herk_int( alpha,
|
||||
&a_local,
|
||||
&ah_local,
|
||||
beta,
|
||||
&c_local,
|
||||
cntl,
|
||||
&infos[omp_id] );
|
||||
}
|
||||
|
||||
bli_herk_thrinfo_free_paths( infos );
|
||||
}
|
||||
|
||||
|
||||
@@ -84,11 +84,11 @@ herk_thrinfo_t* bli_create_herk_thrinfo_node( thread_comm_t* ocomm, dim_t ocomm_
|
||||
return thread;
|
||||
}
|
||||
|
||||
void bli_herk_thrinfo_free_paths( herk_thrinfo_t* threads )
|
||||
void bli_herk_thrinfo_free_paths( herk_thrinfo_t** threads )
|
||||
{
|
||||
}
|
||||
|
||||
herk_thrinfo_t* bli_create_herk_thrinfo_paths( )
|
||||
herk_thrinfo_t** bli_create_herk_thrinfo_paths( )
|
||||
{
|
||||
dim_t jc_way = read_env( "BLIS_JC_NT" );
|
||||
dim_t kc_way = read_env( "BLIS_KC_NT" );
|
||||
@@ -106,7 +106,7 @@ herk_thrinfo_t* bli_create_herk_thrinfo_paths( )
|
||||
dim_t ir_nt = 1;
|
||||
|
||||
|
||||
herk_thrinfo_t* paths = (herk_thrinfo_t*) malloc( global_num_threads * sizeof( herk_thrinfo_t ) );
|
||||
herk_thrinfo_t** paths = (herk_thrinfo_t**) malloc( global_num_threads * sizeof( herk_thrinfo_t* ) );
|
||||
|
||||
thread_comm_t* global_comm = bli_create_communicator( global_num_threads );
|
||||
for( int a = 0; a < jc_way; a++ )
|
||||
@@ -159,11 +159,12 @@ herk_thrinfo_t* bli_create_herk_thrinfo_paths( )
|
||||
kc_way, b,
|
||||
NULL, NULL, ic_info);
|
||||
|
||||
herk_thrinfo_t* jc_info = &paths[global_comm_id];
|
||||
bli_setup_herk_thrinfo_node( jc_info, global_comm, global_comm_id,
|
||||
jc_comm, jc_comm_id,
|
||||
jc_way, a,
|
||||
NULL, NULL, kc_info);
|
||||
herk_thrinfo_t* jc_info = bli_create_herk_thrinfo_node( global_comm, global_comm_id,
|
||||
jc_comm, jc_comm_id,
|
||||
jc_way, a,
|
||||
NULL, NULL, kc_info);
|
||||
|
||||
paths[global_comm_id] = jc_info;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,8 +53,8 @@ typedef struct herk_thrinfo_s herk_thrinfo_t;
|
||||
#define herk_thread_sub_opackm( thread ) thread->opackm
|
||||
#define herk_thread_sub_ipackm( thread ) thread->ipackm
|
||||
|
||||
herk_thrinfo_t* bli_herk_create_thrinfo_paths( );
|
||||
void bli_herk_thrinfo_free_paths();
|
||||
herk_thrinfo_t** bli_create_herk_thrinfo_paths( );
|
||||
void bli_herk_thrinfo_free_paths( herk_thrinfo_t** paths );
|
||||
|
||||
void bli_setup_herk_thrinfo_node( herk_thrinfo_t* thread,
|
||||
thread_comm_t* ocomm, dim_t ocomm_id,
|
||||
|
||||
@@ -79,23 +79,19 @@ void bli_symm_front( side_t side,
|
||||
bli_obj_swap( a_local, b_local );
|
||||
}
|
||||
|
||||
gemm_thrinfo_t* infos = bli_gemm_cntl_get_thrinfos();
|
||||
dim_t n_threads = thread_num_threads( (&infos[0]) );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
{
|
||||
dim_t omp_id = omp_get_thread_num();
|
||||
|
||||
|
||||
bli_gemm_int( alpha,
|
||||
&a_local,
|
||||
&b_local,
|
||||
beta,
|
||||
&c_local,
|
||||
cntl,
|
||||
&infos[omp_id] );
|
||||
}
|
||||
gemm_thrinfo_t** infos = bli_create_gemm_thrinfo_paths();
|
||||
dim_t n_threads = thread_num_threads( infos[0] );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_level3_thread_decorator( n_threads,
|
||||
(level3_int_t*) bli_gemm_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
&b_local,
|
||||
beta,
|
||||
&c_local,
|
||||
(void*) cntl,
|
||||
(void**) infos );
|
||||
|
||||
bli_gemm_thrinfo_free_paths( infos );
|
||||
}
|
||||
|
||||
@@ -93,21 +93,31 @@ void bli_syr2k_front( obj_t* alpha,
|
||||
cntl );
|
||||
#else
|
||||
// Invoke herk twice, using beta only the first time.
|
||||
bli_herk_int( alpha,
|
||||
&a_local,
|
||||
&bt_local,
|
||||
beta,
|
||||
&c_local,
|
||||
cntl,
|
||||
&BLIS_HERK_SINGLE_THREADED );
|
||||
herk_thrinfo_t** infos = bli_create_herk_thrinfo_paths();
|
||||
dim_t n_threads = thread_num_threads( infos[0] );
|
||||
|
||||
bli_herk_int( alpha,
|
||||
&b_local,
|
||||
&at_local,
|
||||
&BLIS_ONE,
|
||||
&c_local,
|
||||
cntl,
|
||||
&BLIS_HERK_SINGLE_THREADED );
|
||||
// Invoke the internal back-end.
|
||||
bli_level3_thread_decorator( n_threads,
|
||||
(level3_int_t*) bli_herk_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
&bt_local,
|
||||
beta,
|
||||
&c_local,
|
||||
(void*) cntl,
|
||||
(void**) infos );
|
||||
|
||||
bli_level3_thread_decorator( n_threads,
|
||||
(level3_int_t*) bli_herk_int,
|
||||
alpha,
|
||||
&b_local,
|
||||
&at_local,
|
||||
&BLIS_ONE,
|
||||
&c_local,
|
||||
(void*) cntl,
|
||||
(void**) infos );
|
||||
|
||||
bli_herk_thrinfo_free_paths( infos );
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -72,14 +72,21 @@ void bli_syrk_front( obj_t* alpha,
|
||||
{
|
||||
bli_obj_induce_trans( c_local );
|
||||
}
|
||||
|
||||
herk_thrinfo_t** infos = bli_create_herk_thrinfo_paths();
|
||||
dim_t n_threads = thread_num_threads( infos[0] );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_herk_int( alpha,
|
||||
&a_local,
|
||||
&at_local,
|
||||
beta,
|
||||
&c_local,
|
||||
cntl,
|
||||
&BLIS_HERK_SINGLE_THREADED );
|
||||
// Invoke the internal back-end.
|
||||
bli_level3_thread_decorator( n_threads,
|
||||
(level3_int_t*) bli_herk_int,
|
||||
alpha,
|
||||
&a_local,
|
||||
&at_local,
|
||||
beta,
|
||||
&c_local,
|
||||
(void*) cntl,
|
||||
(void**) infos );
|
||||
|
||||
bli_herk_thrinfo_free_paths( infos );
|
||||
}
|
||||
|
||||
|
||||
@@ -95,10 +95,14 @@ void bli_barrier( thread_comm_t* communicator, dim_t t_id )
|
||||
bool_t my_sense = communicator->barrier_sense;
|
||||
dim_t my_threads_arrived;
|
||||
|
||||
_Pragma("omp atomic capture")
|
||||
my_threads_arrived = communicator->barrier_threads_arrived++;
|
||||
/*
|
||||
bli_set_lock(&communicator->barrier_lock);
|
||||
my_threads_arrived = communicator->barrier_threads_arrived + 1;
|
||||
communicator->barrier_threads_arrived = my_threads_arrived;
|
||||
bli_unset_lock(&communicator->barrier_lock);
|
||||
*/
|
||||
|
||||
if( my_threads_arrived == communicator->n_threads ) {
|
||||
|
||||
@@ -223,3 +227,31 @@ void bli_get_range( void* thr, dim_t size, dim_t block_factor, dim_t* start, dim
|
||||
*start = work_id * n_pt;
|
||||
*end = bli_min( *start + n_pt, size );
|
||||
}
|
||||
|
||||
void bli_get_range_tri_weighted( void* thr, dim_t size, dim_t block_factor, bool_t forward, dim_t* start, dim_t* end)
|
||||
{
|
||||
}
|
||||
|
||||
void bli_level3_thread_decorator( dim_t n_threads,
|
||||
level3_int_t* func,
|
||||
obj_t* alpha,
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
obj_t* beta,
|
||||
obj_t* c,
|
||||
void* cntl,
|
||||
void** thread )
|
||||
{
|
||||
_Pragma( "omp parallel num_threads(n_threads)" )
|
||||
{
|
||||
dim_t omp_id = omp_get_thread_num();
|
||||
|
||||
(*func) ( alpha,
|
||||
a,
|
||||
b,
|
||||
beta,
|
||||
c,
|
||||
cntl,
|
||||
thread[omp_id] );
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,4 +99,15 @@ void bli_setup_thread_info( thrinfo_t* thread, thread_comm_t* ocomm, dim_t ocomm
|
||||
#include "bli_gemm_threading.h"
|
||||
#include "bli_herk_threading.h"
|
||||
|
||||
typedef void (*level3_int_t) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, void* cntl, void* thread );
|
||||
void bli_level3_thread_decorator( dim_t num_threads,
|
||||
level3_int_t* func,
|
||||
obj_t* alpha,
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
obj_t* beta,
|
||||
obj_t* c,
|
||||
void* cntl,
|
||||
void** thread );
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user