Added support for IC loop parallelism to trsm.

Details:
- Parallelism within the IC loop (3rd loop around the microkernel) is
  now supported within the trsm operation. This is done via a new branch
  on each of the control and thread trees, which guide execution of a
  new trsm-only subproblem from within bli_trsm_blk_var1(). This trsm
  subproblem corresponds to the macrokernel computation on only the
  block of A that contains the diagonal (labeled as A11 in algorithms
  with FLAME-like partitioning), and the corresponding row panel of C.
  During the trsm subproblem, all threads within the JC communicator
  participate and parallelize along the JR loop, including any
  parallelism that was specified for the IC loop. (IR loop parallelism
  is not supported for trsm due to inter-iteration dependencies.) After
  this trsm subproblem is complete, a barrier synchronizes all
  participating threads and then they proceed to apply the prescribed
  BLIS_IC_NT (or equivalent) ways of parallelism (and any BLIS_JR_NT
  parallelism specified within) to the remaining gemm subproblem (the
  rank-k update that is performed using the newly updated row-panel of
  B). Thus, trsm now supports JC, IC, and JR loop parallelism.
- Modified bli_trsm_l_cntl_create() to create the new "prenode" branch
  of the trsm_l cntl_t tree. The trsm_r tree was left unchanged, for
  now, since it is not currently used. (All trsm problems are cast in
  terms of left-side trsm.)
- Updated bli_cntl_free_w_thrinfo() to be able to free the newly shaped
  trsm cntl_t trees. Fixed a potentially latent bug whereby a cntl_t
  subnode is only recursed upon if there existed a corresponding
  thrinfo_t node, which may not always exist (for problems too small
  to employ full parallelization due to the minimum granularity imposed
  by micropanels).
- Updated other functions in frame/base/bli_cntl.c, such as
  bli_cntl_copy() and bli_cntl_mark_family(), to recurse on sub-prenodes
  if they exist.
- Updated bli_thrinfo_free() to recurse into sub-nodes and prenodes
  when they exist, and added support for growing a prenode branch to
  bli_thrinfo_grow() via a corresponding set of help functions named
  with the _prenode() suffix.
- Added a bszid_t field thrinfo_t nodes. This field comes in handy when
  debugging the allocation/release of thrinfo_t nodes, as it helps trace
  the "identity" of each nodes as it is created/destroyed.
- Renamed
    bli_l3_thrinfo_print_paths() -> bli_l3_thrinfo_print_gemm_paths()
  and created a separate bli_l3_thrinfo_print_trsm_paths() function to
  print out the newly reconfigured thrinfo_t trees for the trsm
  operation.
- Trival changes to bli_gemm_blk_var?.c and bli_trsm_blk_var?.c
  regarding variable declarations.
- Removed subpart_t enum values BLIS_SUBPART1T, BLIS_SUBPART1B,
  BLIS_SUBPART1L, BLIS_SUBPART1R. Then added support for two new labels
  (semantically speaking): BLIS_SUBPART1A and BLIS_SUBPART1B, which
  represent the subpartition ahead of and behind, respectively,
  BLIS_SUBPART1. Updated check functions in bli_check.c accordingly.
- Shuffled layering/APIs for bli_acquire_mpart_[mn]dim() and
  bli_acquire_mpart_t2b/b2t(), _l2r/r2l().
- Deprecated old functions in frame/3/bli_l3_thrinfo.c.
This commit is contained in:
Field G. Van Zee
2019-02-14 18:52:45 -06:00
parent 78bc0bc8b6
commit 075143dfd9
23 changed files with 1253 additions and 404 deletions

View File

@@ -42,6 +42,7 @@ void bli_packm_thrinfo_init
dim_t ocomm_id,
dim_t n_way,
dim_t work_id,
bszid_t bszid,
thrinfo_t* sub_node
)
{
@@ -51,6 +52,7 @@ void bli_packm_thrinfo_init
ocomm, ocomm_id,
n_way, work_id,
FALSE,
BLIS_NO_PART,
sub_node
);
}
@@ -66,6 +68,7 @@ void bli_packm_thrinfo_init_single
&BLIS_SINGLE_COMM, 0,
1,
0,
BLIS_NO_PART,
NULL
);
}

View File

@@ -87,6 +87,7 @@ void bli_packm_thrinfo_init
dim_t ocomm_id,
dim_t n_way,
dim_t work_id,
bszid_t bszid,
thrinfo_t* sub_node
);

View File

@@ -36,48 +36,6 @@
#include "blis.h"
#include "assert.h"
#if 0
thrinfo_t* bli_l3_thrinfo_create
(
thrcomm_t* ocomm,
dim_t ocomm_id,
dim_t n_way,
dim_t work_id,
thrinfo_t* sub_node
)
{
return bli_thrinfo_create
(
ocomm, ocomm_id,
n_way,
work_id,
TRUE,
sub_node
);
}
#endif
void bli_l3_thrinfo_init
(
thrinfo_t* thread,
thrcomm_t* ocomm,
dim_t ocomm_id,
dim_t n_way,
dim_t work_id,
thrinfo_t* sub_node
)
{
bli_thrinfo_init
(
thread,
ocomm, ocomm_id,
n_way,
work_id,
TRUE,
sub_node
);
}
void bli_l3_thrinfo_init_single
(
thrinfo_t* thread
@@ -129,13 +87,14 @@ void bli_l3_thrinfo_create_root
xx_way,
work_id,
TRUE,
bszid,
NULL
);
}
// -----------------------------------------------------------------------------
void bli_l3_thrinfo_print_paths
void bli_l3_thrinfo_print_gemm_paths
(
thrinfo_t** threads
)
@@ -159,26 +118,23 @@ void bli_l3_thrinfo_print_paths
dim_t jr_way = bli_thread_n_way( jr_info );
dim_t ir_way = bli_thread_n_way( ir_info );
dim_t gl_nt = bli_thread_num_threads( jc_info );
dim_t jc_nt = bli_thread_num_threads( pc_info );
dim_t pc_nt = bli_thread_num_threads( pb_info );
dim_t pb_nt = bli_thread_num_threads( ic_info );
dim_t ic_nt = bli_thread_num_threads( pa_info );
dim_t pa_nt = bli_thread_num_threads( jr_info );
dim_t jr_nt = bli_thread_num_threads( ir_info );
dim_t jc_nt = bli_thread_num_threads( jc_info );
dim_t pc_nt = bli_thread_num_threads( pc_info );
dim_t pb_nt = bli_thread_num_threads( pb_info );
dim_t ic_nt = bli_thread_num_threads( ic_info );
dim_t pa_nt = bli_thread_num_threads( pa_info );
dim_t jr_nt = bli_thread_num_threads( jr_info );
dim_t ir_nt = bli_thread_num_threads( ir_info );
printf( " gl jc kc pb ic pa jr ir\n" );
printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
( unsigned long )gl_nt,
printf( " jc kc pb ic pa jr ir\n" );
printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
( unsigned long )jc_nt,
( unsigned long )pc_nt,
( unsigned long )pb_nt,
( unsigned long )ic_nt,
( unsigned long )pa_nt,
( unsigned long )jr_nt,
( unsigned long )1 );
printf( "\n" );
printf( " jc kc pb ic pa jr ir\n" );
( unsigned long )ir_nt );
printf( "xx_way: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n",
( unsigned long )jc_way,
( unsigned long )pc_way,
@@ -187,15 +143,15 @@ void bli_l3_thrinfo_print_paths
( unsigned long )pa_way,
( unsigned long )jr_way,
( unsigned long )ir_way );
printf( "=================================================\n" );
printf( "============================================\n" );
dim_t gl_comm_id;
dim_t jc_comm_id;
dim_t pc_comm_id;
dim_t pb_comm_id;
dim_t ic_comm_id;
dim_t pa_comm_id;
dim_t jr_comm_id;
dim_t ir_comm_id;
dim_t jc_work_id;
dim_t pc_work_id;
@@ -216,78 +172,78 @@ void bli_l3_thrinfo_print_paths
// width, MR or NR).
if ( !jc_info )
{
gl_comm_id = jc_comm_id = pc_comm_id = pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = -1;
jc_comm_id = pc_comm_id = pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1;
jc_work_id = pc_work_id = pb_work_id = ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1;
}
else
{
gl_comm_id = bli_thread_ocomm_id( jc_info );
jc_comm_id = bli_thread_ocomm_id( jc_info );
jc_work_id = bli_thread_work_id( jc_info );
pc_info = bli_thrinfo_sub_node( jc_info );
if ( !pc_info )
{
jc_comm_id = pc_comm_id = pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = -1;
pc_comm_id = pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1;
pc_work_id = pb_work_id = ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1;
}
else
{
jc_comm_id = bli_thread_ocomm_id( pc_info );
pc_comm_id = bli_thread_ocomm_id( pc_info );
pc_work_id = bli_thread_work_id( pc_info );
pb_info = bli_thrinfo_sub_node( pc_info );
if ( !pb_info )
{
pc_comm_id = pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = -1;
pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1;
pb_work_id = ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1;
}
else
{
pc_comm_id = bli_thread_ocomm_id( pb_info );
pb_comm_id = bli_thread_ocomm_id( pb_info );
pb_work_id = bli_thread_work_id( pb_info );
ic_info = bli_thrinfo_sub_node( pb_info );
if ( !ic_info )
{
pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = -1;
ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1;
ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1;
}
else
{
pb_comm_id = bli_thread_ocomm_id( ic_info );
ic_comm_id = bli_thread_ocomm_id( ic_info );
ic_work_id = bli_thread_work_id( ic_info );
pa_info = bli_thrinfo_sub_node( ic_info );
if ( !pa_info )
{
ic_comm_id = pa_comm_id = jr_comm_id = -1;
pa_comm_id = jr_comm_id = ir_comm_id = -1;
pa_work_id = jr_work_id = ir_work_id = -1;
}
else
{
ic_comm_id = bli_thread_ocomm_id( pa_info );
pa_comm_id = bli_thread_ocomm_id( pa_info );
pa_work_id = bli_thread_work_id( pa_info );
jr_info = bli_thrinfo_sub_node( pa_info );
if ( !jr_info )
{
pa_comm_id = jr_comm_id = -1;
jr_comm_id = ir_comm_id = -1;
jr_work_id = ir_work_id = -1;
}
else
{
pa_comm_id = bli_thread_ocomm_id( jr_info );
jr_comm_id = bli_thread_ocomm_id( jr_info );
jr_work_id = bli_thread_work_id( jr_info );
ir_info = bli_thrinfo_sub_node( jr_info );
if ( !ir_info )
{
jr_comm_id = -1;
ir_comm_id = -1;
ir_work_id = -1;
}
else
{
jr_comm_id = bli_thread_ocomm_id( ir_info );
ir_comm_id = bli_thread_ocomm_id( ir_info );
ir_work_id = bli_thread_work_id( ir_info );
}
}
@@ -297,15 +253,16 @@ void bli_l3_thrinfo_print_paths
}
}
printf( " gl jc pb kc pa ic jr \n" );
//printf( " gl jc pb kc pa ic jr \n" );
//printf( " gl jc kc pb ic pa jr \n" );
printf( "comm ids: %4ld %4ld %4ld %4ld %4ld %4ld %4ld\n",
( long )gl_comm_id,
( long )jc_comm_id,
( long )pc_comm_id,
( long )pb_comm_id,
( long )ic_comm_id,
( long )pa_comm_id,
( long )jr_comm_id );
( long )jr_comm_id,
( long )ir_comm_id );
printf( "work ids: %4ld %4ld %4ld %4ld %4ld %4ld %4ld\n",
( long )jc_work_id,
( long )pc_work_id,
@@ -314,7 +271,241 @@ void bli_l3_thrinfo_print_paths
( long )pa_work_id,
( long )jr_work_id,
( long )ir_work_id );
printf( "---------------------------------------\n" );
printf( "--------------------------------------------\n" );
}
}
// -----------------------------------------------------------------------------
// -----------------------------------------------------------------------------
// -----------------------------------------------------------------------------
void bli_l3_thrinfo_print_trsm_paths
(
thrinfo_t** threads
)
{
dim_t n_threads = bli_thread_num_threads( threads[0] );
dim_t gl_id;
thrinfo_t* jc_info = threads[0];
thrinfo_t* pc_info = bli_thrinfo_sub_node( jc_info );
thrinfo_t* pb_info = bli_thrinfo_sub_node( pc_info );
thrinfo_t* ic_info = bli_thrinfo_sub_node( pb_info );
thrinfo_t* pa_info = bli_thrinfo_sub_node( ic_info );
thrinfo_t* jr_info = bli_thrinfo_sub_node( pa_info );
thrinfo_t* ir_info = bli_thrinfo_sub_node( jr_info );
thrinfo_t* pa_info0 = bli_thrinfo_sub_prenode( ic_info );
thrinfo_t* jr_info0 = ( pa_info0 ? bli_thrinfo_sub_node( pa_info0 ) : NULL );
thrinfo_t* ir_info0 = ( jr_info0 ? bli_thrinfo_sub_node( jr_info0 ) : NULL );
dim_t jc_way = bli_thread_n_way( jc_info );
dim_t pc_way = bli_thread_n_way( pc_info );
dim_t pb_way = bli_thread_n_way( pb_info );
dim_t ic_way = bli_thread_n_way( ic_info );
dim_t pa_way = bli_thread_n_way( pa_info );
dim_t jr_way = bli_thread_n_way( jr_info );
dim_t ir_way = bli_thread_n_way( ir_info );
dim_t pa_way0 = ( pa_info0 ? bli_thread_n_way( pa_info0 ) : -1 );
dim_t jr_way0 = ( jr_info0 ? bli_thread_n_way( jr_info0 ) : -1 );
dim_t ir_way0 = ( ir_info0 ? bli_thread_n_way( ir_info0 ) : -1 );
dim_t jc_nt = bli_thread_num_threads( jc_info );
dim_t pc_nt = bli_thread_num_threads( pc_info );
dim_t pb_nt = bli_thread_num_threads( pb_info );
dim_t ic_nt = bli_thread_num_threads( ic_info );
dim_t pa_nt = bli_thread_num_threads( pa_info );
dim_t jr_nt = bli_thread_num_threads( jr_info );
dim_t ir_nt = bli_thread_num_threads( ir_info );
dim_t pa_nt0 = ( pa_info0 ? bli_thread_num_threads( pa_info0 ) : -1 );
dim_t jr_nt0 = ( jr_info0 ? bli_thread_num_threads( jr_info0 ) : -1 );
dim_t ir_nt0 = ( ir_info0 ? bli_thread_num_threads( ir_info0 ) : -1 );
printf( " jc kc pb ic pa jr ir\n" );
printf( "xx_nt: %4ld %4ld %4ld %4ld %2ld|%2ld %2ld|%2ld %2ld|%2ld\n",
( long )jc_nt,
( long )pc_nt,
( long )pb_nt,
( long )ic_nt,
( long )pa_nt0, ( long )pa_nt,
( long )jr_nt0, ( long )jr_nt,
( long )ir_nt0, ( long )ir_nt );
printf( "xx_way: %4ld %4ld %4ld %4ld %2ld|%2ld %2ld|%2ld %2ld|%2ld\n",
( long )jc_way,
( long )pc_way,
( long )pb_way,
( long )ic_way,
( long )pa_way0, ( long )pa_way,
( long )jr_way0, ( long )jr_way,
( long )ir_way0, ( long )ir_way );
printf( "==================================================\n" );
dim_t jc_comm_id;
dim_t pc_comm_id;
dim_t pb_comm_id;
dim_t ic_comm_id;
dim_t pa_comm_id0, pa_comm_id;
dim_t jr_comm_id0, jr_comm_id;
dim_t ir_comm_id0, ir_comm_id;
dim_t jc_work_id;
dim_t pc_work_id;
dim_t pb_work_id;
dim_t ic_work_id;
dim_t pa_work_id0, pa_work_id;
dim_t jr_work_id0, jr_work_id;
dim_t ir_work_id0, ir_work_id;
for ( gl_id = 0; gl_id < n_threads; ++gl_id )
{
jc_info = threads[gl_id];
// NOTE: We must check each thrinfo_t pointer for NULLness. Certain threads
// may not fully build their thrinfo_t structures--specifically when the
// dimension being parallelized is not large enough for each thread to have
// even one unit of work (where as unit is usually a single micropanel's
// width, MR or NR).
if ( !jc_info )
{
jc_comm_id = pc_comm_id = pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1;
jc_work_id = pc_work_id = pb_work_id = ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1;
}
else
{
jc_comm_id = bli_thread_ocomm_id( jc_info );
jc_work_id = bli_thread_work_id( jc_info );
pc_info = bli_thrinfo_sub_node( jc_info );
if ( !pc_info )
{
pc_comm_id = pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1;
pc_work_id = pb_work_id = ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1;
}
else
{
pc_comm_id = bli_thread_ocomm_id( pc_info );
pc_work_id = bli_thread_work_id( pc_info );
pb_info = bli_thrinfo_sub_node( pc_info );
if ( !pb_info )
{
pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1;
pb_work_id = ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1;
}
else
{
pb_comm_id = bli_thread_ocomm_id( pb_info );
pb_work_id = bli_thread_work_id( pb_info );
ic_info = bli_thrinfo_sub_node( pb_info );
if ( !ic_info )
{
ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1;
ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1;
}
else
{
ic_comm_id = bli_thread_ocomm_id( ic_info );
ic_work_id = bli_thread_work_id( ic_info );
pa_info0 = bli_thrinfo_sub_prenode( ic_info );
pa_info = bli_thrinfo_sub_node( ic_info );
// Prenode
if ( !pa_info0 )
{
pa_comm_id0 = jr_comm_id0 = ir_comm_id0 = -1;
pa_work_id0 = jr_work_id0 = ir_work_id0 = -1;
}
else
{
pa_comm_id0 = bli_thread_ocomm_id( pa_info0 );
pa_work_id0 = bli_thread_work_id( pa_info0 );
jr_info0 = bli_thrinfo_sub_node( pa_info0 );
if ( !jr_info0 )
{
jr_comm_id0 = ir_comm_id0 = -1;
jr_work_id0 = ir_work_id0 = -1;
}
else
{
jr_comm_id0 = bli_thread_ocomm_id( jr_info0 );
jr_work_id0 = bli_thread_work_id( jr_info0 );
ir_info0 = bli_thrinfo_sub_node( jr_info0 );
if ( !ir_info0 )
{
ir_comm_id0 = -1;
ir_work_id0 = -1;
}
else
{
ir_comm_id0 = bli_thread_ocomm_id( ir_info0 );
ir_work_id0 = bli_thread_work_id( ir_info0 );
}
}
}
// Main node
if ( !pa_info )
{
pa_comm_id = jr_comm_id = ir_comm_id = -1;
pa_work_id = jr_work_id = ir_work_id = -1;
}
else
{
pa_comm_id = bli_thread_ocomm_id( pa_info );
pa_work_id = bli_thread_work_id( pa_info );
jr_info = bli_thrinfo_sub_node( pa_info );
if ( !jr_info )
{
jr_comm_id = ir_comm_id = -1;
jr_work_id = ir_work_id = -1;
}
else
{
jr_comm_id = bli_thread_ocomm_id( jr_info );
jr_work_id = bli_thread_work_id( jr_info );
ir_info = bli_thrinfo_sub_node( jr_info );
if ( !ir_info )
{
ir_comm_id = -1;
ir_work_id = -1;
}
else
{
ir_comm_id = bli_thread_ocomm_id( ir_info );
ir_work_id = bli_thread_work_id( ir_info );
}
}
}
}
}
}
}
printf( "comm ids: %4ld %4ld %4ld %4ld %2ld|%2ld %2ld|%2ld %2ld|%2ld\n",
( long )jc_comm_id,
( long )pc_comm_id,
( long )pb_comm_id,
( long )ic_comm_id,
( long )pa_comm_id0, ( long )pa_comm_id,
( long )jr_comm_id0, ( long )jr_comm_id,
( long )ir_comm_id0, ( long )ir_comm_id );
printf( "work ids: %4ld %4ld %4ld %4ld %2ld|%2ld %2ld|%2ld %2ld|%2ld\n",
( long )jc_work_id,
( long )pc_work_id,
( long )pb_work_id,
( long )ic_work_id,
( long )pa_work_id0, ( long )pa_work_id,
( long )jr_work_id0, ( long )jr_work_id,
( long )ir_work_id0, ( long )ir_work_id );
printf( "--------------------------------------------------\n" );
}
}

View File

@@ -104,7 +104,12 @@ void bli_l3_thrinfo_create_root
thrinfo_t** thread
);
void bli_l3_thrinfo_print_paths
void bli_l3_thrinfo_print_gemm_paths
(
thrinfo_t** threads
);
void bli_l3_thrinfo_print_trsm_paths
(
thrinfo_t** threads
);

View File

@@ -47,15 +47,11 @@ void bli_gemm_blk_var1
)
{
obj_t a1, c1;
dir_t direct;
dim_t i;
dim_t b_alg;
dim_t my_start, my_end;
dim_t b_alg;
// Determine the direction in which to partition (forwards or backwards).
direct = bli_l3_direct( a, b, c, cntl );
dir_t direct = bli_l3_direct( a, b, c, cntl );
// Prune any zero region that exists along the partitioning dimension.
bli_l3_prune_unref_mparts_m( a, b, c, cntl );
@@ -68,7 +64,7 @@ void bli_gemm_blk_var1
);
// Partition along the m dimension.
for ( i = my_start; i < my_end; i += b_alg )
for ( dim_t i = my_start; i < my_end; i += b_alg )
{
// Determine the current algorithmic blocksize.
b_alg = bli_determine_blocksize( direct, i, my_end, a,

View File

@@ -47,15 +47,11 @@ void bli_gemm_blk_var2
)
{
obj_t b1, c1;
dir_t direct;
dim_t i;
dim_t b_alg;
dim_t my_start, my_end;
dim_t b_alg;
// Determine the direction in which to partition (forwards or backwards).
direct = bli_l3_direct( a, b, c, cntl );
dir_t direct = bli_l3_direct( a, b, c, cntl );
// Prune any zero region that exists along the partitioning dimension.
bli_l3_prune_unref_mparts_n( a, b, c, cntl );
@@ -68,7 +64,7 @@ void bli_gemm_blk_var2
);
// Partition along the n dimension.
for ( i = my_start; i < my_end; i += b_alg )
for ( dim_t i = my_start; i < my_end; i += b_alg )
{
// Determine the current algorithmic blocksize.
b_alg = bli_determine_blocksize( direct, i, my_end, b,

View File

@@ -46,24 +46,19 @@ void bli_gemm_blk_var3
)
{
obj_t a1, b1;
dir_t direct;
dim_t i;
dim_t b_alg;
dim_t k_trans;
// Determine the direction in which to partition (forwards or backwards).
direct = bli_l3_direct( a, b, c, cntl );
dir_t direct = bli_l3_direct( a, b, c, cntl );
// Prune any zero region that exists along the partitioning dimension.
bli_l3_prune_unref_mparts_k( a, b, c, cntl );
// Query dimension in partitioning direction.
k_trans = bli_obj_width_after_trans( a );
dim_t k_trans = bli_obj_width_after_trans( a );
// Partition along the k dimension.
for ( i = 0; i < k_trans; i += b_alg )
for ( dim_t i = 0; i < k_trans; i += b_alg )
{
// Determine the current algorithmic blocksize.
b_alg = bli_l3_determine_kc( direct, i, k_trans, a, b,

View File

@@ -35,6 +35,8 @@
#include "blis.h"
//#define PRINT
void bli_trsm_blk_var1
(
obj_t* a,
@@ -46,45 +48,131 @@ void bli_trsm_blk_var1
thrinfo_t* thread
)
{
obj_t a1, c1;
dir_t direct;
dim_t i;
dim_t b_alg;
dim_t my_start, my_end;
dim_t b_alg;
// Determine the direction in which to partition (forwards or backwards).
direct = bli_l3_direct( a, b, c, cntl );
dir_t direct = bli_l3_direct( a, b, c, cntl );
// Prune any zero region that exists along the partitioning dimension.
bli_l3_prune_unref_mparts_m( a, b, c, cntl );
// Determine the current thread's subpartition range.
bli_thread_range_mdim
(
direct, thread, a, b, c, cntl, cntx,
&my_start, &my_end
);
// Isolate the diagonal block A11 and its corresponding row panel C1.
const dim_t kc = bli_obj_width( a );
obj_t a11, c1;
bli_acquire_mpart_mdim( direct, BLIS_SUBPART1,
0, kc, a, &a11 );
bli_acquire_mpart_mdim( direct, BLIS_SUBPART1,
0, kc, c, &c1 );
// Partition along the m dimension.
for ( i = my_start; i < my_end; i += b_alg )
// All threads iterate over the entire diagonal block A11.
my_start = 0; my_end = kc;
#ifdef PRINT
printf( "bli_trsm_blk_var1(): a11 is %d x %d at offsets (%3d, %3d)\n",
(int)bli_obj_length( &a11 ), (int)bli_obj_width( &a11 ),
(int)bli_obj_row_off( &a11 ), (int)bli_obj_col_off( &a11 ) );
printf( "bli_trsm_blk_var1(): entering trsm subproblem loop.\n" );
#endif
// Partition along the m dimension for the trsm subproblem.
for ( dim_t i = my_start; i < my_end; i += b_alg )
{
// Determine the current algorithmic blocksize.
b_alg = bli_determine_blocksize( direct, i, my_end, a,
obj_t a11_1, c1_1;
b_alg = bli_determine_blocksize( direct, i, my_end, &a11,
bli_cntl_bszid( cntl ), cntx );
// Acquire partitions for A1 and C1.
bli_acquire_mpart_mdim( direct, BLIS_SUBPART1,
i, b_alg, a, &a1 );
i, b_alg, &a11, &a11_1 );
bli_acquire_mpart_mdim( direct, BLIS_SUBPART1,
i, b_alg, c, &c1 );
i, b_alg, &c1, &c1_1 );
#ifdef PRINT
printf( "bli_trsm_blk_var1(): a11_1 is %d x %d at offsets (%3d, %3d)\n",
(int)bli_obj_length( &a11_1 ), (int)bli_obj_width( &a11_1 ),
(int)bli_obj_row_off( &a11_1 ), (int)bli_obj_col_off( &a11_1 ) );
#endif
// Perform trsm subproblem.
bli_trsm_int
(
&BLIS_ONE,
&a1,
&a11_1,
b,
&BLIS_ONE,
&c1_1,
cntx,
rntm,
bli_cntl_sub_prenode( cntl ),
bli_thrinfo_sub_prenode( thread )
);
}
#ifdef PRINT
printf( "bli_trsm_blk_var1(): finishing trsm subproblem loop.\n" );
#endif
// We must execute a barrier here because the upcoming rank-k update
// requires the packed matrix B to be fully updated by the trsm
// subproblem.
bli_thread_obarrier( thread );
// Isolate the remaining part of the column panel matrix A, which we do by
// acquiring the subpartition ahead of A11 (that is, A21 or A01, depending
// on whether we are moving forwards or backwards, respectively).
obj_t ax1, cx1;
bli_acquire_mpart_mdim( direct, BLIS_SUBPART1A,
0, kc, a, &ax1 );
bli_acquire_mpart_mdim( direct, BLIS_SUBPART1A,
0, kc, c, &cx1 );
#ifdef PRINT
printf( "bli_trsm_blk_var1(): ax1 is %d x %d at offsets (%3d, %3d)\n",
(int)bli_obj_length( &ax1 ), (int)bli_obj_width( &ax1 ),
(int)bli_obj_row_off( &ax1 ), (int)bli_obj_col_off( &ax1 ) );
#endif
// Determine the current thread's subpartition range for the gemm
// subproblem over Ax1.
bli_thread_range_mdim
(
direct, thread, &ax1, b, &cx1, cntl, cntx,
&my_start, &my_end
);
#ifdef PRINT
printf( "bli_trsm_blk_var1(): entering gemm subproblem loop (%d->%d).\n", (int)my_start, (int)my_end );
#endif
// Partition along the m dimension for the gemm subproblem.
for ( dim_t i = my_start; i < my_end; i += b_alg )
{
obj_t a11, c1;
// Determine the current algorithmic blocksize.
b_alg = bli_determine_blocksize( direct, i, my_end, &ax1,
bli_cntl_bszid( cntl ), cntx );
// Acquire partitions for A1 and C1.
bli_acquire_mpart_mdim( direct, BLIS_SUBPART1,
i, b_alg, &ax1, &a11 );
bli_acquire_mpart_mdim( direct, BLIS_SUBPART1,
i, b_alg, &cx1, &c1 );
#ifdef PRINT
printf( "bli_trsm_blk_var1(): a11 is %d x %d at offsets (%3d, %3d)\n",
(int)bli_obj_length( &a11 ), (int)bli_obj_width( &a11 ),
(int)bli_obj_row_off( &a11 ), (int)bli_obj_col_off( &a11 ) );
#endif
// Perform gemm subproblem. (Note that we use the same backend
// function as before, since we're calling the same macrokernel.)
bli_trsm_int
(
&BLIS_ONE,
&a11,
b,
&BLIS_ONE,
&c1,
@@ -94,5 +182,8 @@ void bli_trsm_blk_var1
bli_thrinfo_sub_node( thread )
);
}
#ifdef PRINT
printf( "bli_trsm_blk_var1(): finishing gemm subproblem loop.\n" );
#endif
}

View File

@@ -47,15 +47,11 @@ void bli_trsm_blk_var2
)
{
obj_t b1, c1;
dir_t direct;
dim_t i;
dim_t b_alg;
dim_t my_start, my_end;
dim_t b_alg;
// Determine the direction in which to partition (forwards or backwards).
direct = bli_l3_direct( a, b, c, cntl );
dir_t direct = bli_l3_direct( a, b, c, cntl );
// Prune any zero region that exists along the partitioning dimension.
bli_l3_prune_unref_mparts_n( a, b, c, cntl );
@@ -68,7 +64,7 @@ void bli_trsm_blk_var2
);
// Partition along the n dimension.
for ( i = my_start; i < my_end; i += b_alg )
for ( dim_t i = my_start; i < my_end; i += b_alg )
{
// Determine the current algorithmic blocksize.
b_alg = bli_determine_blocksize( direct, i, my_end, b,

View File

@@ -46,24 +46,19 @@ void bli_trsm_blk_var3
)
{
obj_t a1, b1;
dir_t direct;
dim_t i;
dim_t b_alg;
dim_t k_trans;
// Determine the direction in which to partition (forwards or backwards).
direct = bli_l3_direct( a, b, c, cntl );
dir_t direct = bli_l3_direct( a, b, c, cntl );
// Prune any zero region that exists along the partitioning dimension.
bli_l3_prune_unref_mparts_k( a, b, c, cntl );
// Query dimension in partitioning direction.
k_trans = bli_obj_width_after_trans( a );
dim_t k_trans = bli_obj_width_after_trans( a );
// Partition along the k dimension.
for ( i = 0; i < k_trans; i += b_alg )
for ( dim_t i = 0; i < k_trans; i += b_alg )
{
// Determine the current algorithmic blocksize.
b_alg = bli_trsm_determine_kc( direct, i, k_trans, a, b,

View File

@@ -69,7 +69,48 @@ cntl_t* bli_trsm_l_cntl_create
const opid_t family = BLIS_TRSM;
// Create two nodes for the macro-kernel.
//
// Create nodes for packing A and the macro-kernel (gemm branch).
//
cntl_t* gemm_cntl_bu_ke = bli_trsm_cntl_create_node
(
rntm, // the thread's runtime structure
family, // the operation family
BLIS_MR, // needed for bli_thrinfo_rgrow()
NULL, // variant function pointer not used
NULL // no sub-node; this is the leaf of the tree.
);
cntl_t* gemm_cntl_bp_bu = bli_trsm_cntl_create_node
(
rntm,
family,
BLIS_NR, // not used by macro-kernel, but needed for bli_thrinfo_rgrow()
macro_kernel_p,
gemm_cntl_bu_ke
);
// Create a node for packing matrix A.
cntl_t* gemm_cntl_packa = bli_packm_cntl_create_node
(
rntm,
bli_trsm_packa, // trsm operation's packm function for A.
packa_fp,
BLIS_MR,
BLIS_MR,
TRUE, // do NOT invert diagonal
TRUE, // reverse iteration if upper?
FALSE, // reverse iteration if lower?
schema_a, // normally BLIS_PACKED_ROW_PANELS
BLIS_BUFFER_FOR_A_BLOCK,
gemm_cntl_bp_bu
);
//
// Create nodes for packing A and the macro-kernel (trsm branch).
//
cntl_t* trsm_cntl_bu_ke = bli_trsm_cntl_create_node
(
rntm, // the thread's runtime structure
@@ -92,7 +133,7 @@ cntl_t* bli_trsm_l_cntl_create
cntl_t* trsm_cntl_packa = bli_packm_cntl_create_node
(
rntm,
bli_trsm_packa,
bli_trsm_packa, // trsm operation's packm function for A.
packa_fp,
BLIS_MR,
BLIS_MR,
@@ -104,16 +145,24 @@ cntl_t* bli_trsm_l_cntl_create
trsm_cntl_bp_bu
);
// -------------------------------------------------------------------------
// Create a node for partitioning the m dimension by MC.
// NOTE: We attach the gemm sub-tree as the main branch.
cntl_t* trsm_cntl_op_bp = bli_trsm_cntl_create_node
(
rntm,
family,
BLIS_MC,
bli_trsm_blk_var1,
trsm_cntl_packa
gemm_cntl_packa
);
// Attach the trsm sub-tree as the auxiliary "prenode" branch.
bli_cntl_set_sub_prenode( trsm_cntl_packa, trsm_cntl_op_bp );
// -------------------------------------------------------------------------
// Create a node for packing matrix B.
cntl_t* trsm_cntl_packb = bli_packm_cntl_create_node
(

View File

@@ -52,6 +52,9 @@ void bli_trsm_int
obj_t c_local;
trsm_var_oft f;
// Return early if the current control tree node is NULL.
if ( bli_cntl_is_null( cntl ) ) return;
// Check parameters.
if ( bli_error_checking_is_enabled() )
bli_gemm_basic_check( alpha, a, b, beta, c, cntx );

View File

@@ -684,10 +684,12 @@ err_t bli_check_valid_3x1_subpart( subpart_t part )
err_t e_val = BLIS_SUCCESS;
if ( part != BLIS_SUBPART0 &&
part != BLIS_SUBPART1T &&
part != BLIS_SUBPART1AND0 &&
part != BLIS_SUBPART1 &&
part != BLIS_SUBPART1B &&
part != BLIS_SUBPART2 )
part != BLIS_SUBPART1AND2 &&
part != BLIS_SUBPART2 &&
part != BLIS_SUBPART1A &&
part != BLIS_SUBPART1B )
e_val = BLIS_INVALID_3x1_SUBPART;
return e_val;
@@ -698,10 +700,12 @@ err_t bli_check_valid_1x3_subpart( subpart_t part )
err_t e_val = BLIS_SUCCESS;
if ( part != BLIS_SUBPART0 &&
part != BLIS_SUBPART1L &&
part != BLIS_SUBPART1AND0 &&
part != BLIS_SUBPART1 &&
part != BLIS_SUBPART1R &&
part != BLIS_SUBPART2 )
part != BLIS_SUBPART1AND2 &&
part != BLIS_SUBPART2 &&
part != BLIS_SUBPART1A &&
part != BLIS_SUBPART1B )
e_val = BLIS_INVALID_1x3_SUBPART;
return e_val;

View File

@@ -59,6 +59,7 @@ cntl_t* bli_cntl_create_node
bli_cntl_set_bszid( bszid, cntl );
bli_cntl_set_var_func( var_func, cntl );
bli_cntl_set_params( params, cntl );
bli_cntl_set_sub_prenode( NULL, cntl );
bli_cntl_set_sub_node( sub_node, cntl );
// Query the address of the node's packed mem_t entry so we can initialize
@@ -95,6 +96,7 @@ void bli_cntl_clear_node
// actually is not needed, but we do it for debugging/completeness.
bli_cntl_set_var_func( NULL, cntl );
bli_cntl_set_params( NULL, cntl );
bli_cntl_set_sub_prenode( NULL, cntl );
bli_cntl_set_sub_node( NULL, cntl );
// Clearing these fields is potentially more important if the control
@@ -126,14 +128,40 @@ void bli_cntl_free_w_thrinfo
// Base case: simply return when asked to free NULL nodes.
if ( cntl == NULL ) return;
cntl_t* cntl_sub_node = bli_cntl_sub_node( cntl );
void* cntl_params = bli_cntl_params( cntl );
mem_t* cntl_pack_mem = bli_cntl_pack_mem( cntl );
cntl_t* cntl_sub_prenode = bli_cntl_sub_prenode( cntl );
cntl_t* cntl_sub_node = bli_cntl_sub_node( cntl );
void* cntl_params = bli_cntl_params( cntl );
mem_t* cntl_pack_mem = bli_cntl_pack_mem( cntl );
thrinfo_t* thread_sub_node = bli_thrinfo_sub_node( thread );
// Don't immediately dereference the prenode and subnode of the thrinfo_t
// node. In some cases, the thrinfo_t tree is not built out all the way,
// perhaps because there are more ways of parallelization than micropanels
// of data in this dimension, or because the problem is small enough that
// there is no gemm subproblem in bli_trsm_blk_var1(). Thus, we start with
// NULL values for these variables and only dereference the fields of the
// thrinfo_t struct if the thrinfo_t exists (ie: is non-NULL). We will also
// have to check the thrinfo_t pointer for NULLness before using it below,
// when checking if we need to free the pack_mem field of the cntl_t node
// (see below).
thrinfo_t* thread_sub_prenode = NULL;
thrinfo_t* thread_sub_node = NULL;
// Only recurse if the current thrinfo_t node has a child.
if ( thread_sub_node != NULL )
if ( thread != NULL )
{
thread_sub_prenode = bli_thrinfo_sub_prenode( thread );
thread_sub_node = bli_thrinfo_sub_node( thread );
}
// Only recurse into prenode branch if it exists.
if ( cntl_sub_prenode != NULL )
{
// Recursively free all memory associated with the sub-prenode and its
// children.
bli_cntl_free_w_thrinfo( rntm, cntl_sub_prenode, thread_sub_prenode );
}
// Only recurse into the child node if it exists.
if ( cntl_sub_node != NULL )
{
// Recursively free all memory associated with the sub-node and its
// children.
@@ -153,6 +181,10 @@ void bli_cntl_free_w_thrinfo
// Release the current node's pack mem_t entry back to the memory
// broker from which it originated, but only if the mem_t entry is
// allocated, and only if the current thread is chief for its group.
// Also note that we don't proceed with either of the above tests if
// the thrinfo_t pointer is NULL. (See above for background on when
// this can happen.)
if ( thread != NULL )
if ( bli_thread_am_ochief( thread ) )
if ( bli_mem_is_alloc( cntl_pack_mem ) )
{
@@ -176,9 +208,16 @@ void bli_cntl_free_wo_thrinfo
// Base case: simply return when asked to free NULL nodes.
if ( cntl == NULL ) return;
cntl_t* cntl_sub_node = bli_cntl_sub_node( cntl );
void* cntl_params = bli_cntl_params( cntl );
mem_t* cntl_pack_mem = bli_cntl_pack_mem( cntl );
cntl_t* cntl_sub_prenode = bli_cntl_sub_prenode( cntl );
cntl_t* cntl_sub_node = bli_cntl_sub_node( cntl );
void* cntl_params = bli_cntl_params( cntl );
mem_t* cntl_pack_mem = bli_cntl_pack_mem( cntl );
{
// Recursively free all memory associated with the sub-prenode and its
// children.
bli_cntl_free_wo_thrinfo( rntm, cntl_sub_prenode );
}
{
// Recursively free all memory associated with the sub-node and its
@@ -244,6 +283,20 @@ cntl_t* bli_cntl_copy
bli_cntl_set_params( params_copy, cntl_copy );
}
// If the sub-prenode exists, copy it recursively.
if ( bli_cntl_sub_prenode( cntl ) != NULL )
{
cntl_t* sub_prenode_copy = bli_cntl_copy
(
rntm,
bli_cntl_sub_prenode( cntl )
);
// Save the address of the new sub-node (sub-tree) to the existing
// node.
bli_cntl_set_sub_prenode( sub_prenode_copy, cntl_copy );
}
// If the sub-node exists, copy it recursively.
if ( bli_cntl_sub_node( cntl ) != NULL )
{
@@ -277,14 +330,18 @@ void bli_cntl_mark_family
// Set the family of the root node.
bli_cntl_set_family( family, cntl );
// Continue as long as the current node has a valid child.
while ( bli_cntl_sub_node( cntl ) != NULL )
// Recursively set the family field of the sub-tree rooted at the sub-node,
// if it exists.
if ( bli_cntl_sub_prenode( cntl ) != NULL )
{
// Move down the tree to the child node.
cntl = bli_cntl_sub_node( cntl );
bli_cntl_mark_family( family, bli_cntl_sub_prenode( cntl ) );
}
// Set the family of the current node.
bli_cntl_set_family( family, cntl );
// Recursively set the family field of the sub-tree rooted at the prenode,
// if it exists.
if ( bli_cntl_sub_node( cntl ) != NULL )
{
bli_cntl_mark_family( family, bli_cntl_sub_node( cntl ) );
}
}

View File

@@ -43,6 +43,7 @@ struct cntl_s
opid_t family;
bszid_t bszid;
void* var_func;
struct cntl_s* sub_prenode;
struct cntl_s* sub_node;
// Optional fields (needed only by some operations such as packm).
@@ -141,6 +142,11 @@ static void* bli_cntl_var_func( cntl_t* cntl )
return cntl->var_func;
}
static cntl_t* bli_cntl_sub_prenode( cntl_t* cntl )
{
return cntl->sub_prenode;
}
static cntl_t* bli_cntl_sub_node( cntl_t* cntl )
{
return cntl->sub_node;
@@ -164,6 +170,12 @@ static mem_t* bli_cntl_pack_mem( cntl_t* cntl )
// cntl_t query (complex)
static bool_t bli_cntl_is_null( cntl_t* cntl )
{
return ( bool_t )
( cntl == NULL );
}
static bool_t bli_cntl_is_leaf( cntl_t* cntl )
{
return ( bool_t )
@@ -193,6 +205,11 @@ static void bli_cntl_set_var_func( void* var_func, cntl_t* cntl )
cntl->var_func = var_func;
}
static void bli_cntl_set_sub_prenode( cntl_t* sub_prenode, cntl_t* cntl )
{
cntl->sub_prenode = sub_prenode;
}
static void bli_cntl_set_sub_node( cntl_t* sub_node, cntl_t* cntl )
{
cntl->sub_node = sub_node;

View File

@@ -81,9 +81,8 @@ void bli_acquire_mpart
}
void bli_acquire_mpart_mdim
void bli_acquire_mpart_t2b
(
dir_t direct,
subpart_t req_part,
dim_t i,
dim_t b,
@@ -91,14 +90,11 @@ void bli_acquire_mpart_mdim
obj_t* sub_obj
)
{
if ( direct == BLIS_FWD )
bli_acquire_mpart_t2b( req_part, i, b, obj, sub_obj );
else
bli_acquire_mpart_b2t( req_part, i, b, obj, sub_obj );
bli_acquire_mpart_mdim( BLIS_FWD, req_part, i, b, obj, sub_obj );
}
void bli_acquire_mpart_t2b
void bli_acquire_mpart_b2t
(
subpart_t req_part,
dim_t i,
@@ -106,6 +102,20 @@ void bli_acquire_mpart_t2b
obj_t* obj,
obj_t* sub_obj
)
{
bli_acquire_mpart_mdim( BLIS_BWD, req_part, i, b, obj, sub_obj );
}
void bli_acquire_mpart_mdim
(
dir_t direct,
subpart_t req_part,
dim_t i,
dim_t b,
obj_t* obj,
obj_t* sub_obj
)
{
dim_t m;
dim_t n;
@@ -116,6 +126,18 @@ void bli_acquire_mpart_t2b
doff_t diag_off_inc;
// NOTE: Most of this function implicitly assumes moving forward.
// When moving backward, we have to relocate i.
if ( direct == BLIS_BWD )
{
// Query the dimension in the partitioning direction.
dim_t m = bli_obj_length_after_trans( obj );
// Modify i to account for the fact that we are moving backwards.
i = m - i - b;
}
// Call a special function for partitioning packed objects. (By only
// catching those objects packed to panels, we omit cases where the
// object is packed to row or column storage, as such objects can be
@@ -151,9 +173,22 @@ void bli_acquire_mpart_t2b
if ( b > m - i ) b = m - i;
// Support SUBPART1B (behind SUBPART1) and SUBPART1A (ahead of SUBPART1),
// to refer to subpartitions 0 and 2 when moving forward, and 2 and 0 when
// moving backward.
subpart_t subpart0_alias;
subpart_t subpart2_alias;
if ( direct == BLIS_FWD ) { subpart0_alias = BLIS_SUBPART1B;
subpart2_alias = BLIS_SUBPART1A; }
else { subpart0_alias = BLIS_SUBPART1A;
subpart2_alias = BLIS_SUBPART1B; }
// Compute offset increments and dimensions based on which
// subpartition is being requested, assuming no transposition.
if ( req_part == BLIS_SUBPART0 )
if ( req_part == BLIS_SUBPART0 ||
req_part == subpart0_alias )
{
// A0 (offm,offn) unchanged.
// A0 is i x n.
@@ -162,10 +197,10 @@ void bli_acquire_mpart_t2b
m_part = i;
n_part = n;
}
else if ( req_part == BLIS_SUBPART1T )
else if ( req_part == BLIS_SUBPART1AND0 )
{
// A1T (offm,offn) unchanged.
// A1T is (i+b) x n.
// A1+A0 (offm,offn) unchanged.
// A1+A0 is (i+b) x n.
offm_inc = 0;
offn_inc = 0;
m_part = i + b;
@@ -180,16 +215,17 @@ void bli_acquire_mpart_t2b
m_part = b;
n_part = n;
}
else if ( req_part == BLIS_SUBPART1B )
else if ( req_part == BLIS_SUBPART1AND2 )
{
// A1B (offm,offn) += (i,0).
// A1B is (m-i) x n.
// A1+A2 (offm,offn) += (i,0).
// A1+A2 is (m-i) x n.
offm_inc = i;
offn_inc = 0;
m_part = m - i;
n_part = n;
}
else // if ( req_part == BLIS_SUBPART2 )
else if ( req_part == BLIS_SUBPART2 ||
req_part == subpart2_alias )
{
// A2 (offm,offn) += (i+b,0).
// A2 is (m-i-b) x n.
@@ -271,7 +307,7 @@ void bli_acquire_mpart_t2b
}
void bli_acquire_mpart_b2t
void bli_acquire_mpart_l2r
(
subpart_t req_part,
dim_t i,
@@ -280,37 +316,26 @@ void bli_acquire_mpart_b2t
obj_t* sub_obj
)
{
dim_t m;
bli_acquire_mpart_ndim( BLIS_FWD, req_part, i, b, obj, sub_obj );
}
// Query the dimension in the partitioning direction.
m = bli_obj_length_after_trans( obj );
// Modify i to account for the fact that we are moving backwards.
i = m - i - b;
bli_acquire_mpart_t2b( req_part, i, b, obj, sub_obj );
void bli_acquire_mpart_r2l
(
subpart_t req_part,
dim_t j,
dim_t b,
obj_t* obj,
obj_t* sub_obj
)
{
bli_acquire_mpart_ndim( BLIS_BWD, req_part, j, b, obj, sub_obj );
}
void bli_acquire_mpart_ndim
(
dir_t direct,
subpart_t req_part,
dim_t i,
dim_t b,
obj_t* obj,
obj_t* sub_obj
)
{
if ( direct == BLIS_FWD )
bli_acquire_mpart_l2r( req_part, i, b, obj, sub_obj );
else
bli_acquire_mpart_r2l( req_part, i, b, obj, sub_obj );
}
void bli_acquire_mpart_l2r
(
subpart_t req_part,
dim_t j,
dim_t b,
@@ -327,6 +352,18 @@ void bli_acquire_mpart_l2r
doff_t diag_off_inc;
// NOTE: Most of this function implicitly assumes moving forward.
// When moving backward, we have to relocate j.
if ( direct == BLIS_BWD )
{
// Query the dimension in the partitioning direction.
dim_t n = bli_obj_width_after_trans( obj );
// Modify i to account for the fact that we are moving backwards.
j = n - j - b;
}
// Call a special function for partitioning packed objects. (By only
// catching those objects packed to panels, we omit cases where the
// object is packed to row or column storage, as such objects can be
@@ -362,9 +399,22 @@ void bli_acquire_mpart_l2r
if ( b > n - j ) b = n - j;
// Support SUBPART1B (behind SUBPART1) and SUBPART1A (ahead of SUBPART1),
// to refer to subpartitions 0 and 2 when moving forward, and 2 and 0 when
// moving backward.
subpart_t subpart0_alias;
subpart_t subpart2_alias;
if ( direct == BLIS_FWD ) { subpart0_alias = BLIS_SUBPART1B;
subpart2_alias = BLIS_SUBPART1A; }
else { subpart0_alias = BLIS_SUBPART1A;
subpart2_alias = BLIS_SUBPART1B; }
// Compute offset increments and dimensions based on which
// subpartition is being requested, assuming no transposition.
if ( req_part == BLIS_SUBPART0 )
if ( req_part == BLIS_SUBPART0 ||
req_part == subpart0_alias )
{
// A0 (offm,offn) unchanged.
// A0 is m x j.
@@ -373,10 +423,10 @@ void bli_acquire_mpart_l2r
m_part = m;
n_part = j;
}
else if ( req_part == BLIS_SUBPART1L )
else if ( req_part == BLIS_SUBPART1AND0 )
{
// A1L (offm,offn) unchanged.
// A1L is m x (j+b).
// A1+A0 (offm,offn) unchanged.
// A1+A0 is m x (j+b).
offm_inc = 0;
offn_inc = 0;
m_part = m;
@@ -391,16 +441,17 @@ void bli_acquire_mpart_l2r
m_part = m;
n_part = b;
}
else if ( req_part == BLIS_SUBPART1R )
else if ( req_part == BLIS_SUBPART1AND2 )
{
// A1R (offm,offn) += (0,j).
// A1R is m x (n-j).
// A1+A2 (offm,offn) += (0,j).
// A1+A2 is m x (n-j).
offm_inc = 0;
offn_inc = j;
m_part = m;
n_part = n - j;
}
else // if ( req_part == BLIS_SUBPART2 )
else if ( req_part == BLIS_SUBPART2 ||
req_part == subpart2_alias )
{
// A2 (offm,offn) += (0,j+b).
// A2 is m x (n-j-b).
@@ -481,7 +532,20 @@ void bli_acquire_mpart_l2r
}
void bli_acquire_mpart_r2l
void bli_acquire_mpart_tl2br
(
subpart_t req_part,
dim_t i,
dim_t b,
obj_t* obj,
obj_t* sub_obj
)
{
bli_acquire_mpart_mndim( BLIS_FWD, req_part, i, b, obj, sub_obj );
}
void bli_acquire_mpart_br2tl
(
subpart_t req_part,
dim_t j,
@@ -490,20 +554,13 @@ void bli_acquire_mpart_r2l
obj_t* sub_obj
)
{
dim_t n;
// Query the dimension in the partitioning direction.
n = bli_obj_width_after_trans( obj );
// Modify i to account for the fact that we are moving backwards.
j = n - j - b;
bli_acquire_mpart_l2r( req_part, j, b, obj, sub_obj );
bli_acquire_mpart_mndim( BLIS_BWD, req_part, j, b, obj, sub_obj );
}
void bli_acquire_mpart_tl2br
void bli_acquire_mpart_mndim
(
dir_t direct,
subpart_t req_part,
dim_t ij,
dim_t b,
@@ -521,6 +578,18 @@ void bli_acquire_mpart_tl2br
doff_t diag_off_inc;
// NOTE: Most of this function implicitly assumes moving forward.
// When moving backward, we have to relocate ij.
if ( direct == BLIS_BWD )
{
// Query the dimension of the object.
dim_t mn = bli_obj_length( obj );
// Modify ij to account for the fact that we are moving backwards.
ij = mn - ij - b;
}
// Call a special function for partitioning packed objects. (By only
// catching those objects packed to panels, we omit cases where the
// object is packed to row or column storage, as such objects can be
@@ -730,25 +799,6 @@ void bli_acquire_mpart_tl2br
}
void bli_acquire_mpart_br2tl
(
subpart_t req_part,
dim_t ij,
dim_t b,
obj_t* obj,
obj_t* sub_obj
)
{
// Query the dimension of the object.
dim_t mn = bli_obj_length( obj );
// Modify ij to account for the fact that we are moving backwards.
ij = mn - ij - b;
bli_acquire_mpart_tl2br( req_part, ij, b, obj, sub_obj );
}
// -- Vector partitioning ------------------------------------------------------
@@ -762,9 +812,9 @@ void bli_acquire_vpart_f2b
)
{
if ( bli_obj_is_col_vector( obj ) )
bli_acquire_mpart_t2b( req_part, i, b, obj, sub_obj );
bli_acquire_mpart_mdim( BLIS_FWD, req_part, i, b, obj, sub_obj );
else // if ( bli_obj_is_row_vector( obj ) )
bli_acquire_mpart_l2r( req_part, i, b, obj, sub_obj );
bli_acquire_mpart_ndim( BLIS_FWD, req_part, i, b, obj, sub_obj );
}
@@ -778,9 +828,9 @@ void bli_acquire_vpart_b2f
)
{
if ( bli_obj_is_col_vector( obj ) )
bli_acquire_mpart_b2t( req_part, i, b, obj, sub_obj );
bli_acquire_mpart_mdim( BLIS_BWD, req_part, i, b, obj, sub_obj );
else // if ( bli_obj_is_row_vector( obj ) )
bli_acquire_mpart_r2l( req_part, i, b, obj, sub_obj );
bli_acquire_mpart_ndim( BLIS_BWD, req_part, i, b, obj, sub_obj );
}
@@ -797,8 +847,8 @@ void bli_acquire_mij
{
obj_t tmp_obj;
bli_acquire_mpart_l2r( BLIS_SUBPART1, j, 1, obj, &tmp_obj );
bli_acquire_mpart_t2b( BLIS_SUBPART1, i, 1, &tmp_obj, sub_obj );
bli_acquire_mpart_ndim( BLIS_FWD, BLIS_SUBPART1, j, 1, obj, &tmp_obj );
bli_acquire_mpart_mdim( BLIS_FWD, BLIS_SUBPART1, i, 1, &tmp_obj, sub_obj );
}
@@ -810,8 +860,8 @@ void bli_acquire_vi
)
{
if ( bli_obj_is_col_vector( obj ) )
bli_acquire_mpart_t2b( BLIS_SUBPART1, i, 1, obj, sub_obj );
bli_acquire_mpart_mdim( BLIS_FWD, BLIS_SUBPART1, i, 1, obj, sub_obj );
else // if ( bli_obj_is_row_vector( obj ) )
bli_acquire_mpart_l2r( BLIS_SUBPART1, i, 1, obj, sub_obj );
bli_acquire_mpart_ndim( BLIS_FWD, BLIS_SUBPART1, i, 1, obj, sub_obj );
}

View File

@@ -46,22 +46,6 @@ void bli_acquire_mpart
obj_t* sub_obj
);
#undef GENPROT
#define GENPROT( opname ) \
\
void PASTEMAC0( opname ) \
( \
dir_t direct, \
subpart_t req_part, \
dim_t i, \
dim_t b, \
obj_t* obj, \
obj_t* sub_obj \
);
GENPROT( acquire_mpart_mdim )
GENPROT( acquire_mpart_ndim )
#undef GENPROT
#define GENPROT( opname ) \
\
@@ -81,8 +65,39 @@ GENPROT( acquire_mpart_r2l )
GENPROT( acquire_mpart_tl2br )
GENPROT( acquire_mpart_br2tl )
#undef GENPROT
#define GENPROT( opname ) \
\
void PASTEMAC0( opname ) \
( \
dir_t direct, \
subpart_t req_part, \
dim_t i, \
dim_t b, \
obj_t* obj, \
obj_t* sub_obj \
);
GENPROT( acquire_mpart_mdim )
GENPROT( acquire_mpart_ndim )
GENPROT( acquire_mpart_mndim )
// -- Vector partitioning ------------------------------------------------------
#undef GENPROT
#define GENPROT( opname ) \
\
void PASTEMAC0( opname ) \
( \
subpart_t req_part, \
dim_t i, \
dim_t b, \
obj_t* obj, \
obj_t* sub_obj \
);
GENPROT( acquire_vpart_f2b )
GENPROT( acquire_vpart_b2f )

View File

@@ -123,7 +123,13 @@ void bli_pool_finalize
const siz_t top_index = bli_pool_top_index( pool );
// Sanity check: The top_index should be zero.
if ( top_index != 0 ) bli_abort();
if ( top_index != 0 )
{
printf( "bli_pool_finalize(): final top_index == %d (expected 0); block_size: %d.\n",
( int )top_index, ( int )bli_pool_block_size( pool ) );
printf( "bli_pool_finalize(): Implication: not all blocks were checked back in!\n" );
bli_abort();
}
// Query the free() function pointer for the pool.
free_ft free_fp = bli_pool_free_fp( pool );

View File

@@ -101,16 +101,15 @@ bli_rntm_print( rntm );
}
else if ( l3_op == BLIS_TRSM )
{
// For trsm_l, we extract all parallelism from the jc and jr loops.
// For trsm_r, we extract all parallelism from the ic loop.
//printf( "bli_rntm_set_ways_for_op(): jc%d ic%d jr%d\n", (int)jc, (int)ic, (int)jr );
if ( bli_is_left( side ) )
{
bli_rntm_set_ways_only
(
jc,
1,
1,
ic * pc * jr * ir,
ic,
jr,
1,
rntm
);

View File

@@ -594,10 +594,10 @@ typedef enum
BLIS_SUBPART0,
BLIS_SUBPART1,
BLIS_SUBPART2,
BLIS_SUBPART1T,
BLIS_SUBPART1AND0,
BLIS_SUBPART1AND2,
BLIS_SUBPART1A,
BLIS_SUBPART1B,
BLIS_SUBPART1L,
BLIS_SUBPART1R,
BLIS_SUBPART00,
BLIS_SUBPART10,
BLIS_SUBPART20,
@@ -1015,6 +1015,7 @@ struct cntl_s
opid_t family;
bszid_t bszid;
void* var_func;
struct cntl_s* sub_prenode;
struct cntl_s* sub_node;
// Optional fields (needed only by some operations such as packm).

View File

@@ -317,6 +317,7 @@ void bli_l3_thread_decorator
// Create the root node of the current thread's thrinfo_t structure.
bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread );
#if 1
func
(
alpha,
@@ -329,6 +330,14 @@ void bli_l3_thread_decorator
cntl_use,
thread
);
#else
bli_thrinfo_grow_tree
(
rntm_p,
cntl_use,
thread
);
#endif
// Free the thread's local control tree.
bli_l3_cntl_free( rntm_p, cntl_use, thread );
@@ -346,9 +355,9 @@ void bli_l3_thread_decorator
// (called above).
#ifdef PRINT_THRINFO
bli_l3_thrinfo_print_paths( threads );
if ( family != BLIS_TRSM ) bli_l3_thrinfo_print_gemm_paths( threads );
else bli_l3_thrinfo_print_trsm_paths( threads );
exit(1);
//bli_l3_thrinfo_free_paths( rntm_p, threads );
#endif
// Check the array_t back into the small block allocator. Similar to the
@@ -414,4 +423,3 @@ void bli_l3_thread_decorator_thread_check
}
#endif

View File

@@ -43,6 +43,7 @@ thrinfo_t* bli_thrinfo_create
dim_t n_way,
dim_t work_id,
bool_t free_comm,
bszid_t bszid,
thrinfo_t* sub_node
)
{
@@ -58,6 +59,7 @@ thrinfo_t* bli_thrinfo_create
ocomm, ocomm_id,
n_way, work_id,
free_comm,
bszid,
sub_node
);
@@ -72,6 +74,7 @@ void bli_thrinfo_init
dim_t n_way,
dim_t work_id,
bool_t free_comm,
bszid_t bszid,
thrinfo_t* sub_node
)
{
@@ -80,8 +83,10 @@ void bli_thrinfo_init
thread->n_way = n_way;
thread->work_id = work_id;
thread->free_comm = free_comm;
thread->bszid = bszid;
thread->sub_node = sub_node;
thread->sub_prenode = NULL;
thread->sub_node = sub_node;
}
void bli_thrinfo_init_single
@@ -96,6 +101,7 @@ void bli_thrinfo_init_single
1,
0,
FALSE,
BLIS_NO_PART,
thread
);
}
@@ -111,7 +117,20 @@ void bli_thrinfo_free
thread == &BLIS_GEMM_SINGLE_THREADED
) return;
thrinfo_t* thrinfo_sub_node = bli_thrinfo_sub_node( thread );
thrinfo_t* thrinfo_sub_prenode = bli_thrinfo_sub_prenode( thread );
thrinfo_t* thrinfo_sub_node = bli_thrinfo_sub_node( thread );
// Recursively free all children of the current thrinfo_t.
if ( thrinfo_sub_prenode != NULL )
{
bli_thrinfo_free( rntm, thrinfo_sub_prenode );
}
// Recursively free all children of the current thrinfo_t.
if ( thrinfo_sub_node != NULL )
{
bli_thrinfo_free( rntm, thrinfo_sub_node );
}
// Free the communicators, but only if the current thrinfo_t struct
// is marked as needing them to be freed. The most common example of
@@ -119,15 +138,11 @@ void bli_thrinfo_free
// associated with packm thrinfo_t nodes.
if ( bli_thrinfo_needs_free_comm( thread ) )
{
// The ochief always frees his communicator, and the ichief free its
// communicator if we are at the leaf node.
// The ochief always frees his communicator.
if ( bli_thread_am_ochief( thread ) )
bli_thrcomm_free( rntm, bli_thrinfo_ocomm( thread ) );
}
// Recursively free all children of the current thrinfo_t.
bli_thrinfo_free( rntm, thrinfo_sub_node );
#ifdef BLIS_ENABLE_MEM_TRACING
printf( "bli_thrinfo_free(): " );
#endif
@@ -138,97 +153,6 @@ void bli_thrinfo_free
// -----------------------------------------------------------------------------
#include "assert.h"
#define BLIS_NUM_STATIC_COMMS 80
thrinfo_t* bli_thrinfo_create_for_cntl
(
rntm_t* rntm,
cntl_t* cntl_par,
cntl_t* cntl_chl,
thrinfo_t* thread_par
)
{
thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ];
thrcomm_t** new_comms = NULL;
thrinfo_t* thread_chl;
const bszid_t bszid_chl = bli_cntl_bszid( cntl_chl );
const dim_t parent_nt_in = bli_thread_num_threads( thread_par );
const dim_t parent_n_way = bli_thread_n_way( thread_par );
const dim_t parent_comm_id = bli_thread_ocomm_id( thread_par );
const dim_t parent_work_id = bli_thread_work_id( thread_par );
dim_t child_nt_in;
dim_t child_comm_id;
dim_t child_n_way;
dim_t child_work_id;
// Sanity check: make sure the number of threads in the parent's
// communicator is divisible by the number of new sub-groups.
assert( parent_nt_in % parent_n_way == 0 );
// Compute:
// - the number of threads inside the new child comm,
// - the current thread's id within the new communicator,
// - the current thread's work id, given the ways of parallelism
// to be obtained within the next loop.
child_nt_in = bli_cntl_calc_num_threads_in( rntm, cntl_chl );
child_n_way = bli_rntm_ways_for( bszid_chl, rntm );
child_comm_id = parent_comm_id % child_nt_in;
child_work_id = child_comm_id / ( child_nt_in / child_n_way );
// The parent's chief thread creates a temporary array of thrcomm_t
// pointers.
if ( bli_thread_am_ochief( thread_par ) )
{
if ( parent_n_way > BLIS_NUM_STATIC_COMMS )
new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ) );
else
new_comms = static_comms;
}
// Broadcast the temporary array to all threads in the parent's
// communicator.
new_comms = bli_thread_obroadcast( thread_par, new_comms );
// Chiefs in the child communicator allocate the communicator
// object and store it in the array element corresponding to the
// parent's work id.
if ( child_comm_id == 0 )
new_comms[ parent_work_id ] = bli_thrcomm_create( rntm, child_nt_in );
bli_thread_obarrier( thread_par );
// All threads create a new thrinfo_t node using the communicator
// that was created by their chief, as identified by parent_work_id.
thread_chl = bli_thrinfo_create
(
rntm,
new_comms[ parent_work_id ],
child_comm_id,
child_n_way,
child_work_id,
TRUE,
NULL
);
bli_thread_obarrier( thread_par );
// The parent's chief thread frees the temporary array of thrcomm_t
// pointers.
if ( bli_thread_am_ochief( thread_par ) )
{
if ( parent_n_way > BLIS_NUM_STATIC_COMMS )
bli_free_intl( new_comms );
}
return thread_chl;
}
void bli_thrinfo_grow
(
rntm_t* rntm,
@@ -236,24 +160,72 @@ void bli_thrinfo_grow
thrinfo_t* thread
)
{
// If the sub-node of the thrinfo_t object is non-NULL, we don't
// need to create it, and will just use the existing sub-node as-is.
if ( bli_thrinfo_sub_node( thread ) != NULL ) return;
// First, consider the prenode branch of the thrinfo_t tree, which should be
// expanded only if there exists a prenode branch in the cntl_t tree.
// Create a new node (or, if needed, multiple nodes) and return the
// pointer to the (eldest) child.
thrinfo_t* thread_child = bli_thrinfo_rgrow
(
rntm,
cntl,
bli_cntl_sub_node( cntl ),
thread
);
if ( bli_cntl_sub_prenode( cntl ) != NULL )
{
// We only need to take action if the thrinfo_t sub-node is NULL; if it
// is non-NULL, then it has already been created and we'll use it as-is.
if ( bli_thrinfo_sub_prenode( thread ) == NULL )
{
// Assertion / sanity check.
if ( bli_cntl_bszid( cntl ) != BLIS_MC )
{
printf( "Assertion failed: Expanding prenode for non-IC loop?\n" );
bli_abort();
}
// Attach the child thrinfo_t node to its parent structure.
bli_thrinfo_set_sub_node( thread_child, thread );
// Now we must create the packa, jr, and ir nodes that make up
// the prenode branch of current cntl_t node.
// Create a new node (or, if needed, multiple nodes) along the
// prenode branch of the tree and return the pointer to the
// (highest) child.
thrinfo_t* thread_prenode = bli_thrinfo_rgrow_prenode
(
rntm,
cntl,
bli_cntl_sub_prenode( cntl ),
thread
);
// Attach the child thrinfo_t node for the secondary branch to its
// parent structure.
bli_thrinfo_set_sub_prenode( thread_prenode, thread );
}
}
// Now, grow the primary branch of the thrinfo_t tree.
// NOTE: If bli_thrinfo_rgrow() is being called, the sub_node field will
// always be non-NULL, and so there's no need to check it.
//if ( bli_cntl_sub_node( cntl ) != NULL )
{
// We only need to take action if the thrinfo_t sub-node is NULL; if it
// is non-NULL, then it has already been created and we'll use it as-is.
if ( bli_thrinfo_sub_node( thread ) == NULL )
{
// Create a new node (or, if needed, multiple nodes) along the
// main sub-node branch of the tree and return the pointer to the
// (highest) child.
thrinfo_t* thread_child = bli_thrinfo_rgrow
(
rntm,
cntl,
bli_cntl_sub_node( cntl ),
thread
);
// Attach the child thrinfo_t node for the primary branch to its
// parent structure.
bli_thrinfo_set_sub_node( thread_child, thread );
}
}
}
// -----------------------------------------------------------------------------
thrinfo_t* bli_thrinfo_rgrow
(
rntm_t* rntm,
@@ -291,25 +263,368 @@ thrinfo_t* bli_thrinfo_rgrow
thread_par
);
// Create a thrinfo_t node corresponding to cntl_cur. Notice that
// the free_comm field is set to FALSE, since cntl_cur is a
// non-partitioning node. The communicator used here will be
// freed when thread_seg, or one of its descendents, is freed.
// Create a thrinfo_t node corresponding to cntl_cur. Since the
// corresponding cntl node, cntl_cur, is a non-partitioning node
// (bszid = BLIS_NO_PART), this means it's a packing node. Packing
// thrinfo_t nodes are formed differently than those corresponding to
// partitioning nodes; specifically, their work_id's are set equal to
// the their comm_id's. Also, notice that the free_comm field is set
// to FALSE since cntl_cur is a non-partitioning node. The reason:
// the communicator used here will be freed when thread_seg, or one
// of its descendents, is freed.
thread_cur = bli_thrinfo_create
(
rntm,
bli_thrinfo_ocomm( thread_seg ),
bli_thread_ocomm_id( thread_seg ),
bli_cntl_calc_num_threads_in( rntm, cntl_cur ),
bli_thread_ocomm_id( thread_seg ),
FALSE,
thread_seg
rntm, // rntm
bli_thrinfo_ocomm( thread_seg ), // ocomm
bli_thread_ocomm_id( thread_seg ), // ocomm_id
bli_cntl_calc_num_threads_in( rntm, cntl_cur ), // n_way
bli_thread_ocomm_id( thread_seg ), // work_id
FALSE, // free_comm
BLIS_NO_PART, // bszid
thread_seg // sub_node
);
// Attach the child thrinfo_t node to its parent structure.
bli_thrinfo_set_sub_node( thread_cur, thread_par );
}
return thread_cur;
}
#define BLIS_NUM_STATIC_COMMS 80
thrinfo_t* bli_thrinfo_create_for_cntl
(
rntm_t* rntm,
cntl_t* cntl_par,
cntl_t* cntl_chl,
thrinfo_t* thread_par
)
{
thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ];
thrcomm_t** new_comms = NULL;
const bszid_t bszid_chl = bli_cntl_bszid( cntl_chl );
const dim_t parent_nt_in = bli_thread_num_threads( thread_par );
const dim_t parent_n_way = bli_thread_n_way( thread_par );
const dim_t parent_comm_id = bli_thread_ocomm_id( thread_par );
const dim_t parent_work_id = bli_thread_work_id( thread_par );
// Sanity check: make sure the number of threads in the parent's
// communicator is divisible by the number of new sub-groups.
if ( parent_nt_in % parent_n_way != 0 )
{
printf( "Assertion failed: parent_nt_in <mod> parent_n_way != 0\n" );
bli_abort();
}
// Compute:
// - the number of threads inside the new child comm,
// - the current thread's id within the new communicator,
// - the current thread's work id, given the ways of parallelism
// to be obtained within the next loop.
const dim_t child_nt_in = bli_cntl_calc_num_threads_in( rntm, cntl_chl );
const dim_t child_n_way = bli_rntm_ways_for( bszid_chl, rntm );
const dim_t child_comm_id = parent_comm_id % child_nt_in;
const dim_t child_work_id = child_comm_id / ( child_nt_in / child_n_way );
//printf( "thread %d: child_n_way = %d child_nt_in = %d parent_n_way = %d (bszid = %d->%d)\n", (int)child_comm_id, (int)child_nt_in, (int)child_n_way, (int)parent_n_way, (int)bli_cntl_bszid( cntl_par ), (int)bszid_chl );
// The parent's chief thread creates a temporary array of thrcomm_t
// pointers.
if ( bli_thread_am_ochief( thread_par ) )
{
if ( parent_n_way > BLIS_NUM_STATIC_COMMS )
new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ) );
else
new_comms = static_comms;
}
// Broadcast the temporary array to all threads in the parent's
// communicator.
new_comms = bli_thread_obroadcast( thread_par, new_comms );
// Chiefs in the child communicator allocate the communicator
// object and store it in the array element corresponding to the
// parent's work id.
if ( child_comm_id == 0 )
new_comms[ parent_work_id ] = bli_thrcomm_create( rntm, child_nt_in );
bli_thread_obarrier( thread_par );
// All threads create a new thrinfo_t node using the communicator
// that was created by their chief, as identified by parent_work_id.
thrinfo_t* thread_chl = bli_thrinfo_create
(
rntm, // rntm
new_comms[ parent_work_id ], // ocomm
child_comm_id, // ocomm_id
child_n_way, // n_way
child_work_id, // work_id
TRUE, // free_comm
bszid_chl, // bszid
NULL // sub_node
);
bli_thread_obarrier( thread_par );
// The parent's chief thread frees the temporary array of thrcomm_t
// pointers.
if ( bli_thread_am_ochief( thread_par ) )
{
if ( parent_n_way > BLIS_NUM_STATIC_COMMS )
bli_free_intl( new_comms );
}
return thread_chl;
}
// -----------------------------------------------------------------------------
thrinfo_t* bli_thrinfo_rgrow_prenode
(
rntm_t* rntm,
cntl_t* cntl_par,
cntl_t* cntl_cur,
thrinfo_t* thread_par
)
{
thrinfo_t* thread_cur;
// We must handle two cases: those where the next node in the
// control tree is a partitioning node, and those where it is
// a non-partitioning (ie: packing) node.
if ( bli_cntl_bszid( cntl_cur ) != BLIS_NO_PART )
{
// Create the child thrinfo_t node corresponding to cntl_cur,
// with cntl_par being the parent.
thread_cur = bli_thrinfo_create_for_cntl_prenode
(
rntm,
cntl_par,
cntl_cur,
thread_par
);
}
else // if ( bli_cntl_bszid( cntl_cur ) == BLIS_NO_PART )
{
// Recursively grow the thread structure and return the top-most
// thrinfo_t node of that segment.
thrinfo_t* thread_seg = bli_thrinfo_rgrow_prenode
(
rntm,
cntl_par,
bli_cntl_sub_node( cntl_cur ),
thread_par
);
// Create a thrinfo_t node corresponding to cntl_cur. Since the
// corresponding cntl node, cntl_cur, is a non-partitioning node
// (bszid = BLIS_NO_PART), this means it's a packing node. Packing
// thrinfo_t nodes are formed differently than those corresponding to
// partitioning nodes; specifically, their work_id's are set equal to
// the their comm_id's. Also, notice that the free_comm field is set
// to FALSE since cntl_cur is a non-partitioning node. The reason:
// the communicator used here will be freed when thread_seg, or one
// of its descendents, is freed.
thread_cur = bli_thrinfo_create
(
rntm, // rntm
bli_thrinfo_ocomm( thread_seg ), // ocomm
bli_thread_ocomm_id( thread_seg ), // ocomm_id
bli_cntl_calc_num_threads_in( rntm, cntl_par ), // n_way
bli_thread_ocomm_id( thread_seg ), // work_id
FALSE, // free_comm
BLIS_NO_PART, // bszid
thread_seg // sub_node
);
}
return thread_cur;
}
thrinfo_t* bli_thrinfo_create_for_cntl_prenode
(
rntm_t* rntm,
cntl_t* cntl_par,
cntl_t* cntl_chl,
thrinfo_t* thread_par
)
{
// NOTE: This function only has to work for the ic -> (pa -> jr)
// thrinfo_t tree branch extension. After that, the function
// bli_thrinfo_create_for_cntl() will be called for the last jr->ir
// branch extension.
const bszid_t bszid_chl = bli_cntl_bszid( cntl_chl );
const dim_t parent_nt_in = bli_thread_num_threads( thread_par );
const dim_t parent_n_way = bli_thread_n_way( thread_par );
const dim_t parent_comm_id = bli_thread_ocomm_id( thread_par );
//const dim_t parent_work_id = bli_thread_work_id( thread_par );
// Sanity check: make sure the number of threads in the parent's
// communicator is divisible by the number of new sub-groups.
if ( parent_nt_in % parent_n_way != 0 )
{
printf( "Assertion failed: parent_nt_in (%d) <mod> parent_n_way (%d) != 0\n",
( int )parent_nt_in, ( int )parent_n_way );
bli_abort();
}
//dim_t child_nt_in = bli_cntl_calc_num_threads_in( rntm, cntl_chl );
//dim_t child_n_way = bli_rntm_ways_for( bszid_chl, rntm );
const dim_t child_nt_in = parent_nt_in;
const dim_t child_n_way = parent_nt_in;
const dim_t child_comm_id = parent_comm_id % child_nt_in;
const dim_t child_work_id = child_comm_id / ( child_nt_in / child_n_way );
bli_thread_obarrier( thread_par );
// NOTE: Recall that parent_comm_id == child_comm_id, so checking for the
// parent's chief-ness is equivalent to checking for chief-ness in the new
// about-to-be-created communicator group.
thrcomm_t* new_comm = NULL;
if ( bli_thread_am_ochief( thread_par ) )
new_comm = bli_thrcomm_create( rntm, child_nt_in );
// Broadcast the new thrcomm_t address to the other threads in the
// parent's group.
new_comm = bli_thread_obroadcast( thread_par, new_comm );
// All threads create a new thrinfo_t node using the communicator
// that was created by their chief, as identified by parent_work_id.
thrinfo_t* thread_chl = bli_thrinfo_create
(
rntm, // rntm
new_comm, // ocomm
child_comm_id, // ocomm_id
child_n_way, // n_way
child_work_id, // work_id
TRUE, // free_comm
bszid_chl, // bszid
NULL // sub_node
);
bli_thread_obarrier( thread_par );
return thread_chl;
}
// -----------------------------------------------------------------------------
#if 0
void bli_thrinfo_grow_tree
(
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
{
cntl_t* cntl_jc = cntl;
thrinfo_t* thrinfo_jc = thread;
bli_thrinfo_grow( rntm, cntl_jc, thrinfo_jc );
// inside jc loop:
cntl_t* cntl_pc = bli_cntl_sub_node( cntl_jc );
thrinfo_t* thrinfo_pc = bli_thrinfo_sub_node( thrinfo_jc );
bli_thrinfo_grow( rntm, cntl_pc, thrinfo_pc );
// inside pc loop:
cntl_t* cntl_pb = bli_cntl_sub_node( cntl_pc );
thrinfo_t* thrinfo_pb = bli_thrinfo_sub_node( thrinfo_pc );
bli_thrinfo_grow( rntm, cntl_pb, thrinfo_pb );
// after pb packing:
cntl_t* cntl_ic = bli_cntl_sub_node( cntl_pb );
thrinfo_t* thrinfo_ic = bli_thrinfo_sub_node( thrinfo_pb );
bli_thrinfo_grow( rntm, cntl_ic, thrinfo_ic );
// -- main branch --
// inside ic loop:
cntl_t* cntl_pa = bli_cntl_sub_node( cntl_ic );
thrinfo_t* thrinfo_pa = bli_thrinfo_sub_node( thrinfo_ic );
bli_thrinfo_grow( rntm, cntl_pa, thrinfo_pa );
// after pa packing:
cntl_t* cntl_jr = bli_cntl_sub_node( cntl_pa );
thrinfo_t* thrinfo_jr = bli_thrinfo_sub_node( thrinfo_pa );
bli_thrinfo_grow( rntm, cntl_jr, thrinfo_jr );
// inside jr loop:
//cntl_t* cntl_ir = bli_cntl_sub_node( cntl_jr );
//thrinfo_t* thrinfo_ir = bli_thrinfo_sub_node( thrinfo_jr );
// -- trsm branch --
// inside ic loop:
cntl_t* cntl_pa0 = bli_cntl_sub_prenode( cntl_ic );
thrinfo_t* thrinfo_pa0 = bli_thrinfo_sub_prenode( thrinfo_ic );
bli_thrinfo_grow( rntm, cntl_pa0, thrinfo_pa0 );
// after pa packing:
cntl_t* cntl_jr0 = bli_cntl_sub_node( cntl_pa0 );
thrinfo_t* thrinfo_jr0 = bli_thrinfo_sub_node( thrinfo_pa0 );
bli_thrinfo_grow( rntm, cntl_jr0, thrinfo_jr0 );
// inside jr loop:
//cntl_t* cntl_ir0 = bli_cntl_sub_node( cntl_jr0 );
//thrinfo_t* thrinfo_ir0= bli_thrinfo_sub_node( thrinfo_jr0 );
}
void bli_thrinfo_grow_tree_ic
(
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
{
cntl_t* cntl_ic = cntl;
thrinfo_t* thrinfo_ic = thread;
bli_thrinfo_grow( rntm, cntl_ic, thrinfo_ic );
// -- main branch --
// inside ic loop:
cntl_t* cntl_pa = bli_cntl_sub_node( cntl_ic );
thrinfo_t* thrinfo_pa = bli_thrinfo_sub_node( thrinfo_ic );
bli_thrinfo_grow( rntm, cntl_pa, thrinfo_pa );
// after pa packing:
cntl_t* cntl_jr = bli_cntl_sub_node( cntl_pa );
thrinfo_t* thrinfo_jr = bli_thrinfo_sub_node( thrinfo_pa );
bli_thrinfo_grow( rntm, cntl_jr, thrinfo_jr );
// inside jr loop:
//cntl_t* cntl_ir = bli_cntl_sub_node( cntl_jr );
//thrinfo_t* thrinfo_ir = bli_thrinfo_sub_node( thrinfo_jr );
// -- trsm branch --
// inside ic loop:
cntl_t* cntl_pa0 = bli_cntl_sub_prenode( cntl_ic );
thrinfo_t* thrinfo_pa0 = bli_thrinfo_sub_prenode( thrinfo_ic );
bli_thrinfo_grow( rntm, cntl_pa0, thrinfo_pa0 );
// after pa packing:
cntl_t* cntl_jr0 = bli_cntl_sub_node( cntl_pa0 );
thrinfo_t* thrinfo_jr0 = bli_thrinfo_sub_node( thrinfo_pa0 );
bli_thrinfo_grow( rntm, cntl_jr0, thrinfo_jr0 );
// inside jr loop:
//cntl_t* cntl_ir0 = bli_cntl_sub_node( cntl_jr0 );
//thrinfo_t* thrinfo_ir0= bli_thrinfo_sub_node( thrinfo_jr0 );
}
#endif

View File

@@ -58,6 +58,11 @@ struct thrinfo_s
// to false.
bool_t free_comm;
// The bszid_t to help identify the node. This is mostly only useful when
// debugging or tracing the allocation and release of thrinfo_t nodes.
bszid_t bszid;
struct thrinfo_s* sub_prenode;
struct thrinfo_s* sub_node;
};
typedef struct thrinfo_s thrinfo_t;
@@ -100,11 +105,21 @@ static bool_t bli_thrinfo_needs_free_comm( thrinfo_t* t )
return t->free_comm;
}
static dim_t bli_thread_bszid( thrinfo_t* t )
{
return t->bszid;
}
static thrinfo_t* bli_thrinfo_sub_node( thrinfo_t* t )
{
return t->sub_node;
}
static thrinfo_t* bli_thrinfo_sub_prenode( thrinfo_t* t )
{
return t->sub_prenode;
}
// thrinfo_t query (complex)
static bool_t bli_thread_am_ochief( thrinfo_t* t )
@@ -119,6 +134,11 @@ static void bli_thrinfo_set_sub_node( thrinfo_t* sub_node, thrinfo_t* t )
t->sub_node = sub_node;
}
static void bli_thrinfo_set_sub_prenode( thrinfo_t* sub_prenode, thrinfo_t* t )
{
t->sub_prenode = sub_prenode;
}
// other thrinfo_t-related functions
static void* bli_thread_obroadcast( thrinfo_t* t, void* p )
@@ -144,6 +164,7 @@ thrinfo_t* bli_thrinfo_create
dim_t n_way,
dim_t work_id,
bool_t free_comm,
bszid_t bszid,
thrinfo_t* sub_node
);
@@ -155,6 +176,7 @@ void bli_thrinfo_init
dim_t n_way,
dim_t work_id,
bool_t free_comm,
bszid_t bszid,
thrinfo_t* sub_node
);
@@ -171,14 +193,6 @@ void bli_thrinfo_free
// -----------------------------------------------------------------------------
thrinfo_t* bli_thrinfo_create_for_cntl
(
rntm_t* rntm,
cntl_t* cntl_par,
cntl_t* cntl_chl,
thrinfo_t* thread_par
);
void bli_thrinfo_grow
(
rntm_t* rntm,
@@ -194,4 +208,46 @@ thrinfo_t* bli_thrinfo_rgrow
thrinfo_t* thread_par
);
thrinfo_t* bli_thrinfo_create_for_cntl
(
rntm_t* rntm,
cntl_t* cntl_par,
cntl_t* cntl_chl,
thrinfo_t* thread_par
);
thrinfo_t* bli_thrinfo_rgrow_prenode
(
rntm_t* rntm,
cntl_t* cntl_par,
cntl_t* cntl_cur,
thrinfo_t* thread_par
);
thrinfo_t* bli_thrinfo_create_for_cntl_prenode
(
rntm_t* rntm,
cntl_t* cntl_par,
cntl_t* cntl_chl,
thrinfo_t* thread_par
);
// -----------------------------------------------------------------------------
#if 0
void bli_thrinfo_grow_tree
(
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
);
void bli_thrinfo_grow_tree_ic
(
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
);
#endif
#endif