mirror of
https://github.com/amd/blis.git
synced 2026-05-11 09:39:59 +00:00
Some TRSM threading fixes/additions
This commit is contained in:
@@ -125,7 +125,7 @@ void bli_trmm_front( side_t side,
|
||||
if ( bli_is_left( side ) ) cntl = l_cntl;
|
||||
else cntl = r_cntl;
|
||||
|
||||
trmm_thrinfo_t** infos = bli_create_trmm_thrinfo_paths( !bli_is_left( side ) );
|
||||
trmm_thrinfo_t** infos = bli_create_trmm_thrinfo_paths( bli_is_right( side ) );
|
||||
dim_t n_threads = thread_num_threads( infos[0] );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
|
||||
@@ -83,15 +83,16 @@ void bli_trsm_blk_var2b( obj_t* a,
|
||||
// Query dimension in partitioning direction.
|
||||
n_trans = bli_obj_width_after_trans( *b );
|
||||
dim_t start, end;
|
||||
bli_get_range_weighted( thread, 0, n_trans,
|
||||
bli_determine_reg_blocksize( b, cntl_blocksize( cntl ) ),
|
||||
bli_obj_is_upper( *c ), &start, &end );
|
||||
bli_get_range( thread, 0, n_trans,
|
||||
//bli_determine_reg_blocksize( b, cntl_blocksize( cntl ) ),
|
||||
8,
|
||||
&start, &end );
|
||||
|
||||
// Partition along the n dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
{
|
||||
// Determine the current algorithmic blocksize.
|
||||
b_alg = bli_determine_blocksize_b( i, n_trans, b,
|
||||
b_alg = bli_determine_blocksize_b( i, end, b,
|
||||
cntl_blocksize( cntl ) );
|
||||
|
||||
// Acquire partitions for B1 and C1.
|
||||
|
||||
@@ -83,15 +83,16 @@ void bli_trsm_blk_var2f( obj_t* a,
|
||||
// Query dimension in partitioning direction.
|
||||
n_trans = bli_obj_width_after_trans( *b );
|
||||
dim_t start, end;
|
||||
bli_get_range_weighted( thread, 0, n_trans,
|
||||
bli_determine_reg_blocksize( b, cntl_blocksize( cntl ) ),
|
||||
bli_obj_is_lower( *c ), &start, &end );
|
||||
bli_get_range( thread, 0, n_trans,
|
||||
//bli_determine_reg_blocksize( b, cntl_blocksize( cntl ) ),
|
||||
8,
|
||||
&start, &end );
|
||||
|
||||
// Partition along the n dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
{
|
||||
// Determine the current algorithmic blocksize.
|
||||
b_alg = bli_determine_blocksize_f( i, n_trans, b,
|
||||
b_alg = bli_determine_blocksize_f( i, end, b,
|
||||
cntl_blocksize( cntl ) );
|
||||
|
||||
// Acquire partitions for B1 and C1.
|
||||
|
||||
@@ -125,9 +125,9 @@ void bli_trsm_front( side_t side,
|
||||
if ( bli_is_left( side ) ) cntl = l_cntl;
|
||||
else cntl = r_cntl;
|
||||
|
||||
trsm_thrinfo_t** infos = bli_create_trsm_thrinfo_paths();
|
||||
trsm_thrinfo_t** infos = bli_create_trsm_thrinfo_paths( bli_is_right( side ) );
|
||||
dim_t n_threads = thread_num_threads( infos[0] );
|
||||
|
||||
|
||||
// Invoke the internal back-end.
|
||||
bli_level3_thread_decorator( n_threads,
|
||||
(level3_int_t) bli_trsm_int,
|
||||
|
||||
@@ -131,17 +131,14 @@ void bli_trsm_int( obj_t* alpha,
|
||||
// packed, this is our last chance to handle the transposition.
|
||||
if ( cntl_is_leaf( cntl ) && bli_obj_has_trans( *c ) )
|
||||
{
|
||||
if( thread_am_ochief( thread ) ) {
|
||||
bli_obj_induce_trans( c_local );
|
||||
bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, c_local );
|
||||
}
|
||||
bli_obj_induce_trans( c_local );
|
||||
bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, c_local );
|
||||
}
|
||||
|
||||
// If beta is non-unit, apply it to the scalar attached to C.
|
||||
if ( !bli_obj_equals( beta, &BLIS_ONE ) )
|
||||
{
|
||||
if( thread_am_ochief( thread ) )
|
||||
bli_obj_scalar_apply_scalar( beta, &c_local );
|
||||
bli_obj_scalar_apply_scalar( beta, &c_local );
|
||||
}
|
||||
|
||||
// Set two bools: one based on the implied side parameter (the structure
|
||||
@@ -157,8 +154,7 @@ void bli_trsm_int( obj_t* alpha,
|
||||
// attached to B (the non-triangular matrix).
|
||||
if ( !bli_obj_equals( alpha, &BLIS_ONE ) )
|
||||
{
|
||||
if( thread_am_ochief( thread ) )
|
||||
bli_obj_scalar_apply_scalar( alpha, &b_local );
|
||||
bli_obj_scalar_apply_scalar( alpha, &b_local );
|
||||
}
|
||||
}
|
||||
else // if ( bli_obj_root_is_triangular( *b ) )
|
||||
@@ -172,8 +168,7 @@ void bli_trsm_int( obj_t* alpha,
|
||||
// attached to A (the non-triangular matrix).
|
||||
if ( !bli_obj_equals( alpha, &BLIS_ONE ) )
|
||||
{
|
||||
if( thread_am_ochief( thread ) )
|
||||
bli_obj_scalar_apply_scalar( alpha, &a_local );
|
||||
bli_obj_scalar_apply_scalar( alpha, &a_local );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -107,23 +107,26 @@ void bli_trsm_thrinfo_free_paths( trsm_thrinfo_t** threads, dim_t num )
|
||||
bli_free( threads );
|
||||
}
|
||||
|
||||
trsm_thrinfo_t** bli_create_trsm_thrinfo_paths( )
|
||||
trsm_thrinfo_t** bli_create_trsm_thrinfo_paths( bool_t right_sided )
|
||||
{
|
||||
dim_t jc_way = 1;
|
||||
dim_t kc_way = 1;
|
||||
dim_t ic_way = 1;
|
||||
dim_t jr_way = 1;
|
||||
dim_t ir_way = 1;
|
||||
|
||||
#ifdef BLIS_ENABLE_MULTITHREADING
|
||||
dim_t jc_in = bli_read_nway_from_env( "BLIS_JC_NT" );
|
||||
/*dim_t kc_in = bli_read_nway_from_env( "BLIS_KC_NT" );*/
|
||||
dim_t ic_in = bli_read_nway_from_env( "BLIS_IC_NT" );
|
||||
dim_t jr_in = bli_read_nway_from_env( "BLIS_JR_NT" );
|
||||
dim_t ir_in = bli_read_nway_from_env( "BLIS_IR_NT" );
|
||||
|
||||
dim_t jr_way = jc_in * ic_in * jr_in * ir_in;
|
||||
#else
|
||||
dim_t jr_way = 1;
|
||||
|
||||
if(!right_sided){
|
||||
jc_way = jc_in;
|
||||
jr_way = jr_in * ic_in * ir_in;
|
||||
}
|
||||
#endif
|
||||
dim_t jc_way = 1;
|
||||
dim_t kc_way = 1;
|
||||
dim_t ic_way = 1;
|
||||
dim_t ir_way = 1;
|
||||
|
||||
dim_t global_num_threads = jc_way * kc_way * ic_way * jr_way * ir_way;
|
||||
assert( global_num_threads != 0 );
|
||||
@@ -171,12 +174,12 @@ trsm_thrinfo_t** bli_create_trsm_thrinfo_paths( )
|
||||
NULL, NULL, ir_info);
|
||||
|
||||
packm_thrinfo_t* packb = bli_create_packm_thread_info( kc_comm, kc_comm_id,
|
||||
ic_comm, ic_comm_id,
|
||||
kc_nt, kc_comm_id );
|
||||
ic_comm, ic_comm_id,
|
||||
kc_nt, kc_comm_id );
|
||||
|
||||
packm_thrinfo_t* packa = bli_create_packm_thread_info( ic_comm, ic_comm_id,
|
||||
jr_comm, jr_comm_id,
|
||||
ic_nt, ic_comm_id );
|
||||
jr_comm, jr_comm_id,
|
||||
ic_nt, ic_comm_id );
|
||||
|
||||
trsm_thrinfo_t* ic_info = bli_create_trsm_thrinfo_node( kc_comm, kc_comm_id,
|
||||
ic_comm, ic_comm_id,
|
||||
|
||||
@@ -55,7 +55,7 @@ typedef struct trsm_thrinfo_s trsm_thrinfo_t;
|
||||
|
||||
#define trsm_my_iter( index, thread ) ( index % thread->n_way == thread->work_id % thread->n_way )
|
||||
|
||||
trsm_thrinfo_t** bli_create_trsm_thrinfo_paths( );
|
||||
trsm_thrinfo_t** bli_create_trsm_thrinfo_paths( bool_t right_sided );
|
||||
void bli_trsm_thrinfo_free_paths( trsm_thrinfo_t** info, dim_t n_threads );
|
||||
|
||||
void bli_setup_trsm_thrinfo_node( trsm_thrinfo_t* thread,
|
||||
|
||||
Reference in New Issue
Block a user