diff --git a/frame/3/gemm/bli_gemm_ker_var2.c b/frame/3/gemm/bli_gemm_ker_var2.c index 942e4010c..fe8b85a98 100644 --- a/frame/3/gemm/bli_gemm_ker_var2.c +++ b/frame/3/gemm/bli_gemm_ker_var2.c @@ -240,11 +240,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ diff --git a/frame/3/gemm/bli_gemm_ker_var5.c b/frame/3/gemm/bli_gemm_ker_var5.c index 35552d36d..ba3f3e6eb 100644 --- a/frame/3/gemm/bli_gemm_ker_var5.c +++ b/frame/3/gemm/bli_gemm_ker_var5.c @@ -261,7 +261,7 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ } \ diff --git a/frame/3/herk/bli_herk_l_ker_var2.c b/frame/3/herk/bli_herk_l_ker_var2.c index fe858fb8e..0355ce19f 100644 --- a/frame/3/herk/bli_herk_l_ker_var2.c +++ b/frame/3/herk/bli_herk_l_ker_var2.c @@ -274,11 +274,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ diff --git a/frame/3/herk/bli_herk_u_ker_var2.c b/frame/3/herk/bli_herk_u_ker_var2.c index 6faf5a38b..d6fad2543 100644 --- a/frame/3/herk/bli_herk_u_ker_var2.c +++ b/frame/3/herk/bli_herk_u_ker_var2.c @@ -274,11 +274,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ diff --git a/frame/3/trmm/bli_trmm_ll_ker_var2.c b/frame/3/trmm/bli_trmm_ll_ker_var2.c index a81a9e211..2581380d9 100644 --- a/frame/3/trmm/bli_trmm_ll_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ll_ker_var2.c @@ -274,11 +274,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + k_a1011 * PACKMR; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ @@ -329,11 +329,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ diff --git a/frame/3/trmm/bli_trmm_lu_ker_var2.c b/frame/3/trmm/bli_trmm_lu_ker_var2.c index b6d1a6db2..03e7522fb 100644 --- a/frame/3/trmm/bli_trmm_lu_ker_var2.c +++ b/frame/3/trmm/bli_trmm_lu_ker_var2.c @@ -282,11 +282,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + k_a1112 * PACKMR; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ @@ -337,11 +337,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ diff --git a/frame/3/trmm/bli_trmm_rl_ker_var2.c b/frame/3/trmm/bli_trmm_rl_ker_var2.c index 681e980a9..b2a311b79 100644 --- a/frame/3/trmm/bli_trmm_rl_ker_var2.c +++ b/frame/3/trmm/bli_trmm_rl_ker_var2.c @@ -282,11 +282,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + k_b1121 * PACKNR; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ @@ -344,11 +344,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ diff --git a/frame/3/trmm/bli_trmm_ru_ker_var2.c b/frame/3/trmm/bli_trmm_ru_ker_var2.c index cf46ae53a..7b482b374 100644 --- a/frame/3/trmm/bli_trmm_ru_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ru_ker_var2.c @@ -282,11 +282,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + k_b0111 * PACKNR; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ @@ -344,11 +344,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ diff --git a/frame/3/trsm/bli_trsm_ll_ker_var2.c b/frame/3/trsm/bli_trsm_ll_ker_var2.c index 372caf6eb..abea139e3 100644 --- a/frame/3/trsm/bli_trsm_ll_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ll_ker_var2.c @@ -286,11 +286,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + k_a1011 * PACKMR; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ @@ -338,11 +338,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ diff --git a/frame/3/trsm/bli_trsm_lu_ker_var2.c b/frame/3/trsm/bli_trsm_lu_ker_var2.c index 95027e450..cf589a793 100644 --- a/frame/3/trsm/bli_trsm_lu_ker_var2.c +++ b/frame/3/trsm/bli_trsm_lu_ker_var2.c @@ -297,11 +297,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + k_a1112 * PACKMR; \ - if ( ib == m_iter - 1 ) \ + if ( bli_is_last_iter( ib, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ @@ -349,11 +349,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( ib == m_iter - 1 ) \ + if ( bli_is_last_iter( ib, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ diff --git a/frame/3/trsm/bli_trsm_rl_ker_var2.c b/frame/3/trsm/bli_trsm_rl_ker_var2.c index 1751b39f1..5be5ace9a 100644 --- a/frame/3/trsm/bli_trsm_rl_ker_var2.c +++ b/frame/3/trsm/bli_trsm_rl_ker_var2.c @@ -302,11 +302,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + k_b1121 * PACKNR; \ - if ( jb == n_iter - 1 ) \ + if ( bli_is_last_iter( jb, n_iter ) ) \ b2 = b_cast; \ } \ \ @@ -362,11 +362,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( jb == n_iter - 1 ) \ + if ( bli_is_last_iter( jb, n_iter ) ) \ b2 = b_cast; \ } \ \ diff --git a/frame/3/trsm/bli_trsm_ru_ker_var2.c b/frame/3/trsm/bli_trsm_ru_ker_var2.c index 420a13997..d66e459ec 100644 --- a/frame/3/trsm/bli_trsm_ru_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ru_ker_var2.c @@ -296,11 +296,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + k_b0111 * PACKNR; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ @@ -356,11 +356,11 @@ void PASTEMAC(ch,varname)( \ \ /* Compute the addresses of the next panels of A and B. */ \ a2 = a1 + rstep_a; \ - if ( i == m_iter - 1 ) \ + if ( bli_is_last_iter( i, m_iter ) ) \ { \ a2 = a_cast; \ b2 = b1 + cstep_b; \ - if ( j == n_iter - 1 ) \ + if ( bli_is_last_iter( j, n_iter ) ) \ b2 = b_cast; \ } \ \ diff --git a/frame/include/bli_param_macro_defs.h b/frame/include/bli_param_macro_defs.h index 4ef2fd6f6..186df305c 100644 --- a/frame/include/bli_param_macro_defs.h +++ b/frame/include/bli_param_macro_defs.h @@ -485,6 +485,10 @@ \ ( i1 != 0 || left == 0 ) +#define bli_is_last_iter( i1, iter ) \ +\ + ( i1 == iter - 1 ) + // packbuf_t-related diff --git a/testsuite/src/test_gemm_ukr.c b/testsuite/src/test_gemm_ukr.c index 9200d667e..8ca519b0e 100644 --- a/testsuite/src/test_gemm_ukr.c +++ b/testsuite/src/test_gemm_ukr.c @@ -363,12 +363,13 @@ void libblis_test_gemm_ukr_check( obj_t* alpha, #define FUNCPTR_T gemm_ukr_fp typedef void (*FUNCPTR_T)( - dim_t k, - void* alpha, - void* a, - void* b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c + dim_t k, + void* alpha, + void* a, + void* b, + void* beta, + void* c, inc_t rs_c, inc_t cs_c, + auxinfo_t* data ); static FUNCPTR_T GENARRAY(ftypes,gemm_ukr); @@ -396,8 +397,20 @@ void bli_gemm_ukr( obj_t* alpha, void* buf_beta = bli_obj_buffer_for_1x1( dt, *beta ); + inc_t ps_a = bli_obj_panel_stride( *a ); + inc_t ps_b = bli_obj_panel_stride( *b ); + FUNCPTR_T f; + auxinfo_t data; + + + // Fill the auxinfo_t struct in case the micro-kernel uses it. + bli_auxinfo_set_next_a( buf_a, data ); + bli_auxinfo_set_next_b( buf_b, data ); + bli_auxinfo_set_ps_a( ps_a, data ); + bli_auxinfo_set_ps_b( ps_b, data ); + // Index into the type combination array to extract the correct // function pointer. f = ftypes[dt]; @@ -408,7 +421,8 @@ void bli_gemm_ukr( obj_t* alpha, buf_a, buf_b, buf_beta, - buf_c, rs_c, cs_c ); + buf_c, rs_c, cs_c, + &data ); } @@ -416,12 +430,13 @@ void bli_gemm_ukr( obj_t* alpha, #define GENTFUNC( ctype, ch, varname, ukrname ) \ \ void PASTEMAC(ch,varname)( \ - dim_t k, \ - void* alpha, \ - void* a, \ - void* b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c \ + dim_t k, \ + void* alpha, \ + void* a, \ + void* b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* data \ ) \ { \ PASTEMAC(ch,ukrname)( k, \ @@ -430,7 +445,7 @@ void PASTEMAC(ch,varname)( \ b, \ beta, \ c, rs_c, cs_c, \ - NULL ); \ + data ); \ } INSERT_GENTFUNC_BASIC( gemm_ukr, GEMM_UKERNEL ) diff --git a/testsuite/src/test_gemm_ukr.h b/testsuite/src/test_gemm_ukr.h index 284fc796f..94b33312c 100644 --- a/testsuite/src/test_gemm_ukr.h +++ b/testsuite/src/test_gemm_ukr.h @@ -48,12 +48,13 @@ void bli_gemm_ukr( obj_t* alpha, #define GENTPROT( ctype, ch, varname ) \ \ void PASTEMAC(ch,varname)( \ - dim_t k, \ - void* alpha, \ - void* a, \ - void* b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c \ + dim_t k, \ + void* alpha, \ + void* a, \ + void* b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* data \ ); INSERT_GENTPROT_BASIC( gemm_ukr ) diff --git a/testsuite/src/test_gemmtrsm_ukr.c b/testsuite/src/test_gemmtrsm_ukr.c index f8d514665..2cf551a80 100644 --- a/testsuite/src/test_gemmtrsm_ukr.c +++ b/testsuite/src/test_gemmtrsm_ukr.c @@ -508,48 +508,53 @@ void bli_gemmtrsm_ukr( obj_t* alpha, { dim_t k = bli_obj_width( *a1x ); - num_t dt = bli_obj_datatype( *c11 ); + num_t dt = bli_obj_datatype( *c11 ); - void* buf_a1x = bli_obj_buffer_at_off( *a1x ); + void* buf_a1x = bli_obj_buffer_at_off( *a1x ); - void* buf_a11 = bli_obj_buffer_at_off( *a11 ); + void* buf_a11 = bli_obj_buffer_at_off( *a11 ); - void* buf_bx1 = bli_obj_buffer_at_off( *bx1 ); + void* buf_bx1 = bli_obj_buffer_at_off( *bx1 ); - void* buf_b11 = bli_obj_buffer_at_off( *b11 ); + void* buf_b11 = bli_obj_buffer_at_off( *b11 ); - void* buf_c11 = bli_obj_buffer_at_off( *c11 ); - inc_t rs_c = bli_obj_row_stride( *c11 ); - inc_t cs_c = bli_obj_col_stride( *c11 ); + void* buf_c11 = bli_obj_buffer_at_off( *c11 ); + inc_t rs_c = bli_obj_row_stride( *c11 ); + inc_t cs_c = bli_obj_col_stride( *c11 ); void* buf_alpha = bli_obj_buffer_for_1x1( dt, *alpha ); - FUNCPTR_T f; + inc_t ps_a = bli_obj_panel_stride( *a1x ); + inc_t ps_b = bli_obj_panel_stride( *bx1 ); + + FUNCPTR_T f; auxinfo_t data; // Fill the auxinfo_t struct in case the micro-kernel uses it. - if ( bli_obj_is_lower( *a11 ) ) { bli_auxinfo_set_next_a( buf_a1x, data ); } - else { bli_auxinfo_set_next_a( buf_a11, data ); } + if ( bli_obj_is_lower( *a11 ) ) + { bli_auxinfo_set_next_a( buf_a1x, data ); } + else + { bli_auxinfo_set_next_a( buf_a11, data ); } bli_auxinfo_set_next_b( buf_bx1, data ); - // STILL NEED TO FILL IN PANEL STRIDE FIELDS! + bli_auxinfo_set_ps_a( ps_a, data ); + bli_auxinfo_set_ps_b( ps_b, data ); + // Index into the type combination array to extract the correct + // function pointer. + if ( bli_obj_is_lower( *a11 ) ) f = ftypes_l[dt]; + else f = ftypes_u[dt]; - // Index into the type combination array to extract the correct - // function pointer. - if ( bli_obj_is_lower( *a11 ) ) f = ftypes_l[dt]; - else f = ftypes_u[dt]; - - // Invoke the function. - f( k, + // Invoke the function. + f( k, buf_alpha, buf_a1x, buf_a11, buf_bx1, - buf_b11, - buf_c11, rs_c, cs_c, + buf_b11, + buf_c11, rs_c, cs_c, &data ); } @@ -568,13 +573,13 @@ void PASTEMAC(ch,varname)( \ auxinfo_t* data \ ) \ { \ - PASTEMAC(ch,ukrname)( k, \ - alpha, \ - a1x, \ - a11, \ - bx1, \ - b11, \ - c11, rs_c, cs_c, \ + PASTEMAC(ch,ukrname)( k, \ + alpha, \ + a1x, \ + a11, \ + bx1, \ + b11, \ + c11, rs_c, cs_c, \ data ); \ } diff --git a/testsuite/src/test_trsm_ukr.c b/testsuite/src/test_trsm_ukr.c index bf2f6c9a9..b3409baca 100644 --- a/testsuite/src/test_trsm_ukr.c +++ b/testsuite/src/test_trsm_ukr.c @@ -391,37 +391,40 @@ void bli_trsm_ukr( obj_t* a, obj_t* b, obj_t* c ) { - num_t dt = bli_obj_datatype( *c ); + num_t dt = bli_obj_datatype( *c ); - void* buf_a = bli_obj_buffer_at_off( *a ); + void* buf_a = bli_obj_buffer_at_off( *a ); - void* buf_b = bli_obj_buffer_at_off( *b ); + void* buf_b = bli_obj_buffer_at_off( *b ); - void* buf_c = bli_obj_buffer_at_off( *c ); - inc_t rs_c = bli_obj_row_stride( *c ); - inc_t cs_c = bli_obj_col_stride( *c ); + void* buf_c = bli_obj_buffer_at_off( *c ); + inc_t rs_c = bli_obj_row_stride( *c ); + inc_t cs_c = bli_obj_col_stride( *c ); - FUNCPTR_T f; + inc_t ps_a = bli_obj_panel_stride( *a ); + inc_t ps_b = bli_obj_panel_stride( *b ); + + FUNCPTR_T f; auxinfo_t data; // Fill the auxinfo_t struct in case the micro-kernel uses it. - bli_auxinfo_set_next_a( buf_a, data ); + bli_auxinfo_set_next_a( buf_a, data ); bli_auxinfo_set_next_b( buf_b, data ); - // STILL NEED TO FILL IN PANEL STRIDE FIELDS! + bli_auxinfo_set_ps_a( ps_a, data ); + bli_auxinfo_set_ps_b( ps_b, data ); + // Index into the type combination array to extract the correct + // function pointer. + if ( bli_obj_is_lower( *a ) ) f = ftypes_l[dt]; + else f = ftypes_u[dt]; - // Index into the type combination array to extract the correct - // function pointer. - if ( bli_obj_is_lower( *a ) ) f = ftypes_l[dt]; - else f = ftypes_u[dt]; - - // Invoke the function. - f( buf_a, - buf_b, - buf_c, rs_c, cs_c, + // Invoke the function. + f( buf_a, + buf_b, + buf_c, rs_c, cs_c, &data ); } @@ -436,9 +439,9 @@ void PASTEMAC(ch,varname)( \ auxinfo_t* data \ ) \ { \ - PASTEMAC(ch,ukrname)( a, \ - b, \ - c, rs_c, cs_c, \ + PASTEMAC(ch,ukrname)( a, \ + b, \ + c, rs_c, cs_c, \ data ); \ }