Fix for DSCAL Multi-thread implementation

- In the existing implementation, when the actual number of threads
  spawned is different from the indicated number of threads for the
  parallel region, partial job is being executed.
- Fix added to identify actual number of threads spawned and
  allocate the work load to single thread in case of discrepancy
  in the number of threads spawned vs indicated.

AMD-Internal: [CPUPL-2761]
Change-Id: Ife36e6e4993bdcc5a506349b54b2177173866e32
This commit is contained in:
Arnav Sharma
2022-11-21 13:16:44 +05:30
committed by Arnav Sharma
parent 69be3b0557
commit 2bc2d11e8a

View File

@@ -283,37 +283,64 @@ void dscal_blis_impl
#pragma omp parallel num_threads( nt )
{
// The following conditions handle the optimal distribution of
// load among the threads.
// Say we have n0 = 50 & nt = 4.
// So we get 12 ( n0 / nt ) elements per thread along with 2
// remaining elements. Each of these remaining elements is given
// to the last threads, respectively.
// So, t0, t1, t2 and t3 gets 12, 12, 13 and 13 elements,
// respectively.
// Getting the actual number of threads that are spawned.
dim_t nt_real = omp_get_num_threads();
dim_t t_id = omp_get_thread_num();
dim_t npt, offset;
if ( t_id < ( nt - n_elem_rem ) )
// The actual number of threads spawned might be different
// from the predicted number of threads for which this parallel
// region is being generated. Thus, in such a case we are
// falling back to the Single-Threaded call.
if ( nt_real != nt )
{
npt = n_elem_per_thrd;
offset = t_id * npt * incx0;
// More than one thread can still be spawned but since we
// are falling back to the ST call, we are
// calling the kernel from thread 0 only.
if ( t_id == 0 )
{
bli_dscalv_zen_int10
(
BLIS_NO_CONJUGATE,
n0,
(double*) alpha,
x0, incx0,
NULL
);
}
}
else
{
npt = n_elem_per_thrd + 1;
offset = ( ( t_id * n_elem_per_thrd ) +
( t_id - ( nt - n_elem_rem ) ) ) * incx0;
}
// The following conditions handle the optimal distribution of
// load among the threads.
// Say we have n0 = 50 & nt = 4.
// So we get 12 ( n0 / nt ) elements per thread along with 2
// remaining elements. Each of these remaining elements is given
// to the last threads, respectively.
// So, t0, t1, t2 and t3 gets 12, 12, 13 and 13 elements,
// respectively.
dim_t npt, offset;
bli_dscalv_zen_int10
(
BLIS_NO_CONJUGATE,
npt,
(double*) alpha,
x0 + offset, incx0,
NULL
);
if ( t_id < ( nt - n_elem_rem ) )
{
npt = n_elem_per_thrd;
offset = t_id * npt * incx0;
}
else
{
npt = n_elem_per_thrd + 1;
offset = ( ( t_id * n_elem_per_thrd ) +
( t_id - ( nt - n_elem_rem ) ) ) * incx0;
}
bli_dscalv_zen_int10
(
BLIS_NO_CONJUGATE,
npt,
(double*) alpha,
x0 + offset, incx0,
NULL
);
}
}
}
#else