Make local copy of user's rntm_t in level-3 ops.

Details:
- In the case that the caller passes in a non-NULL rntm_t pointer into
  one of the expert APIs for a level-3 operation (e.g. bli_gemm_ex()),
  make a local copy of the rntm_t and use the address of that local copy
  in all subsequent execution (which may change the contents of the
  rntm_t). This prevents a potentially confusing situation whereby a
  user-initialized rntm_t is used once (in, say, gemm), and then found
  by the user to be in a different state before it is used a second
  time.
This commit is contained in:
Field G. Van Zee
2018-12-20 19:38:11 -06:00
parent 0476f706b9
commit 61441b24f3
3 changed files with 56 additions and 28 deletions

View File

@@ -94,9 +94,11 @@ void PASTEMAC(opname,imeth) \
cntx_t cntx_l; \
if ( ind == BLIS_3MH || ind == BLIS_4MH ) { cntx_l = *cntx; cntx = &cntx_l; } \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
/* Some induced methods execute in multiple "stages". */ \
for ( i = 0; i < nstage; ++i ) \
@@ -185,9 +187,11 @@ void PASTEMAC(opname,imeth) \
cntx_t cntx_l; \
if ( ind == BLIS_3MH || ind == BLIS_4MH ) { cntx_l = *cntx; cntx = &cntx_l; } \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
/* Some induced methods execute in multiple "stages". */ \
for ( i = 0; i < nstage; ++i ) \
@@ -274,9 +278,11 @@ void PASTEMAC(opname,imeth) \
cntx_t cntx_l; \
if ( ind == BLIS_3MH || ind == BLIS_4MH ) { cntx_l = *cntx; cntx = &cntx_l; } \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
/* Some induced methods execute in multiple "stages". */ \
for ( i = 0; i < nstage; ++i ) \
@@ -348,9 +354,11 @@ void PASTEMAC(opname,imeth) \
_cntx_init() function. */ \
cntx = bli_gks_query_ind_cntx( ind, dt ); \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
/* Some induced methods execute in multiple "stages". */ \
for ( i = 0; i < nstage; ++i ) \
@@ -408,9 +416,11 @@ void PASTEMAC(opname,imeth) \
_cntx_init() function. */ \
cntx = bli_gks_query_ind_cntx( ind, dt ); \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
{ \
/* NOTE: trsm cannot be implemented via any induced method that

View File

@@ -56,9 +56,11 @@ void PASTEMAC(opname,imeth) \
num_t dt = bli_obj_dt( c ); \
PASTECH(opname,_oft) func = PASTEMAC(opname,ind_get_avail)( dt ); \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
func( alpha, a, b, beta, c, cntx, rntm ); \
}
@@ -90,9 +92,11 @@ void PASTEMAC(opname,imeth) \
num_t dt = bli_obj_dt( c ); \
PASTECH(opname,_oft) func = PASTEMAC(opname,ind_get_avail)( dt ); \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
func( side, alpha, a, b, beta, c, cntx, rntm ); \
}
@@ -122,9 +126,11 @@ void PASTEMAC(opname,imeth) \
num_t dt = bli_obj_dt( c ); \
PASTECH(opname,_oft) func = PASTEMAC(opname,ind_get_avail)( dt ); \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
func( alpha, a, beta, c, cntx, rntm ); \
}
@@ -153,9 +159,11 @@ void PASTEMAC(opname,imeth) \
num_t dt = bli_obj_dt( b ); \
PASTECH(opname,_oft) func = PASTEMAC(opname,ind_get_avail)( dt ); \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
func( side, alpha, a, b, cntx, rntm ); \
}

View File

@@ -61,9 +61,11 @@ void PASTEMAC(opname,imeth) \
/* Obtain a valid (native) context from the gks if necessary. */ \
if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
/* Invoke the operation's front end. */ \
PASTEMAC(opname,_front) \
@@ -103,9 +105,11 @@ void PASTEMAC(opname,imeth) \
/* Obtain a valid (native) context from the gks if necessary. */ \
if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
/* Invoke the operation's front end. */ \
PASTEMAC(opname,_front) \
@@ -139,9 +143,11 @@ void PASTEMAC(opname,imeth) \
/* Obtain a valid (native) context from the gks if necessary. */ \
if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
/* Invoke the operation's front end. */ \
PASTEMAC(opname,_front) \
@@ -174,9 +180,11 @@ void PASTEMAC(opname,imeth) \
/* Obtain a valid (native) context from the gks if necessary. */ \
if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
/* Invoke the operation's front end. */ \
PASTEMAC(opname,_front) \
@@ -208,9 +216,11 @@ void PASTEMAC(opname,imeth) \
/* Obtain a valid (native) context from the gks if necessary. */ \
if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \
\
/* Initialize a local runtime with global settings if necessary. */ \
/* Initialize a local runtime with global settings if necessary. Note
that in the case that a runtime is passed in, we make a local copy. */ \
rntm_t rntm_l; \
if ( rntm == NULL ) { rntm = &rntm_l; bli_thread_init_rntm( rntm ); } \
if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } \
else { rntm_l = *rntm; rntm = &rntm_l; } \
\
/* Invoke the operation's front end. */ \
PASTEMAC(opname,_front) \