diff --git a/frame/compat/bla_axpy_amd.c b/frame/compat/bla_axpy_amd.c index f6a64c40a..0e24d7d4a 100644 --- a/frame/compat/bla_axpy_amd.c +++ b/frame/compat/bla_axpy_amd.c @@ -398,11 +398,17 @@ void daxpy_blis_impl // Get the thread ID dim_t thread_id = omp_get_thread_num(); - // Calculate the compute range for the current thread + // Get the actual number of threads spawned + dim_t nt_use = omp_get_num_threads(); + + /* + Calculate the compute range for the current thread + based on the actual number of threads spawned + */ bli_thread_vector_partition ( n_elem, - nt, + nt_use, &start, &length, thread_id ); diff --git a/frame/compat/bla_dot_amd.c b/frame/compat/bla_dot_amd.c index 3d0648d27..213fd14a4 100644 --- a/frame/compat/bla_dot_amd.c +++ b/frame/compat/bla_dot_amd.c @@ -466,11 +466,17 @@ double ddot_blis_impl // Get the thread ID dim_t thread_id = omp_get_thread_num(); - // Calculate the compute range for the current thread + // Get the actual number of threads spawned + dim_t nt_use = omp_get_num_threads(); + + /* + Calculate the compute range for the current thread + based on the actual number of threads spawned + */ bli_thread_vector_partition ( n_elem, - nt, + nt_use, &start, &length, thread_id ); @@ -492,13 +498,18 @@ double ddot_blis_impl ); } - // Accumulating the nt thread outputs to rho - for ( dim_t i = 0; i < nt; i++ ) - rho += rho_temp[i]; - - // Releasing the allocated memory if it was allocated - if( bli_mem_is_alloc(&mem_buf_rho)) + /* + Accumulate the values in rho_temp only when mem is allocated. + When the memory cannot be allocated rho_temp will point to + rho + */ + if (bli_mem_is_alloc(&mem_buf_rho)) { + // Accumulating the nt thread outputs to rho + for (dim_t i = 0; i < nt; i++) + rho += rho_temp[i]; + + // Releasing the allocated memory if it was allocated bli_membrk_release(&rntm, &mem_buf_rho); } #endif diff --git a/frame/compat/bla_scal_amd.c b/frame/compat/bla_scal_amd.c index bec3515a0..041c1b6a8 100644 --- a/frame/compat/bla_scal_amd.c +++ b/frame/compat/bla_scal_amd.c @@ -383,11 +383,17 @@ void dscal_blis_impl // Get the thread ID dim_t thread_id = omp_get_thread_num(); - // Calculate the compute range for the current thread + // Get the actual number of threads spawned + dim_t nt_use = omp_get_num_threads(); + + /* + Calculate the compute range for the current thread + based on the actual number of threads spawned + */ bli_thread_vector_partition ( n_elem, - nt, + nt_use, &start, &length, thread_id ); @@ -563,11 +569,17 @@ void zdscal_blis_impl // Get the thread ID dim_t thread_id = omp_get_thread_num(); - // Calculate the compute range for the current thread + // Get the actual number of threads spawned + dim_t nt_use = omp_get_num_threads(); + + /* + Calculate the compute range for the current thread + based on the actual number of threads spawned + */ bli_thread_vector_partition ( n_elem, - nt, + nt_use, &start, &length, thread_id );