diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index b3964856b..dd73d91b4 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -1700,8 +1700,7 @@ void bli_thread_init_rntm_from_env // by bli_init_once(). bool auto_factor = FALSE; - dim_t active_level, max_levels, nt; - dim_t jc, pc, ic, jr, ir; + dim_t jc, pc, ic, jr, ir, nt; #ifdef BLIS_ENABLE_MULTITHREADING @@ -1772,8 +1771,8 @@ void bli_thread_init_rntm_from_env bli_rntm_set_blis_mt_only(FALSE, rntm); #ifdef BLIS_ENABLE_OPENMP - active_level = omp_get_active_level(); - max_levels = omp_get_max_active_levels(); + dim_t active_level = omp_get_active_level(); + dim_t max_levels = omp_get_max_active_levels(); if ( active_level < max_levels ) { nt = omp_get_max_threads(); @@ -1847,34 +1846,26 @@ void bli_thread_update_rntm_from_env rntm_t* rntm ) { + // Update tl_rntm for this user thread from runtime environment and + // current status of global_rntm. Must do this every time, in case + // global_rntm has been updated by blis-specific threading function calls. + // NOTE: We don't need to acquire the global_rntm_mutex here because this // function is updating the thread local tl_rntm (not global_rntm). - bool auto_factor = FALSE; - dim_t active_level, max_levels, nt; - dim_t jc, pc, ic, jr, ir; - bool blis_mt; - - // Set threading related parts of tl_rntm from global_rntm. - // Must do this every time, in case global_rntm has been updated by - // blis-specific threading function calls. + bool auto_factor = FALSE; + dim_t jc, pc, ic, jr, ir, nt; + bool blis_mt; + // Extract threading data from global_rntm. nt = bli_rntm_num_threads( &global_rntm ); - jc = bli_rntm_jc_ways( &global_rntm ); pc = bli_rntm_pc_ways( &global_rntm ); ic = bli_rntm_ic_ways( &global_rntm ); jr = bli_rntm_jr_ways( &global_rntm ); ir = bli_rntm_ir_ways( &global_rntm ); - blis_mt = bli_rntm_blis_mt( &global_rntm ); - bli_rntm_set_num_threads_only( nt, rntm ); - bli_rntm_set_ways_only( jc, pc, ic, jr, ir, rntm ); - bli_rntm_set_blis_mt_only( blis_mt, rntm ); - - // Update tl_rntm from runtime environment for this user thread. - #ifdef BLIS_ENABLE_MULTITHREADING // Environment variables BLIS_NUM_THREADS and BLIS_*_NT have been read @@ -1931,48 +1922,59 @@ void bli_thread_update_rntm_from_env // // OMP_NUM_THREADS environment variable is applicable only when OpenMP is enabled. - // If any BLIS_*_NT environment variable was set, then we ignore the - // value of BLIS_NUM_THREADS or OMP_NUM_THREADS and use the - // BLIS_*_NT values instead (with unset variables being treated as if - // they contained 1). - if ( jc != -1 || pc != -1 || ic != -1 || jr != -1 || ir != -1 ) + if(blis_mt) { - if ( jc == -1 ) jc = 1; - if ( pc == -1 ) pc = 1; - if ( ic == -1 ) ic = 1; - if ( jr == -1 ) jr = 1; - if ( ir == -1 ) ir = 1; + // BLIS threading env vars and/or APIs have been used. - // Unset the value for nt. - nt = -1; + // If any BLIS_*_NT environment variable was set, then we ignore the + // value of BLIS_NUM_THREADS or OMP_NUM_THREADS and use the + // BLIS_*_NT values instead (with unset variables being treated as if + // they contained 1). + if ( jc != -1 || pc != -1 || ic != -1 || jr != -1 || ir != -1 ) + { + if ( jc == -1 ) jc = 1; + if ( pc == -1 ) pc = 1; + if ( ic == -1 ) ic = 1; + if ( jr == -1 ) jr = 1; + if ( ir == -1 ) ir = 1; + + // Unset the value for nt. + nt = -1; + } + +#ifdef BLIS_ENABLE_OPENMP + // If call is not from an active OpenMP level, then it will be + // serial irrespective of BLIS threading settings. + // Reminder that we are setting values here for tl_rntm, thus + // BLIS threading settings remain unchanged in global_rntm for + // consideration in future calls. + dim_t active_level = omp_get_active_level(); + dim_t max_levels = omp_get_max_active_levels(); + if ( active_level >= max_levels ) + { + nt = -1; + jc = pc = ic = jr = ir = 1; + } +#endif - // Ensure blis_mt is set to TRUE. - // (It should already be from sync with global_rntm above). - bli_rntm_set_blis_mt_only(TRUE, rntm); } else { - // Check if blis_mt is already true, e.g. from prior call to bli_thread_set_num_threads() - // or from BLIS_NUM_THREADS being set. - if(blis_mt) - { - nt = bli_rntm_num_threads( rntm ); - } - else - { + + // BLIS threading env vars and/or APIs have not been used. + #ifdef BLIS_ENABLE_OPENMP - active_level = omp_get_active_level(); - max_levels = omp_get_max_active_levels(); - if ( active_level < max_levels ) - { - nt = omp_get_max_threads(); - } else { - nt = 1; - } -#else - nt = 1; -#endif + dim_t active_level = omp_get_active_level(); + dim_t max_levels = omp_get_max_active_levels(); + if ( active_level < max_levels ) + { + nt = omp_get_max_threads(); + } else { + nt = 1; } +#else + nt = 1; +#endif } // By this time, one of the following conditions holds: @@ -1998,6 +2000,7 @@ void bli_thread_update_rntm_from_env bli_rntm_set_auto_factor_only( auto_factor, rntm ); bli_rntm_set_num_threads_only( nt, rntm ); bli_rntm_set_ways_only( jc, pc, ic, jr, ir, rntm ); + bli_rntm_set_blis_mt_only( blis_mt, rntm ); #ifdef PRINT_THREADING printf( "bli_thread_update_rntm_from_env(): tl_rntm\n" );