diff --git a/frame/3/gemm/bli_gemm_ker_var2.c b/frame/3/gemm/bli_gemm_ker_var2.c index ab5585d7c..4c01c9841 100644 --- a/frame/3/gemm/bli_gemm_ker_var2.c +++ b/frame/3/gemm/bli_gemm_ker_var2.c @@ -45,7 +45,8 @@ typedef void (*FUNCPTR_T)( void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b, void* beta, void* c, inc_t rs_c, inc_t cs_c, - void* gemm_ukr + void* gemm_ukr, + thrinfo_t* thread ); static FUNCPTR_T GENARRAY(ftypes,gemm_ker_var2); @@ -118,7 +119,8 @@ void bli_gemm_ker_var2( obj_t* a, buf_b, rs_b, pd_b, ps_b, buf_beta, buf_c, rs_c, cs_c, - gemm_ukr ); + gemm_ukr, + thread ); } @@ -134,7 +136,8 @@ void PASTEMAC(ch,varname)( \ void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b, \ void* beta, \ void* c, inc_t rs_c, inc_t cs_c, \ - void* gemm_ukr \ + void* gemm_ukr, \ + thrinfo_t* thread \ ) \ { \ /* Cast the micro-kernel address to its function pointer type. */ \ @@ -214,18 +217,21 @@ void PASTEMAC(ch,varname)( \ bli_auxinfo_set_ps_a( ps_a, aux ); \ bli_auxinfo_set_ps_b( ps_b, aux ); \ \ - b1 = b_cast; \ - c1 = c_cast; \ + thrinfo_t* caucus = thread_sub_caucus( thread ); \ + dim_t l2_num_threads = thread_num_caucuses( thread ); \ + dim_t l2_thread_id = thread_caucus_id( thread ); \ + dim_t l1_num_threads = thread_num_caucuses( caucus ); \ + dim_t l1_thread_id = thread_caucus_id( caucus ); \ \ /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = 0; j < n_iter; ++j ) \ + for ( j = l2_thread_id; j < n_iter; j += l2_num_threads ) \ { \ ctype* restrict a1; \ ctype* restrict c11; \ ctype* restrict b2; \ -\ - a1 = a_cast; \ - c11 = c1; \ + \ + b1 = b_cast + j * cstep_b; \ + c1 = c_cast + j * cstep_c; \ \ n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ \ @@ -233,9 +239,12 @@ void PASTEMAC(ch,varname)( \ b2 = b1; \ \ /* Loop over the m dimension (MR rows at a time). */ \ - for ( i = 0; i < m_iter; ++i ) \ + for ( i = l1_thread_id; i < m_iter; i += l1_num_threads ) \ { \ ctype* restrict a2; \ +\ + a1 = a_cast + i * rstep_a; \ + c11 = c1 + i * rstep_c; \ \ m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ \ @@ -283,13 +292,7 @@ void PASTEMAC(ch,varname)( \ beta_cast, \ c11, rs_c, cs_c ); \ } \ -\ - a1 += rstep_a; \ - c11 += rstep_c; \ } \ -\ - b1 += cstep_b; \ - c1 += cstep_c; \ } \ \ /*PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ diff --git a/frame/3/gemm/bli_gemm_ker_var2.h b/frame/3/gemm/bli_gemm_ker_var2.h index e41ee44be..71248819b 100644 --- a/frame/3/gemm/bli_gemm_ker_var2.h +++ b/frame/3/gemm/bli_gemm_ker_var2.h @@ -58,7 +58,8 @@ void PASTEMAC(ch,varname)( \ void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b, \ void* beta, \ void* c, inc_t rs_c, inc_t cs_c, \ - void* gemm_ukr \ + void* gemm_ukr, \ + thrinfo_t* thread \ ); INSERT_GENTPROT_BASIC( gemm_ker_var2 )