Added support for parallelism in gemm micro-kernel

This commit is contained in:
Tyler Smith
2014-02-27 16:29:46 -06:00
parent bfe214b633
commit e4738c48e0
2 changed files with 21 additions and 17 deletions

View File

@@ -45,7 +45,8 @@ typedef void (*FUNCPTR_T)(
void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b,
void* beta,
void* c, inc_t rs_c, inc_t cs_c,
void* gemm_ukr
void* gemm_ukr,
thrinfo_t* thread
);
static FUNCPTR_T GENARRAY(ftypes,gemm_ker_var2);
@@ -118,7 +119,8 @@ void bli_gemm_ker_var2( obj_t* a,
buf_b, rs_b, pd_b, ps_b,
buf_beta,
buf_c, rs_c, cs_c,
gemm_ukr );
gemm_ukr,
thread );
}
@@ -134,7 +136,8 @@ void PASTEMAC(ch,varname)( \
void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b, \
void* beta, \
void* c, inc_t rs_c, inc_t cs_c, \
void* gemm_ukr \
void* gemm_ukr, \
thrinfo_t* thread \
) \
{ \
/* Cast the micro-kernel address to its function pointer type. */ \
@@ -214,18 +217,21 @@ void PASTEMAC(ch,varname)( \
bli_auxinfo_set_ps_a( ps_a, aux ); \
bli_auxinfo_set_ps_b( ps_b, aux ); \
\
b1 = b_cast; \
c1 = c_cast; \
thrinfo_t* caucus = thread_sub_caucus( thread ); \
dim_t l2_num_threads = thread_num_caucuses( thread ); \
dim_t l2_thread_id = thread_caucus_id( thread ); \
dim_t l1_num_threads = thread_num_caucuses( caucus ); \
dim_t l1_thread_id = thread_caucus_id( caucus ); \
\
/* Loop over the n dimension (NR columns at a time). */ \
for ( j = 0; j < n_iter; ++j ) \
for ( j = l2_thread_id; j < n_iter; j += l2_num_threads ) \
{ \
ctype* restrict a1; \
ctype* restrict c11; \
ctype* restrict b2; \
\
a1 = a_cast; \
c11 = c1; \
\
b1 = b_cast + j * cstep_b; \
c1 = c_cast + j * cstep_c; \
\
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
\
@@ -233,9 +239,12 @@ void PASTEMAC(ch,varname)( \
b2 = b1; \
\
/* Loop over the m dimension (MR rows at a time). */ \
for ( i = 0; i < m_iter; ++i ) \
for ( i = l1_thread_id; i < m_iter; i += l1_num_threads ) \
{ \
ctype* restrict a2; \
\
a1 = a_cast + i * rstep_a; \
c11 = c1 + i * rstep_c; \
\
m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
\
@@ -283,13 +292,7 @@ void PASTEMAC(ch,varname)( \
beta_cast, \
c11, rs_c, cs_c ); \
} \
\
a1 += rstep_a; \
c11 += rstep_c; \
} \
\
b1 += cstep_b; \
c1 += cstep_c; \
} \
\
/*PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \

View File

@@ -58,7 +58,8 @@ void PASTEMAC(ch,varname)( \
void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b, \
void* beta, \
void* c, inc_t rs_c, inc_t cs_c, \
void* gemm_ukr \
void* gemm_ukr, \
thrinfo_t* thread \
);
INSERT_GENTPROT_BASIC( gemm_ker_var2 )