mirror of
https://github.com/amd/blis.git
synced 2026-05-11 17:50:00 +00:00
Fixed bugs in _get_range(), _get_range_weighted().
Details: - Fixed some bugs that only manifested in multithreaded instances of some (non-gemm) level-3 operations. The bugs were related to invalid allocation of "edge" cases to thread subpartitions. (Here, we define an "edge" case to be one where the dimension being partitioned for parallelism is not a whole multiple of whatever register blocksize is needed in that dimension.) In BLIS, we always require edge cases to be part of the bottom, right, or bottom-right subpartitions. (This is so that zero-padding only has to happen at the bottom, right, or bottom-right edges of micro-panels.) The previous implementations of bli_get_range() and _get_range_weighted() did not adhere to this implicit policy and thus produced bad ranges for some combinations of operation, parameter cases, problem sizes, and n-way parallelism. - As part of the above fix, the functions bli_get_range() and _get_range_weighted() have been renamed to use _l2r, _r2l, _t2b, and _b2t suffixes, similar to the partitioning functions. This is an easy way to make sure that the variants are calling the right version of each function. The function signatures have also been changed slightly. - Comment/whitespace updates. - Removed unnecessary '/' from macros in bli_obj_macro_defs.h.
This commit is contained in:
@@ -44,7 +44,7 @@ void bli_gemm_blk_var1f( obj_t* a,
|
||||
obj_t b_pack_s;
|
||||
obj_t a1_pack_s, c1_pack_s;
|
||||
|
||||
obj_t a1, c1;
|
||||
obj_t a1, c1;
|
||||
obj_t* a1_pack = NULL;
|
||||
obj_t* b_pack = NULL;
|
||||
obj_t* c1_pack = NULL;
|
||||
@@ -83,9 +83,9 @@ void bli_gemm_blk_var1f( obj_t* a,
|
||||
// Query dimension in partitioning direction.
|
||||
m_trans = bli_obj_length_after_trans( *a );
|
||||
dim_t start, end;
|
||||
bli_get_range( thread, 0, m_trans,
|
||||
bli_blksz_get_mult_for_obj( a, cntl_blocksize( cntl ) ),
|
||||
&start, &end );
|
||||
bli_get_range_t2b( thread, 0, m_trans,
|
||||
bli_blksz_get_mult_for_obj( a, cntl_blocksize( cntl ) ),
|
||||
&start, &end );
|
||||
|
||||
// Partition along the m dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
@@ -130,7 +130,7 @@ void bli_gemm_blk_var1f( obj_t* a,
|
||||
c1_pack,
|
||||
cntl_sub_gemm( cntl ),
|
||||
gemm_thread_sub_gemm( thread ) );
|
||||
|
||||
|
||||
thread_ibarrier( thread );
|
||||
|
||||
// Unpack C1 (if C1 was packed).
|
||||
|
||||
@@ -42,7 +42,7 @@ void bli_gemm_blk_var2f( obj_t* a,
|
||||
{
|
||||
obj_t a_pack_s;
|
||||
obj_t b1_pack_s, c1_pack_s;
|
||||
|
||||
|
||||
obj_t b1, c1;
|
||||
obj_t* a_pack = NULL;
|
||||
obj_t* b1_pack = NULL;
|
||||
@@ -82,9 +82,9 @@ void bli_gemm_blk_var2f( obj_t* a,
|
||||
// Query dimension in partitioning direction.
|
||||
n_trans = bli_obj_width_after_trans( *b );
|
||||
dim_t start, end;
|
||||
bli_get_range( thread, 0, n_trans,
|
||||
bli_blksz_get_mult_for_obj( b, cntl_blocksize( cntl ) ),
|
||||
&start, &end );
|
||||
bli_get_range_l2r( thread, 0, n_trans,
|
||||
bli_blksz_get_mult_for_obj( b, cntl_blocksize( cntl ) ),
|
||||
&start, &end );
|
||||
|
||||
// Partition along the n dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
@@ -129,7 +129,7 @@ void bli_gemm_blk_var2f( obj_t* a,
|
||||
c1_pack,
|
||||
cntl_sub_gemm( cntl ),
|
||||
gemm_thread_sub_gemm( thread ) );
|
||||
|
||||
|
||||
thread_ibarrier( thread );
|
||||
|
||||
// Unpack C1 (if C1 was packed).
|
||||
|
||||
@@ -52,7 +52,7 @@ void bli_gemm_blk_var4f( obj_t* a,
|
||||
obj_t b_pack_s;
|
||||
obj_t a1_pack_s, c1_pack_s;
|
||||
|
||||
obj_t a1, c1;
|
||||
obj_t a1, c1;
|
||||
obj_t* a1_pack = NULL;
|
||||
obj_t* b_pack = NULL;
|
||||
obj_t* c1_pack = NULL;
|
||||
@@ -91,9 +91,9 @@ void bli_gemm_blk_var4f( obj_t* a,
|
||||
// Query dimension in partitioning direction.
|
||||
m_trans = bli_obj_length_after_trans( *a );
|
||||
dim_t start, end;
|
||||
bli_get_range( thread, 0, m_trans,
|
||||
bli_blksz_get_mult_for_obj( a, cntl_blocksize( cntl ) ),
|
||||
&start, &end );
|
||||
bli_get_range_t2b( thread, 0, m_trans,
|
||||
bli_blksz_get_mult_for_obj( a, cntl_blocksize( cntl ) ),
|
||||
&start, &end );
|
||||
|
||||
// Partition along the m dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
@@ -140,7 +140,7 @@ void bli_gemm_blk_var4f( obj_t* a,
|
||||
c1_pack,
|
||||
cntl_sub_gemm( cntl ),
|
||||
gemm_thread_sub_gemm( thread ) );
|
||||
|
||||
|
||||
thread_ibarrier( thread );
|
||||
|
||||
// Only apply beta within the first of three subproblems.
|
||||
@@ -167,7 +167,7 @@ void bli_gemm_blk_var4f( obj_t* a,
|
||||
c1_pack,
|
||||
cntl_sub_gemm( cntl ),
|
||||
gemm_thread_sub_gemm( thread ) );
|
||||
|
||||
|
||||
thread_ibarrier( thread );
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ void bli_gemm_blk_var4f( obj_t* a,
|
||||
c1_pack,
|
||||
cntl_sub_gemm( cntl ),
|
||||
gemm_thread_sub_gemm( thread ) );
|
||||
|
||||
|
||||
thread_ibarrier( thread );
|
||||
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ void bli_herk_blk_var1f( obj_t* a,
|
||||
dim_t b_alg;
|
||||
dim_t m_trans;
|
||||
|
||||
if( thread_am_ochief( thread ) ) {
|
||||
if( thread_am_ochief( thread ) ) {
|
||||
// Initialize object for packing A'.
|
||||
bli_obj_init_pack( &ah_pack_s );
|
||||
bli_packm_init( ah, &ah_pack_s,
|
||||
@@ -61,9 +61,9 @@ void bli_herk_blk_var1f( obj_t* a,
|
||||
// Scale C by beta (if instructed).
|
||||
// Since scalm doesn't support multithreading yet, must be done by chief thread (ew)
|
||||
bli_scalm_int( &BLIS_ONE,
|
||||
c,
|
||||
c,
|
||||
cntl_sub_scalm( cntl ) );
|
||||
}
|
||||
}
|
||||
ah_pack = thread_obroadcast( thread, &ah_pack_s );
|
||||
|
||||
// Initialize pack objects that are passed into packm_init() for A and C.
|
||||
@@ -82,9 +82,9 @@ void bli_herk_blk_var1f( obj_t* a,
|
||||
// Query dimension in partitioning direction.
|
||||
m_trans = bli_obj_length_after_trans( *c );
|
||||
dim_t start, end;
|
||||
bli_get_range_weighted( thread, 0, m_trans,
|
||||
bli_blksz_get_mult_for_obj( a, cntl_blocksize( cntl ) ),
|
||||
bli_obj_is_upper( *c ), &start, &end );
|
||||
bli_get_range_weighted_t2b( thread, 0, m_trans,
|
||||
bli_blksz_get_mult_for_obj( a, cntl_blocksize( cntl ) ),
|
||||
bli_obj_root_uplo( *c ), &start, &end );
|
||||
|
||||
// Partition along the m dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
|
||||
@@ -90,9 +90,9 @@ void bli_herk_blk_var2f( obj_t* a,
|
||||
dim_t start, end;
|
||||
|
||||
// Needs to be replaced with a weighted range because triangle
|
||||
bli_get_range_weighted( thread, 0, n_trans,
|
||||
bli_blksz_get_mult_for_obj( a, cntl_blocksize( cntl ) ),
|
||||
bli_obj_is_lower( *c ), &start, &end );
|
||||
bli_get_range_weighted_l2r( thread, 0, n_trans,
|
||||
bli_blksz_get_mult_for_obj( a, cntl_blocksize( cntl ) ),
|
||||
bli_obj_root_uplo( *c ), &start, &end );
|
||||
|
||||
// Partition along the n dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
|
||||
@@ -94,9 +94,9 @@ void bli_trmm_blk_var1f( obj_t* a,
|
||||
bli_obj_width_after_trans( *a );
|
||||
|
||||
dim_t start, end;
|
||||
bli_get_range_weighted( thread, offA, m_trans,
|
||||
bli_blksz_get_mult_for_obj( a, cntl_blocksize( cntl ) ),
|
||||
bli_obj_is_upper( *c ), &start, &end );
|
||||
bli_get_range_weighted_t2b( thread, offA, m_trans,
|
||||
bli_blksz_get_mult_for_obj( a, cntl_blocksize( cntl ) ),
|
||||
bli_obj_root_uplo( *a ), &start, &end );
|
||||
|
||||
// Partition along the m dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
|
||||
@@ -82,9 +82,9 @@ void bli_trmm_blk_var2b( obj_t* a,
|
||||
// Query dimension in partitioning direction.
|
||||
n_trans = bli_obj_width_after_trans( *b );
|
||||
dim_t start, end;
|
||||
bli_get_range_weighted( thread, 0, n_trans,
|
||||
bli_blksz_get_mult_for_obj( b, cntl_blocksize( cntl ) ),
|
||||
bli_obj_is_upper( *c ), &start, &end );
|
||||
bli_get_range_weighted_r2l( thread, 0, n_trans,
|
||||
bli_blksz_get_mult_for_obj( b, cntl_blocksize( cntl ) ),
|
||||
bli_obj_root_uplo( *b ), &start, &end );
|
||||
|
||||
// Partition along the n dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
|
||||
@@ -82,9 +82,9 @@ void bli_trmm_blk_var2f( obj_t* a,
|
||||
// Query dimension in partitioning direction.
|
||||
n_trans = bli_obj_width_after_trans( *b );
|
||||
dim_t start, end;
|
||||
bli_get_range_weighted( thread, 0, n_trans,
|
||||
bli_blksz_get_mult_for_obj( b, cntl_blocksize( cntl ) ),
|
||||
bli_obj_is_lower( *c ), &start, &end );
|
||||
bli_get_range_weighted_l2r( thread, 0, n_trans,
|
||||
bli_blksz_get_mult_for_obj( b, cntl_blocksize( cntl ) ),
|
||||
bli_obj_root_uplo( *b ), &start, &end );
|
||||
|
||||
// Partition along the n dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
|
||||
@@ -129,8 +129,9 @@ void bli_trmm3_front( side_t side,
|
||||
bli_obj_set_as_root( b_local );
|
||||
bli_obj_set_as_root( c_local );
|
||||
|
||||
|
||||
trmm_thrinfo_t** infos = bli_create_trmm_thrinfo_paths( FALSE );
|
||||
// Notice that, unlike trmm_r, there is no dependency in the jc loop
|
||||
// for trmm3_r, so we can pass in FALSE for jc_dependency.
|
||||
trmm_thrinfo_t** infos = bli_create_trmm_thrinfo_paths( FALSE );
|
||||
dim_t n_threads = thread_num_threads( infos[0] );
|
||||
|
||||
// Invoke the internal back-end.
|
||||
|
||||
@@ -83,10 +83,10 @@ void bli_trsm_blk_var1b( obj_t* a,
|
||||
|
||||
dim_t start, end;
|
||||
num_t dt = bli_obj_execution_datatype( *a );
|
||||
bli_get_range( thread, offA, m_trans,
|
||||
//bli_lcm( bli_info_get_default_nr( BLIS_TRSM, dt ), bli_info_get_default_mr( BLIS_TRSM, dt ) ),
|
||||
bli_info_get_default_mc( BLIS_TRSM, dt ),
|
||||
&start, &end );
|
||||
bli_get_range_b2t( thread, offA, m_trans,
|
||||
//bli_lcm( bli_info_get_default_nr( BLIS_TRSM, dt ), bli_info_get_default_mr( BLIS_TRSM, dt ) ),
|
||||
bli_info_get_default_mc( BLIS_TRSM, dt ),
|
||||
&start, &end );
|
||||
|
||||
// Partition along the remaining portion of the m dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
|
||||
@@ -82,10 +82,10 @@ void bli_trsm_blk_var1f( obj_t* a,
|
||||
|
||||
dim_t start, end;
|
||||
num_t dt = bli_obj_execution_datatype( *a );
|
||||
bli_get_range( thread, offA, m_trans,
|
||||
//bli_lcm( bli_info_get_default_nr( BLIS_TRSM, dt ), bli_info_get_default_mr( BLIS_TRSM, dt ) ),
|
||||
bli_info_get_default_mc( BLIS_TRSM, dt ),
|
||||
&start, &end );
|
||||
bli_get_range_t2b( thread, offA, m_trans,
|
||||
//bli_lcm( bli_info_get_default_nr( BLIS_TRSM, dt ), bli_info_get_default_mr( BLIS_TRSM, dt ) ),
|
||||
bli_info_get_default_mc( BLIS_TRSM, dt ),
|
||||
&start, &end );
|
||||
|
||||
// Partition along the remaining portion of the m dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
|
||||
@@ -84,12 +84,12 @@ void bli_trsm_blk_var2b( obj_t* a,
|
||||
n_trans = bli_obj_width_after_trans( *b );
|
||||
dim_t start, end;
|
||||
num_t dt = bli_obj_execution_datatype( *a );
|
||||
bli_get_range( thread, 0, n_trans,
|
||||
//bli_lcm( bli_info_get_default_nr( BLIS_TRSM, dt ),
|
||||
// bli_info_get_default_mr( BLIS_TRSM, dt ) ),
|
||||
bli_lcm( bli_blksz_get_nr( dt, cntl_blocksize( cntl ) ),
|
||||
bli_blksz_get_mr( dt, cntl_blocksize( cntl ) ) ),
|
||||
&start, &end );
|
||||
bli_get_range_r2l( thread, 0, n_trans,
|
||||
//bli_lcm( bli_info_get_default_nr( BLIS_TRSM, dt ),
|
||||
// bli_info_get_default_mr( BLIS_TRSM, dt ) ),
|
||||
bli_lcm( bli_blksz_get_nr( dt, cntl_blocksize( cntl ) ),
|
||||
bli_blksz_get_mr( dt, cntl_blocksize( cntl ) ) ),
|
||||
&start, &end );
|
||||
|
||||
// Partition along the n dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
|
||||
@@ -84,12 +84,12 @@ void bli_trsm_blk_var2f( obj_t* a,
|
||||
n_trans = bli_obj_width_after_trans( *b );
|
||||
dim_t start, end;
|
||||
num_t dt = bli_obj_execution_datatype( *a );
|
||||
bli_get_range( thread, 0, n_trans,
|
||||
//bli_lcm( bli_info_get_default_nr( BLIS_TRSM, dt ),
|
||||
// bli_info_get_default_mr( BLIS_TRSM, dt ) ),
|
||||
bli_lcm( bli_blksz_get_nr( dt, cntl_blocksize( cntl ) ),
|
||||
bli_blksz_get_mr( dt, cntl_blocksize( cntl ) ) ),
|
||||
&start, &end );
|
||||
bli_get_range_l2r( thread, 0, n_trans,
|
||||
//bli_lcm( bli_info_get_default_nr( BLIS_TRSM, dt ),
|
||||
// bli_info_get_default_mr( BLIS_TRSM, dt ) ),
|
||||
bli_lcm( bli_blksz_get_nr( dt, cntl_blocksize( cntl ) ),
|
||||
bli_blksz_get_mr( dt, cntl_blocksize( cntl ) ) ),
|
||||
&start, &end );
|
||||
|
||||
// Partition along the n dimension.
|
||||
for ( i = start; i < end; i += b_alg )
|
||||
|
||||
@@ -130,10 +130,10 @@ void bli_trsm_blk_var3b( obj_t* a,
|
||||
// internal alpha scalars on A/B and C are non-zero, we must ensure
|
||||
// that they are only used in the first iteration.
|
||||
thread_ibarrier( thread );
|
||||
if ( i == 0 && thread_am_ichief( thread ) ) {
|
||||
if ( i == 0 && thread_am_ichief( thread ) ) {
|
||||
bli_obj_scalar_reset( a );
|
||||
bli_obj_scalar_reset( b );
|
||||
bli_obj_scalar_reset( c_pack );
|
||||
bli_obj_scalar_reset( c_pack );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -130,10 +130,10 @@ void bli_trsm_blk_var3f( obj_t* a,
|
||||
// internal alpha scalars on A/B and C are non-zero, we must ensure
|
||||
// that they are only used in the first iteration.
|
||||
thread_ibarrier( thread );
|
||||
if ( i == 0 && thread_am_ichief( thread ) ) {
|
||||
if ( i == 0 && thread_am_ichief( thread ) ) {
|
||||
bli_obj_scalar_reset( a );
|
||||
bli_obj_scalar_reset( b );
|
||||
bli_obj_scalar_reset( c_pack );
|
||||
bli_obj_scalar_reset( c_pack );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -142,65 +142,319 @@ void* bli_broadcast_structure( thread_comm_t* communicator, dim_t id, void* to_s
|
||||
}
|
||||
|
||||
// Code for work assignments
|
||||
void bli_get_range( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, dim_t* start, dim_t* end )
|
||||
void bli_get_range( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, bool_t handle_edge_low, dim_t* start, dim_t* end )
|
||||
{
|
||||
thrinfo_t* thread = (thrinfo_t*) thr;
|
||||
dim_t n_way = thread->n_way;
|
||||
dim_t work_id = thread->work_id;
|
||||
thrinfo_t* thread = ( thrinfo_t* )thr;
|
||||
dim_t n_way = thread->n_way;
|
||||
dim_t work_id = thread->work_id;
|
||||
|
||||
dim_t size = all_end - all_start;
|
||||
dim_t n_pt = size / n_way;
|
||||
n_pt = (n_pt * n_way < size) ? n_pt + 1 : n_pt;
|
||||
n_pt = (n_pt % block_factor == 0) ? n_pt : n_pt + block_factor - (n_pt % block_factor);
|
||||
*start = work_id * n_pt + all_start;
|
||||
*end = bli_min( *start + n_pt, size + all_start );
|
||||
dim_t size = all_end - all_start;
|
||||
|
||||
dim_t n_bf_whole = size / block_factor;
|
||||
dim_t n_bf_left = size % block_factor;
|
||||
|
||||
dim_t n_bf_lo = n_bf_whole / n_way;
|
||||
dim_t n_bf_hi = n_bf_whole / n_way;
|
||||
|
||||
// In this function, we partition the space between all_start and
|
||||
// all_end into n_way partitions, each a multiple of block_factor
|
||||
// with the exception of the one partition that recieves the
|
||||
// "edge" case (if applicable).
|
||||
//
|
||||
// Here are examples of various thread partitionings, in units of
|
||||
// the block_factor, when n_way = 4. (A '+' indicates the thread
|
||||
// that receives the leftover edge case (ie: n_bf_left extra
|
||||
// rows/columns in its sub-range).
|
||||
// (all_start ... all_end)
|
||||
// n_bf_whole _left hel n_th_lo _hi thr0 thr1 thr2 thr3
|
||||
// 12 =0 f 0 4 3 3 3 3
|
||||
// 12 >0 f 0 4 3 3 3 3+
|
||||
// 13 >0 f 1 3 4 3 3 3+
|
||||
// 14 >0 f 2 2 4 4 3 3+
|
||||
// 15 >0 f 3 1 4 4 4 3+
|
||||
// 15 =0 f 3 1 4 4 4 3
|
||||
//
|
||||
// 12 =0 t 4 0 3 3 3 3
|
||||
// 12 >0 t 4 0 3+ 3 3 3
|
||||
// 13 >0 t 3 1 3+ 3 3 4
|
||||
// 14 >0 t 2 2 3+ 3 4 4
|
||||
// 15 >0 t 1 3 3+ 4 4 4
|
||||
// 15 =0 t 1 3 3 4 4 4
|
||||
|
||||
// As indicated by the table above, load is balanced as equally
|
||||
// as possible, even in the presence of an edge case.
|
||||
|
||||
// First, we must differentiate between cases where the leftover
|
||||
// "edge" case (n_bf_left) should be allocated to a thread partition
|
||||
// at the low end of the index range or the high end.
|
||||
|
||||
if ( handle_edge_low == FALSE )
|
||||
{
|
||||
// Notice that if all threads receive the same number of
|
||||
// block_factors, those threads are considered "high" and
|
||||
// the "low" thread group is empty.
|
||||
dim_t n_th_lo = n_bf_whole % n_way;
|
||||
//dim_t n_th_hi = n_way - n_th_lo;
|
||||
|
||||
// If some partitions must have more block_factors than others
|
||||
// assign the slightly larger partitions to lower index threads.
|
||||
if ( n_th_lo != 0 ) n_bf_lo += 1;
|
||||
|
||||
// Compute the actual widths (in units of rows/columns) of
|
||||
// individual threads in the low and high groups.
|
||||
dim_t size_lo = n_bf_lo * block_factor;
|
||||
dim_t size_hi = n_bf_hi * block_factor;
|
||||
|
||||
// Precompute the starting indices of the low and high groups.
|
||||
dim_t lo_start = all_start;
|
||||
dim_t hi_start = all_start + n_th_lo * size_lo;
|
||||
|
||||
// Compute the start and end of individual threads' ranges
|
||||
// as a function of their work_ids and also the group to which
|
||||
// they belong (low or high).
|
||||
if ( work_id < n_th_lo )
|
||||
{
|
||||
*start = lo_start + (work_id ) * size_lo;
|
||||
*end = lo_start + (work_id+1) * size_lo;
|
||||
}
|
||||
else // if ( n_th_lo <= work_id )
|
||||
{
|
||||
*start = hi_start + (work_id-n_th_lo ) * size_hi;
|
||||
*end = hi_start + (work_id-n_th_lo+1) * size_hi;
|
||||
|
||||
// Since the edge case is being allocated to the high
|
||||
// end of the index range, we have to advance the last
|
||||
// thread's end.
|
||||
if ( work_id == n_way - 1 ) *end += n_bf_left;
|
||||
}
|
||||
}
|
||||
else // if ( handle_edge_low == TRUE )
|
||||
{
|
||||
// Notice that if all threads receive the same number of
|
||||
// block_factors, those threads are considered "low" and
|
||||
// the "high" thread group is empty.
|
||||
dim_t n_th_hi = n_bf_whole % n_way;
|
||||
dim_t n_th_lo = n_way - n_th_hi;
|
||||
|
||||
// If some partitions must have more block_factors than others
|
||||
// assign the slightly larger partitions to higher index threads.
|
||||
if ( n_th_hi != 0 ) n_bf_hi += 1;
|
||||
|
||||
// Compute the actual widths (in units of rows/columns) of
|
||||
// individual threads in the low and high groups.
|
||||
dim_t size_lo = n_bf_lo * block_factor;
|
||||
dim_t size_hi = n_bf_hi * block_factor;
|
||||
|
||||
// Precompute the starting indices of the low and high groups.
|
||||
dim_t lo_start = all_start;
|
||||
dim_t hi_start = all_start + n_th_lo * size_lo
|
||||
+ n_bf_left;
|
||||
|
||||
// Compute the start and end of individual threads' ranges
|
||||
// as a function of their work_ids and also the group to which
|
||||
// they belong (low or high).
|
||||
if ( work_id < n_th_lo )
|
||||
{
|
||||
*start = lo_start + (work_id ) * size_lo;
|
||||
*end = lo_start + (work_id+1) * size_lo;
|
||||
|
||||
// Since the edge case is being allocated to the low
|
||||
// end of the index range, we have to advance the
|
||||
// starts/ends accordingly.
|
||||
if ( work_id == 0 ) *end += n_bf_left;
|
||||
else { *start += n_bf_left;
|
||||
*end += n_bf_left; }
|
||||
}
|
||||
else // if ( n_th_lo <= work_id )
|
||||
{
|
||||
*start = hi_start + (work_id-n_th_lo ) * size_hi;
|
||||
*end = hi_start + (work_id-n_th_lo+1) * size_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void bli_get_range_weighted( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, bool_t forward, dim_t* start, dim_t* end)
|
||||
void bli_get_range_l2r( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, dim_t* start, dim_t* end )
|
||||
{
|
||||
thrinfo_t* thread = (thrinfo_t*) thr;
|
||||
dim_t n_way = thread->n_way;
|
||||
dim_t work_id = thread->work_id;
|
||||
dim_t size = all_end - all_start;
|
||||
bli_get_range( thr, all_start, all_end, block_factor,
|
||||
FALSE, start, end );
|
||||
}
|
||||
|
||||
*start = 0;
|
||||
*end = all_end - all_start;
|
||||
double num = size*size / (double) n_way;
|
||||
void bli_get_range_r2l( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, dim_t* start, dim_t* end )
|
||||
{
|
||||
bli_get_range( thr, all_start, all_end, block_factor,
|
||||
TRUE, start, end );
|
||||
}
|
||||
|
||||
if( forward ) {
|
||||
dim_t curr_caucus = n_way - 1;
|
||||
dim_t len = 0;
|
||||
while(1){
|
||||
dim_t width = ceil(sqrt( len*len + num )) - len; // The width of the current caucus
|
||||
width = (width % block_factor == 0) ? width : width + block_factor - (width % block_factor);
|
||||
if( curr_caucus == work_id ) {
|
||||
*start = bli_max( 0 , *end - width ) + all_start;
|
||||
*end = *end + all_start;
|
||||
return;
|
||||
}
|
||||
else{
|
||||
*end -= width;
|
||||
len += width;
|
||||
curr_caucus--;
|
||||
}
|
||||
}
|
||||
}
|
||||
else{
|
||||
while(1){
|
||||
dim_t width = ceil(sqrt(*start * *start + num)) - *start;
|
||||
width = (width % block_factor == 0) ? width : width + block_factor - (width % block_factor);
|
||||
void bli_get_range_t2b( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, dim_t* start, dim_t* end )
|
||||
{
|
||||
bli_get_range( thr, all_start, all_end, block_factor,
|
||||
FALSE, start, end );
|
||||
}
|
||||
|
||||
if( work_id == 0 ) {
|
||||
*start = *start + all_start;
|
||||
*end = bli_min( *start + width, all_end );
|
||||
return;
|
||||
}
|
||||
else{
|
||||
*start = *start + width;
|
||||
}
|
||||
work_id--;
|
||||
}
|
||||
}
|
||||
void bli_get_range_b2t( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, dim_t* start, dim_t* end )
|
||||
{
|
||||
bli_get_range( thr, all_start, all_end, block_factor,
|
||||
TRUE, start, end );
|
||||
}
|
||||
|
||||
void bli_get_range_weighted( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, uplo_t uplo, bool_t handle_edge_low, dim_t* start, dim_t* end )
|
||||
{
|
||||
thrinfo_t* thread = ( thrinfo_t* )thr;
|
||||
dim_t n_way = thread->n_way;
|
||||
dim_t work_id = thread->work_id;
|
||||
dim_t size = all_end - all_start;
|
||||
dim_t width;
|
||||
dim_t block_fac_leftover = size % block_factor;
|
||||
dim_t i;
|
||||
double num;
|
||||
|
||||
*start = 0;
|
||||
*end = all_end - all_start;
|
||||
num = size * size / ( double )n_way;
|
||||
|
||||
if ( bli_is_lower( uplo ) )
|
||||
{
|
||||
dim_t cur_caucus = n_way - 1;
|
||||
dim_t len = 0;
|
||||
|
||||
// This loop computes subpartitions backwards, from the high end
|
||||
// of the index range to the low end. If the low end is assumed
|
||||
// to be on the left and the high end the right, this assignment
|
||||
// of widths is appropriate for n dimension partitioning of a
|
||||
// lower triangular matrix.
|
||||
for ( i = 0; TRUE; ++i )
|
||||
{
|
||||
width = ceil( sqrt( len*len + num ) ) - len;
|
||||
|
||||
// If we need to allocate the edge case (assuming it exists)
|
||||
// to the high thread subpartition, adjust width so that it
|
||||
// contains the exact amount of leftover edge dimension so that
|
||||
// all remaining subpartitions can be multiples of block_factor.
|
||||
// If the edge case is to be allocated to the low subpartition,
|
||||
// or if there is no edge case, it is implicitly allocated to
|
||||
// the low subpartition by virtue of the fact that all other
|
||||
// subpartitions already assigned will be multiples of
|
||||
// block_factor.
|
||||
if ( i == 0 && !handle_edge_low )
|
||||
{
|
||||
if ( width % block_factor != block_fac_leftover )
|
||||
width += block_fac_leftover - ( width % block_factor );
|
||||
}
|
||||
else
|
||||
{
|
||||
if ( width % block_factor != 0 )
|
||||
width += block_factor - ( width % block_factor );
|
||||
}
|
||||
|
||||
if ( cur_caucus == work_id )
|
||||
{
|
||||
*start = bli_max( 0, *end - width ) + all_start;
|
||||
*end = *end + all_start;
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
*end -= width;
|
||||
len += width;
|
||||
cur_caucus--;
|
||||
}
|
||||
}
|
||||
}
|
||||
else // if ( bli_is_upper( uplo ) )
|
||||
{
|
||||
// This loop computes subpartitions forwards, from the low end
|
||||
// of the index range to the high end. If the low end is assumed
|
||||
// to be on the left and the high end the right, this assignment
|
||||
// of widths is appropriate for n dimension partitioning of an
|
||||
// upper triangular matrix.
|
||||
for ( i = 0; TRUE; ++i )
|
||||
{
|
||||
width = ceil( sqrt( *start * *start + num ) ) - *start;
|
||||
|
||||
if ( i == 0 && handle_edge_low )
|
||||
{
|
||||
if ( width % block_factor != block_fac_leftover )
|
||||
width += block_fac_leftover - ( width % block_factor );
|
||||
}
|
||||
else
|
||||
{
|
||||
if ( width % block_factor != 0 )
|
||||
width += block_factor - ( width % block_factor );
|
||||
}
|
||||
|
||||
if ( work_id == 0 )
|
||||
{
|
||||
*start = *start + all_start;
|
||||
*end = bli_min( *start + width, all_end );
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
*start = *start + width;
|
||||
work_id--;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void bli_get_range_weighted_l2r( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, uplo_t uplo, dim_t* start, dim_t* end )
|
||||
{
|
||||
if ( bli_is_upper_or_lower( uplo ) )
|
||||
{
|
||||
bli_get_range_weighted( thr, all_start, all_end, block_factor,
|
||||
uplo, FALSE, start, end );
|
||||
}
|
||||
else // if dense or zeros
|
||||
{
|
||||
bli_get_range_l2r( thr, all_start, all_end, block_factor,
|
||||
start, end );
|
||||
}
|
||||
}
|
||||
|
||||
void bli_get_range_weighted_r2l( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, uplo_t uplo, dim_t* start, dim_t* end )
|
||||
{
|
||||
if ( bli_is_upper_or_lower( uplo ) )
|
||||
{
|
||||
//printf( "bli_get_range_weighted_r2l: is upper or lower\n" );
|
||||
bli_toggle_uplo( uplo );
|
||||
bli_get_range_weighted( thr, all_start, all_end, block_factor,
|
||||
uplo, TRUE, start, end );
|
||||
}
|
||||
else // if dense or zeros
|
||||
{
|
||||
//printf( "bli_get_range_weighted_r2l: is dense or zeros\n" );
|
||||
bli_get_range_r2l( thr, all_start, all_end, block_factor,
|
||||
start, end );
|
||||
}
|
||||
}
|
||||
|
||||
void bli_get_range_weighted_t2b( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, uplo_t uplo, dim_t* start, dim_t* end )
|
||||
{
|
||||
if ( bli_is_upper_or_lower( uplo ) )
|
||||
{
|
||||
bli_toggle_uplo( uplo );
|
||||
bli_get_range_weighted( thr, all_start, all_end, block_factor,
|
||||
uplo, FALSE, start, end );
|
||||
}
|
||||
else // if dense or zeros
|
||||
{
|
||||
bli_get_range_t2b( thr, all_start, all_end, block_factor,
|
||||
start, end );
|
||||
}
|
||||
}
|
||||
|
||||
void bli_get_range_weighted_b2t( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, uplo_t uplo, dim_t* start, dim_t* end )
|
||||
{
|
||||
if ( bli_is_upper_or_lower( uplo ) )
|
||||
{
|
||||
bli_get_range_weighted( thr, all_start, all_end, block_factor,
|
||||
uplo, TRUE, start, end );
|
||||
}
|
||||
else // if dense or zeros
|
||||
{
|
||||
bli_get_range_b2t( thr, all_start, all_end, block_factor,
|
||||
start, end );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -127,8 +127,40 @@ typedef struct thrinfo_s thrinfo_t;
|
||||
#define thread_obarrier( thread ) bli_barrier( thread->ocomm, thread->ocomm_id )
|
||||
#define thread_ibarrier( thread ) bli_barrier( thread->icomm, thread->icomm_id )
|
||||
|
||||
void bli_get_range( void* thread, dim_t all_start, dim_t all_end, dim_t block_factor, dim_t* start, dim_t* end );
|
||||
void bli_get_range_weighted( void* thr, dim_t all_start, dim_t all_end, dim_t block_factor, bool_t forward, dim_t* start, dim_t* end);
|
||||
void bli_get_range( void* thr, dim_t all_start, dim_t all_end,
|
||||
dim_t block_factor,
|
||||
bool_t handle_edge_low,
|
||||
dim_t* start, dim_t* end );
|
||||
void bli_get_range_l2r( void* thr, dim_t all_start, dim_t all_end,
|
||||
dim_t block_factor,
|
||||
dim_t* start, dim_t* end );
|
||||
void bli_get_range_r2l( void* thr, dim_t all_start, dim_t all_end,
|
||||
dim_t block_factor,
|
||||
dim_t* start, dim_t* end );
|
||||
void bli_get_range_t2b( void* thr, dim_t all_start, dim_t all_end,
|
||||
dim_t block_factor,
|
||||
dim_t* start, dim_t* end );
|
||||
void bli_get_range_b2t( void* thr, dim_t all_start, dim_t all_end,
|
||||
dim_t block_factor,
|
||||
dim_t* start, dim_t* end );
|
||||
|
||||
void bli_get_range_weighted( void* thr, dim_t all_start, dim_t all_end,
|
||||
dim_t block_factor, uplo_t uplo,
|
||||
bool_t handle_edge_low,
|
||||
dim_t* start, dim_t* end );
|
||||
void bli_get_range_weighted_l2r( void* thr, dim_t all_start, dim_t all_end,
|
||||
dim_t block_factor, uplo_t uplo,
|
||||
dim_t* start, dim_t* end );
|
||||
void bli_get_range_weighted_r2l( void* thr, dim_t all_start, dim_t all_end,
|
||||
dim_t block_factor, uplo_t uplo,
|
||||
dim_t* start, dim_t* end );
|
||||
void bli_get_range_weighted_t2b( void* thr, dim_t all_start, dim_t all_end,
|
||||
dim_t block_factor, uplo_t uplo,
|
||||
dim_t* start, dim_t* end );
|
||||
void bli_get_range_weighted_b2t( void* thr, dim_t all_start, dim_t all_end,
|
||||
dim_t block_factor, uplo_t uplo,
|
||||
dim_t* start, dim_t* end );
|
||||
|
||||
thrinfo_t* bli_create_thread_info( thread_comm_t* ocomm, dim_t ocomm_id,
|
||||
thread_comm_t* icomm, dim_t icomm_id,
|
||||
dim_t n_way, dim_t work_id );
|
||||
|
||||
@@ -378,12 +378,12 @@
|
||||
} \
|
||||
}
|
||||
|
||||
#define bli_obj_apply_trans( trans, obj )\
|
||||
#define bli_obj_apply_trans( trans, obj ) \
|
||||
{ \
|
||||
(obj).info = ( (obj).info ^ (trans) ); \
|
||||
}
|
||||
|
||||
#define bli_obj_apply_conj( conjval, obj )\
|
||||
#define bli_obj_apply_conj( conjval, obj ) \
|
||||
{ \
|
||||
(obj).info = ( (obj).info ^ (conjval) ); \
|
||||
}
|
||||
@@ -395,21 +395,25 @@
|
||||
\
|
||||
((obj).root)
|
||||
|
||||
#define bli_obj_root_uplo( obj ) \
|
||||
\
|
||||
bli_obj_uplo( *bli_obj_root( obj ) )
|
||||
|
||||
#define bli_obj_root_is_general( obj ) \
|
||||
\
|
||||
bli_obj_is_general( *bli_obj_root( obj ) ) \
|
||||
bli_obj_is_general( *bli_obj_root( obj ) )
|
||||
|
||||
#define bli_obj_root_is_hermitian( obj ) \
|
||||
\
|
||||
bli_obj_is_hermitian( *bli_obj_root( obj ) ) \
|
||||
bli_obj_is_hermitian( *bli_obj_root( obj ) )
|
||||
|
||||
#define bli_obj_root_is_symmetric( obj ) \
|
||||
\
|
||||
bli_obj_is_symmetric( *bli_obj_root( obj ) ) \
|
||||
bli_obj_is_symmetric( *bli_obj_root( obj ) )
|
||||
|
||||
#define bli_obj_root_is_triangular( obj ) \
|
||||
\
|
||||
bli_obj_is_triangular( *bli_obj_root( obj ) ) \
|
||||
bli_obj_is_triangular( *bli_obj_root( obj ) )
|
||||
|
||||
#define bli_obj_root_is_herm_or_symm( obj ) \
|
||||
\
|
||||
@@ -418,11 +422,11 @@
|
||||
|
||||
#define bli_obj_root_is_upper( obj ) \
|
||||
\
|
||||
bli_obj_is_upper( *bli_obj_root( obj ) ) \
|
||||
bli_obj_is_upper( *bli_obj_root( obj ) )
|
||||
|
||||
#define bli_obj_root_is_lower( obj ) \
|
||||
\
|
||||
bli_obj_is_lower( *bli_obj_root( obj ) ) \
|
||||
bli_obj_is_lower( *bli_obj_root( obj ) )
|
||||
|
||||
|
||||
// Root matrix modification
|
||||
|
||||
Reference in New Issue
Block a user