diff --git a/frame/3/trsm/bli_trsm_ll_ker_var2.c b/frame/3/trsm/bli_trsm_ll_ker_var2.c index 0d31f656b..137ffaf07 100644 --- a/frame/3/trsm/bli_trsm_ll_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ll_ker_var2.c @@ -301,6 +301,8 @@ void PASTEMAC(ch,varname)( \ /* Loop over the n dimension (NR columns at a time). */ \ for ( j = 0; j < n_iter; ++j ) \ { \ + if( trsm_my_iter( j, thread ) ) { \ +\ ctype* restrict a1; \ ctype* restrict c11; \ ctype* restrict b2; \ @@ -355,8 +357,9 @@ void PASTEMAC(ch,varname)( \ if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ - b2 = b1 + cstep_b; \ - if ( bli_is_last_iter( j, n_iter ) ) \ + b2 = b1; \ + /*if ( bli_is_last_iter( j, n_iter ) ) */\ + if ( j + thread_num_threads(thread) >= n_iter ) \ b2 = b_cast; \ } \ \ @@ -411,8 +414,9 @@ void PASTEMAC(ch,varname)( \ if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ - b2 = b1 + cstep_b; \ - if ( bli_is_last_iter( j, n_iter ) ) \ + b2 = b1; \ + /*if ( bli_is_last_iter( j, n_iter ) ) */\ + if ( j + thread_num_threads(thread) >= n_iter ) \ b2 = b_cast; \ } \ \ @@ -460,7 +464,7 @@ void PASTEMAC(ch,varname)( \ \ c11 += rstep_c; \ } \ -\ + } \ b1 += cstep_b; \ c1 += cstep_c; \ } \ diff --git a/frame/3/trsm/bli_trsm_lu_ker_var2.c b/frame/3/trsm/bli_trsm_lu_ker_var2.c index 6d0efe5e8..8d09567cf 100644 --- a/frame/3/trsm/bli_trsm_lu_ker_var2.c +++ b/frame/3/trsm/bli_trsm_lu_ker_var2.c @@ -309,6 +309,8 @@ void PASTEMAC(ch,varname)( \ /* Loop over the n dimension (NR columns at a time). */ \ for ( j = 0; j < n_iter; ++j ) \ { \ + if( trsm_my_iter( j, thread ) ) { \ +\ ctype* restrict a1; \ ctype* restrict c11; \ ctype* restrict b2; \ @@ -365,8 +367,9 @@ void PASTEMAC(ch,varname)( \ if ( bli_is_last_iter( ib, m_iter ) ) \ { \ a2 = a_cast; \ - b2 = b1 + cstep_b; \ - if ( bli_is_last_iter( j, n_iter ) ) \ + b2 = b1; \ + /*if ( bli_is_last_iter( j, n_iter ) ) */\ + if ( j + thread_num_threads(thread) >= n_iter ) \ b2 = b_cast; \ } \ \ @@ -421,8 +424,9 @@ void PASTEMAC(ch,varname)( \ if ( bli_is_last_iter( ib, m_iter ) ) \ { \ a2 = a_cast; \ - b2 = b1 + cstep_b; \ - if ( bli_is_last_iter( j, n_iter ) ) \ + b2 = b1; \ + /*if ( bli_is_last_iter( j, n_iter ) ) */\ + if ( j + thread_num_threads(thread) >= n_iter ) \ b2 = b_cast; \ } \ \ @@ -470,7 +474,7 @@ void PASTEMAC(ch,varname)( \ \ c11 -= rstep_c; \ } \ -\ + } \ b1 += cstep_b; \ c1 += cstep_c; \ } \ diff --git a/frame/3/trsm/bli_trsm_rl_ker_var2.c b/frame/3/trsm/bli_trsm_rl_ker_var2.c index 3bc951bd5..341aae3aa 100644 --- a/frame/3/trsm/bli_trsm_rl_ker_var2.c +++ b/frame/3/trsm/bli_trsm_rl_ker_var2.c @@ -372,6 +372,8 @@ void PASTEMAC(ch,varname)( \ ctype* restrict a11; \ ctype* restrict a12; \ ctype* restrict a2; \ +\ + if( trsm_my_iter( i, thread ) ){ \ \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ @@ -380,8 +382,9 @@ void PASTEMAC(ch,varname)( \ a12 = a1 + off_b21 * PACKMR; \ \ /* Compute the addresses of the next panels of A and B. */ \ - a2 = a1 + rstep_a; \ - if ( bli_is_last_iter( i, m_iter ) ) \ + a2 = a1; \ + /*if ( bli_is_last_iter( i, m_iter ) ) */\ + if ( i + thread_num_threads(thread) >= m_iter ) \ { \ a2 = a_cast; \ b2 = b1 + k_b1121 * ss_b; \ @@ -425,7 +428,7 @@ void PASTEMAC(ch,varname)( \ ct, rs_ct, cs_ct, \ c11, rs_c, cs_c ); \ } \ -\ + } \ a1 += rstep_a; \ c11 += rstep_c; \ } \ @@ -436,12 +439,15 @@ void PASTEMAC(ch,varname)( \ for ( i = 0; i < m_iter; ++i ) \ { \ ctype* restrict a2; \ +\ + if( trsm_my_iter( i, thread ) ){ \ \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ /* Compute the addresses of the next panels of A and B. */ \ - a2 = a1 + rstep_a; \ - if ( bli_is_last_iter( i, m_iter ) ) \ + a2 = a1; \ + /*if ( bli_is_last_iter( i, m_iter ) ) */\ + if ( i + thread_num_threads(thread) >= m_iter ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ @@ -484,7 +490,7 @@ void PASTEMAC(ch,varname)( \ alpha2_cast, \ c11, rs_c, cs_c ); \ } \ -\ + } \ a1 += rstep_a; \ c11 += rstep_c; \ } \ diff --git a/frame/3/trsm/bli_trsm_ru_ker_var2.c b/frame/3/trsm/bli_trsm_ru_ker_var2.c index 6711ba423..50a19672e 100644 --- a/frame/3/trsm/bli_trsm_ru_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ru_ker_var2.c @@ -365,6 +365,8 @@ void PASTEMAC(ch,varname)( \ ctype* restrict a10; \ ctype* restrict a11; \ ctype* restrict a2; \ +\ + if( trsm_my_iter( i, thread ) ){ \ \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ @@ -373,8 +375,9 @@ void PASTEMAC(ch,varname)( \ a11 = a1 + off_b11 * PACKMR; \ \ /* Compute the addresses of the next panels of A and B. */ \ - a2 = a1 + rstep_a; \ - if ( bli_is_last_iter( i, m_iter ) ) \ + a2 = a1; \ + /*if ( bli_is_last_iter( i, m_iter ) ) */\ + if ( i + thread_num_threads(thread) >= m_iter ) \ { \ a2 = a_cast; \ b2 = b1 + k_b0111 * ss_b; \ @@ -418,7 +421,7 @@ void PASTEMAC(ch,varname)( \ ct, rs_ct, cs_ct, \ c11, rs_c, cs_c ); \ } \ -\ + } \ a1 += rstep_a; \ c11 += rstep_c; \ } \ @@ -429,12 +432,15 @@ void PASTEMAC(ch,varname)( \ for ( i = 0; i < m_iter; ++i ) \ { \ ctype* restrict a2; \ +\ + if( trsm_my_iter( i, thread ) ){ \ \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ /* Compute the addresses of the next panels of A and B. */ \ - a2 = a1 + rstep_a; \ - if ( bli_is_last_iter( i, m_iter ) ) \ + a2 = a1; \ + /*if ( bli_is_last_iter( i, m_iter ) ) */\ + if ( i + thread_num_threads(thread) >= m_iter ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ @@ -477,7 +483,7 @@ void PASTEMAC(ch,varname)( \ alpha2_cast, \ c11, rs_c, cs_c ); \ } \ -\ + } \ a1 += rstep_a; \ c11 += rstep_c; \ } \ diff --git a/frame/3/trsm/bli_trsm_threading.c b/frame/3/trsm/bli_trsm_threading.c index 8d62a737b..73832f4f8 100644 --- a/frame/3/trsm/bli_trsm_threading.c +++ b/frame/3/trsm/bli_trsm_threading.c @@ -109,17 +109,20 @@ void bli_trsm_thrinfo_free_paths( trsm_thrinfo_t** threads, dim_t num ) trsm_thrinfo_t** bli_create_trsm_thrinfo_paths( ) { - /* - dim_t jc_way = bli_read_nway_from_env( "BLIS_JC_NT" ); - dim_t kc_way = bli_read_nway_from_env( "BLIS_KC_NT" ); - dim_t ic_way = bli_read_nway_from_env( "BLIS_IC_NT" ); - dim_t jr_way = bli_read_nway_from_env( "BLIS_JR_NT" ); - dim_t ir_way = bli_read_nway_from_env( "BLIS_IR_NT" ); - */ +#ifdef BLIS_ENABLE_MULTITHREADING + dim_t jc_in = bli_read_nway_from_env( "BLIS_JC_NT" ); + /*dim_t kc_in = bli_read_nway_from_env( "BLIS_KC_NT" );*/ + dim_t ic_in = bli_read_nway_from_env( "BLIS_IC_NT" ); + dim_t jr_in = bli_read_nway_from_env( "BLIS_JR_NT" ); + dim_t ir_in = bli_read_nway_from_env( "BLIS_IR_NT" ); + + dim_t jr_way = jc_in * ic_in * jr_in * ir_in; +#else + dim_t jr_way = 1; +#endif dim_t jc_way = 1; dim_t kc_way = 1; dim_t ic_way = 1; - dim_t jr_way = 1; dim_t ir_way = 1; dim_t global_num_threads = jc_way * kc_way * ic_way * jr_way * ir_way; diff --git a/frame/3/trsm/bli_trsm_threading.h b/frame/3/trsm/bli_trsm_threading.h index ad841331e..8dab87d90 100644 --- a/frame/3/trsm/bli_trsm_threading.h +++ b/frame/3/trsm/bli_trsm_threading.h @@ -53,10 +53,7 @@ typedef struct trsm_thrinfo_s trsm_thrinfo_t; #define trsm_thread_sub_opackm( thread ) thread->opackm #define trsm_thread_sub_ipackm( thread ) thread->ipackm -#define trsm_r_ir_my_iter( index, thread ) ( index % thread->n_way == thread->work_id % thread->n_way ) -#define trsm_r_jr_my_iter( index, thread ) ( index % thread->n_way == thread->work_id % thread->n_way ) -#define trsm_l_ir_my_iter( index, thread ) ( index % thread->n_way == thread->work_id % thread->n_way ) -#define trsm_l_jr_my_iter( index, thread ) ( index % thread->n_way == thread->work_id % thread->n_way ) +#define trsm_my_iter( index, thread ) ( index % thread->n_way == thread->work_id % thread->n_way ) trsm_thrinfo_t** bli_create_trsm_thrinfo_paths( ); void bli_trsm_thrinfo_free_paths( trsm_thrinfo_t** info, dim_t n_threads );