diff --git a/frame/base/bli_cntx.c b/frame/base/bli_cntx.c index 673987bfd..29529924c 100644 --- a/frame/base/bli_cntx.c +++ b/frame/base/bli_cntx.c @@ -756,10 +756,10 @@ void bli_cntx_set_thrloop_from_env( opid_t l3_op, side_t side, cntx_t* cntx, #ifdef BLIS_ENABLE_MULTITHREADING - int nthread = bli_env_read_nway( "BLIS_NUM_THREADS", -1 ); + int nthread = bli_thread_get_env( "BLIS_NUM_THREADS", -1 ); if ( nthread == -1 ) - nthread = bli_env_read_nway( "OMP_NUM_THREADS", -1 ); + nthread = bli_thread_get_env( "OMP_NUM_THREADS", -1 ); if ( nthread < 1 ) nthread = 1; @@ -786,10 +786,10 @@ void bli_cntx_set_thrloop_from_env( opid_t l3_op, side_t side, cntx_t* cntx, pc = 1; - dim_t jc_env = bli_env_read_nway( "BLIS_JC_NT", -1 ); - dim_t ic_env = bli_env_read_nway( "BLIS_IC_NT", -1 ); - dim_t jr_env = bli_env_read_nway( "BLIS_JR_NT", -1 ); - dim_t ir_env = bli_env_read_nway( "BLIS_IR_NT", -1 ); + dim_t jc_env = bli_thread_get_env( "BLIS_JC_NT", -1 ); + dim_t ic_env = bli_thread_get_env( "BLIS_IC_NT", -1 ); + dim_t jr_env = bli_thread_get_env( "BLIS_JR_NT", -1 ); + dim_t ir_env = bli_thread_get_env( "BLIS_IR_NT", -1 ); if (jc_env != -1 || ic_env != -1 || jr_env != -1 || ir_env != -1) { diff --git a/frame/include/bli_system.h b/frame/include/bli_system.h index 99a63d550..5f54605d8 100644 --- a/frame/include/bli_system.h +++ b/frame/include/bli_system.h @@ -41,6 +41,7 @@ #include #include #include +#include // Determine if we are on a 64-bit or 32-bit architecture #if defined(_M_X64) || defined(__x86_64) || defined(__aarch64__) || \ diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index 37ec94292..1dde88206 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -1156,19 +1156,112 @@ void bli_partition_2x2( dim_t nthread, dim_t work1, dim_t work2, // ----------------------------------------------------------------------------- -// Some utilities -dim_t bli_env_read_nway( const char* env, dim_t fallback ) +dim_t bli_thread_get_env( const char* env, dim_t fallback ) { - dim_t num = fallback; - char* str = getenv( env ); + dim_t r_val; + char* str; + // Query the environment variable and store the result in str. + str = getenv( env ); + + // Set the return value based on the string obtained from getenv(). if ( str != NULL ) - { - num = strtol( str, NULL, 10 ); - } - return num; + { + // If there was no error, convert the string to an integer and + // prepare to return that integer. + r_val = strtol( str, NULL, 10 ); + } + else + { + // If there was an error, use the "fallback" as the return value. + r_val = fallback; + } + + return r_val; } +dim_t bli_thread_get_jc_nt( void ) +{ + return bli_thread_get_env( "BLIS_JC_NT", 1 ); +} + +dim_t bli_thread_get_ic_nt( void ) +{ + return bli_thread_get_env( "BLIS_IC_NT", 1 ); +} + +dim_t bli_thread_get_jr_nt( void ) +{ + return bli_thread_get_env( "BLIS_JR_NT", 1 ); +} + +dim_t bli_thread_get_ir_nt( void ) +{ + return bli_thread_get_env( "BLIS_IR_NT", 1 ); +} + +dim_t bli_thread_get_num_threads( void ) +{ + return bli_thread_get_env( "BLIS_NUM_THREADS", 1 ); +} + +void bli_thread_set_env( const char* env, dim_t value ) +{ + dim_t r_val; + char value_str[32]; + const char* fs_32 = "%u"; + const char* fs_64 = "%lu"; + + // Convert the string to an integer, but vary the format specifier + // depending on the integer type size. + if ( bli_info_get_int_type_size() == 32 ) sprintf( value_str, fs_32, value ); + else sprintf( value_str, fs_64, value ); + + // Set the environment variable using the string we just wrote to via + // sprintf(). (The 'TRUE' argument means we want to overwrite the current + // value if the environment variable already exists.) + r_val = setenv( env, value_str, TRUE ); + + // Check the return value in case something went horribly wrong. + if ( r_val == -1 ) + { + char err_str[128]; + + // Query the human-readable error string corresponding to errno. + strerror_r( errno, err_str, 128 ); + + // Print the error message. + bli_print_msg( err_str, __FILE__, __LINE__ ); + } +} + +void bli_thread_set_jc_nt( dim_t value ) +{ + bli_thread_set_env( "BLIS_JC_NT", value ); +} + +void bli_thread_set_ic_nt( dim_t value ) +{ + bli_thread_set_env( "BLIS_IC_NT", value ); +} + +void bli_thread_set_jr_nt( dim_t value ) +{ + bli_thread_set_env( "BLIS_JR_NT", value ); +} + +void bli_thread_set_ir_nt( dim_t value ) +{ + bli_thread_set_env( "BLIS_IR_NT", value ); +} + +void bli_thread_set_num_threads( dim_t value ) +{ + bli_thread_set_env( "BLIS_NUM_THREADS", value ); +} + +// ----------------------------------------------------------------------------- + dim_t bli_gcd( dim_t x, dim_t y ) { while ( y != 0 ) diff --git a/frame/thread/bli_thread.h b/frame/thread/bli_thread.h index 1998253cf..9092bc84d 100644 --- a/frame/thread/bli_thread.h +++ b/frame/thread/bli_thread.h @@ -164,6 +164,8 @@ void bli_l3_thread_decorator cntl_t* cntl ); +// ----------------------------------------------------------------------------- + // Factorization and partitioning prototypes typedef struct { @@ -178,8 +180,26 @@ dim_t bli_next_prime_factor(bli_prime_factors_t* factors); void bli_partition_2x2(dim_t nthread, dim_t work1, dim_t work2, dim_t* nt1, dim_t* nt2); -// Miscellaneous prototypes -dim_t bli_env_read_nway( const char* env, dim_t fallback ); +// ----------------------------------------------------------------------------- + +dim_t bli_thread_get_env( const char* env, dim_t fallback ); + +dim_t bli_thread_get_jc_nt( void ); +dim_t bli_thread_get_ic_nt( void ); +dim_t bli_thread_get_jr_nt( void ); +dim_t bli_thread_get_ir_nt( void ); +dim_t bli_thread_get_num_threads( void ); + +void bli_thread_set_env( const char* env, dim_t value ); + +void bli_thread_set_jc_nt( dim_t value ); +void bli_thread_set_ic_nt( dim_t value ); +void bli_thread_set_jr_nt( dim_t value ); +void bli_thread_set_ir_nt( dim_t value ); +void bli_thread_set_num_threads( dim_t value ); + +// ----------------------------------------------------------------------------- + dim_t bli_gcd( dim_t x, dim_t y ); dim_t bli_lcm( dim_t x, dim_t y ); dim_t bli_ipow( dim_t base, dim_t power );