mirror of
https://github.com/amd/blis.git
synced 2026-05-11 09:39:59 +00:00
Added support for parallelism in gemm micro-kernel
This commit is contained in:
@@ -45,7 +45,8 @@ typedef void (*FUNCPTR_T)(
|
||||
void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b,
|
||||
void* beta,
|
||||
void* c, inc_t rs_c, inc_t cs_c,
|
||||
void* gemm_ukr
|
||||
void* gemm_ukr,
|
||||
thrinfo_t* thread
|
||||
);
|
||||
|
||||
static FUNCPTR_T GENARRAY(ftypes,gemm_ker_var2);
|
||||
@@ -118,7 +119,8 @@ void bli_gemm_ker_var2( obj_t* a,
|
||||
buf_b, rs_b, pd_b, ps_b,
|
||||
buf_beta,
|
||||
buf_c, rs_c, cs_c,
|
||||
gemm_ukr );
|
||||
gemm_ukr,
|
||||
thread );
|
||||
}
|
||||
|
||||
|
||||
@@ -134,7 +136,8 @@ void PASTEMAC(ch,varname)( \
|
||||
void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b, \
|
||||
void* beta, \
|
||||
void* c, inc_t rs_c, inc_t cs_c, \
|
||||
void* gemm_ukr \
|
||||
void* gemm_ukr, \
|
||||
thrinfo_t* thread \
|
||||
) \
|
||||
{ \
|
||||
/* Cast the micro-kernel address to its function pointer type. */ \
|
||||
@@ -214,18 +217,21 @@ void PASTEMAC(ch,varname)( \
|
||||
bli_auxinfo_set_ps_a( ps_a, aux ); \
|
||||
bli_auxinfo_set_ps_b( ps_b, aux ); \
|
||||
\
|
||||
b1 = b_cast; \
|
||||
c1 = c_cast; \
|
||||
thrinfo_t* caucus = thread_sub_caucus( thread ); \
|
||||
dim_t l2_num_threads = thread_num_caucuses( thread ); \
|
||||
dim_t l2_thread_id = thread_caucus_id( thread ); \
|
||||
dim_t l1_num_threads = thread_num_caucuses( caucus ); \
|
||||
dim_t l1_thread_id = thread_caucus_id( caucus ); \
|
||||
\
|
||||
/* Loop over the n dimension (NR columns at a time). */ \
|
||||
for ( j = 0; j < n_iter; ++j ) \
|
||||
for ( j = l2_thread_id; j < n_iter; j += l2_num_threads ) \
|
||||
{ \
|
||||
ctype* restrict a1; \
|
||||
ctype* restrict c11; \
|
||||
ctype* restrict b2; \
|
||||
\
|
||||
a1 = a_cast; \
|
||||
c11 = c1; \
|
||||
\
|
||||
b1 = b_cast + j * cstep_b; \
|
||||
c1 = c_cast + j * cstep_c; \
|
||||
\
|
||||
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
|
||||
\
|
||||
@@ -233,9 +239,12 @@ void PASTEMAC(ch,varname)( \
|
||||
b2 = b1; \
|
||||
\
|
||||
/* Loop over the m dimension (MR rows at a time). */ \
|
||||
for ( i = 0; i < m_iter; ++i ) \
|
||||
for ( i = l1_thread_id; i < m_iter; i += l1_num_threads ) \
|
||||
{ \
|
||||
ctype* restrict a2; \
|
||||
\
|
||||
a1 = a_cast + i * rstep_a; \
|
||||
c11 = c1 + i * rstep_c; \
|
||||
\
|
||||
m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
|
||||
\
|
||||
@@ -283,13 +292,7 @@ void PASTEMAC(ch,varname)( \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
\
|
||||
a1 += rstep_a; \
|
||||
c11 += rstep_c; \
|
||||
} \
|
||||
\
|
||||
b1 += cstep_b; \
|
||||
c1 += cstep_c; \
|
||||
} \
|
||||
\
|
||||
/*PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \
|
||||
|
||||
@@ -58,7 +58,8 @@ void PASTEMAC(ch,varname)( \
|
||||
void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b, \
|
||||
void* beta, \
|
||||
void* c, inc_t rs_c, inc_t cs_c, \
|
||||
void* gemm_ukr \
|
||||
void* gemm_ukr, \
|
||||
thrinfo_t* thread \
|
||||
);
|
||||
|
||||
INSERT_GENTPROT_BASIC( gemm_ker_var2 )
|
||||
|
||||
Reference in New Issue
Block a user