From 7b2f469d5465ed73b1ca88124bc9a1987388aa27 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Mon, 8 Sep 2014 14:49:50 -0500 Subject: [PATCH] Retired trmm_t control tree definitions, usage. Details: - Replaced all trmm_t control tree instances and usage with that of gemm_t. This change is similar to the recent retirement of the herk_t control tree. - Tweaked packm blocked variants so that the triangular code does NOT assume that k is a multiple of MR (when A is triangular) or NR (when B is triangular). This means that bottom-right micro-panels packed for trmm will have different zero-padding when k is not already a multiple of the relevant register blocksize. While this creates a seemingly arbitrary and unnecessary distinction between trmm and trsm packing, it actually allows trmm to be handled with one control tree, instead of one for left and one for right side cases. Furthermore, since only one tree is required, it can now be handled by the gemm tree, and thus the trmm control tree definitions can be disposed of entirely. - Tweaked trmm macro-kernels so that they do NOT inflate k up to a multiple of MR (when A is triangular) or NR (when B is triangular). - Misc. tweaks and cleanups to bli_packm_struc_cxk_4m.c and _3m.c, some of which are to facilitate above-mentioned changes whereby k is no longer required to be a multiple of register blocksize when packing triangular micro-panels. - Adjusted trmm3 according to above changes. - Retired trmm_t control tree creation/initialization functions. --- frame/1m/packm/bli_packm_blk_var1.c | 3 +- frame/1m/packm/bli_packm_blk_var2.c | 3 +- frame/1m/packm/bli_packm_struc_cxk.c | 11 ++- frame/1m/packm/bli_packm_struc_cxk_3m.c | 83 ++++++++--------- frame/1m/packm/bli_packm_struc_cxk_4m.c | 56 +++++------- frame/3/trmm/3m/bli_trmm3m.h | 1 - frame/3/trmm/3m/bli_trmm3m_cntl.c | 24 ++--- frame/3/trmm/3m/bli_trmm3m_entry.c | 6 +- frame/3/trmm/4m/bli_trmm4m.h | 1 - frame/3/trmm/4m/bli_trmm4m_cntl.c | 24 ++--- frame/3/trmm/4m/bli_trmm4m_entry.c | 6 +- frame/3/trmm/bli_trmm.h | 1 - frame/3/trmm/bli_trmm_blk_var1f.c | 4 +- frame/3/trmm/bli_trmm_blk_var1f.h | 2 +- frame/3/trmm/bli_trmm_blk_var2b.c | 6 +- frame/3/trmm/bli_trmm_blk_var2b.h | 2 +- frame/3/trmm/bli_trmm_blk_var2f.c | 6 +- frame/3/trmm/bli_trmm_blk_var2f.h | 2 +- frame/3/trmm/bli_trmm_blk_var3b.c | 4 +- frame/3/trmm/bli_trmm_blk_var3b.h | 2 +- frame/3/trmm/bli_trmm_blk_var3f.c | 4 +- frame/3/trmm/bli_trmm_blk_var3f.h | 2 +- frame/3/trmm/bli_trmm_check.c | 2 +- frame/3/trmm/bli_trmm_check.h | 2 +- frame/3/trmm/bli_trmm_entry.c | 6 +- frame/3/trmm/bli_trmm_front.c | 11 +-- frame/3/trmm/bli_trmm_front.h | 3 +- frame/3/trmm/bli_trmm_int.c | 4 +- frame/3/trmm/bli_trmm_int.h | 2 +- frame/3/trmm/bli_trmm_ll_ker_var2.c | 27 +++--- frame/3/trmm/bli_trmm_ll_ker_var2.h | 2 +- frame/3/trmm/bli_trmm_lu_ker_var2.c | 19 ++-- frame/3/trmm/bli_trmm_lu_ker_var2.h | 2 +- frame/3/trmm/bli_trmm_rl_ker_var2.c | 19 ++-- frame/3/trmm/bli_trmm_rl_ker_var2.h | 2 +- frame/3/trmm/bli_trmm_ru_ker_var2.c | 19 ++-- frame/3/trmm/bli_trmm_ru_ker_var2.h | 2 +- frame/3/trmm/{ => old}/bli_trmm_cntl.c | 106 ++-------------------- frame/3/trmm/{ => old}/bli_trmm_cntl.h | 2 +- frame/3/trmm/other/bli_trmm_ll_blk_var1.c | 2 +- frame/3/trmm/other/bli_trmm_ll_blk_var4.c | 4 +- frame/3/trmm/other/bli_trmm_lu_blk_var1.c | 2 +- frame/3/trmm/other/bli_trmm_lu_blk_var4.c | 4 +- frame/3/trmm3/3m/bli_trmm33m_entry.c | 6 +- frame/3/trmm3/4m/bli_trmm34m_entry.c | 6 +- frame/3/trmm3/bli_trmm3_entry.c | 6 +- frame/3/trmm3/bli_trmm3_front.c | 11 +-- frame/3/trmm3/bli_trmm3_front.h | 3 +- frame/cntl/bli_cntl_init.c | 6 -- 49 files changed, 195 insertions(+), 338 deletions(-) rename frame/3/trmm/{ => old}/bli_trmm_cntl.c (63%) rename frame/3/trmm/{ => old}/bli_trmm_cntl.h (98%) diff --git a/frame/1m/packm/bli_packm_blk_var1.c b/frame/1m/packm/bli_packm_blk_var1.c index 5f0ed38ae..43567996e 100644 --- a/frame/1m/packm/bli_packm_blk_var1.c +++ b/frame/1m/packm/bli_packm_blk_var1.c @@ -336,7 +336,8 @@ void PASTEMAC(ch,varname)( \ { \ panel_off_i = 0; \ panel_len_i = bli_abs( diagoffc_i ) + panel_dim_i; \ - panel_len_max_i = bli_abs( diagoffc_i ) + panel_dim_max; \ + panel_len_max_i = bli_min( bli_abs( diagoffc_i ) + panel_dim_max, \ + panel_len_max ); \ diagoffp_i = diagoffc_i; \ } \ else /* if ( ( row_stored && bli_is_lower( uploc ) ) || \ diff --git a/frame/1m/packm/bli_packm_blk_var2.c b/frame/1m/packm/bli_packm_blk_var2.c index 2c5a75e22..5fc3cedd5 100644 --- a/frame/1m/packm/bli_packm_blk_var2.c +++ b/frame/1m/packm/bli_packm_blk_var2.c @@ -391,7 +391,8 @@ void PASTEMAC(ch,varname)( \ { \ panel_off_i = 0; \ panel_len_i = bli_abs( diagoffc_i ) + panel_dim_i; \ - panel_len_max_i = bli_abs( diagoffc_i ) + panel_dim_max; \ + panel_len_max_i = bli_min( bli_abs( diagoffc_i ) + panel_dim_max, \ + panel_len_max ); \ diagoffp_i = diagoffc_i; \ } \ else /* if ( ( row_stored && bli_is_lower( uploc ) ) || \ diff --git a/frame/1m/packm/bli_packm_struc_cxk.c b/frame/1m/packm/bli_packm_struc_cxk.c index e779f925e..fcdfd943f 100644 --- a/frame/1m/packm/bli_packm_struc_cxk.c +++ b/frame/1m/packm/bli_packm_struc_cxk.c @@ -210,6 +210,16 @@ void PASTEMAC(ch,varname)( \ p_br, rs_p, cs_p ); \ } \ } \ +\ +\ +/* + if ( bli_is_col_packed( schema ) ) \ + PASTEMAC(ch,fprintm)( stdout, "packm_struc_cxk: bp copied", m_panel_max, n_panel_max, \ + p, rs_p, cs_p, "%4.1f", "" ); \ + else if ( bli_is_row_packed( schema ) ) \ + PASTEMAC(ch,fprintm)( stdout, "packm_struc_cxk: ap copied", m_panel_max, n_panel_max, \ + p, rs_p, cs_p, "%4.1f", "" ); \ +*/ \ } INSERT_GENTFUNC_BASIC( packm_struc_cxk, packm_cxk ) @@ -501,7 +511,6 @@ void PASTEMAC(ch,varname)( \ p, rs_p, cs_p ); \ } \ \ -\ } INSERT_GENTFUNC_BASIC( packm_tri_cxk, packm_cxk ) diff --git a/frame/1m/packm/bli_packm_struc_cxk_3m.c b/frame/1m/packm/bli_packm_struc_cxk_3m.c index 9dbfbac7d..a3c32116b 100644 --- a/frame/1m/packm/bli_packm_struc_cxk_3m.c +++ b/frame/1m/packm/bli_packm_struc_cxk_3m.c @@ -557,18 +557,6 @@ void PASTEMAC(ch,varname)( \ inc_t is_p, inc_t ldp \ ) \ { \ - bool_t row_stored; \ - bool_t col_stored; \ -\ -\ - /* Create flags to incidate row or column storage. Note that the - schema bit that encodes row or column is describing the form of - micro-panel, not the storage in the micro-panel. Hence the - mismatch in "row" and "column" semantics. */ \ - row_stored = bli_is_col_packed( schema ); \ - col_stored = bli_is_row_packed( schema ); \ -\ -\ /* Pack the panel. */ \ PASTEMAC(ch,kername)( conjc, \ panel_dim, \ @@ -580,10 +568,24 @@ void PASTEMAC(ch,varname)( \ \ /* Tweak the panel according to its triangular structure */ \ { \ + ctype_r* p_r = ( ctype_r* )p + 0; \ + ctype_r* p_i = ( ctype_r* )p + is_p; \ + ctype_r* p_rpi = ( ctype_r* )p + 2*is_p; \ +\ dim_t j = bli_abs( diagoffp ); \ - ctype_r* p11_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p11_i = ( ctype_r* )p + is_p + (j )*ldp; \ - ctype_r* p11_rpi = ( ctype_r* )p + 2*is_p + (j )*ldp; \ + ctype_r* p11_r = p_r + (j )*ldp; \ + ctype_r* p11_i = p_i + (j )*ldp; \ + ctype_r* p11_rpi = p_rpi + (j )*ldp; \ +\ + dim_t p11_m = m_panel; \ + dim_t p11_n = n_panel; \ +\ + dim_t min_p11_m_n; \ +\ + if ( diagoffp < 0 ) p11_m -= j; \ + else if ( diagoffp > 0 ) p11_n -= j; \ +\ + min_p11_m_n = bli_min( p11_m, p11_n ); \ \ \ /* If the diagonal of c is implicitly unit, explicitly set the @@ -594,21 +596,21 @@ void PASTEMAC(ch,varname)( \ ctype_r kappa_i = PASTEMAC(ch,imag)( *kappa ); \ dim_t i; \ \ - PASTEMAC(chr,setd)( 0, \ + PASTEMAC(chr,setd)( diagoffp, \ m_panel, \ n_panel, \ &kappa_r, \ - p11_r, rs_p, cs_p ); \ - PASTEMAC(chr,setd)( 0, \ + p_r, rs_p, cs_p ); \ + PASTEMAC(chr,setd)( diagoffp, \ m_panel, \ n_panel, \ &kappa_i, \ - p11_i, rs_p, cs_p ); \ + p_i, rs_p, cs_p ); \ \ /* Update the diagonal of the p11 section of the rpi panel. It simply needs to contain the sum of diagonals of p11_r and p11_i. */ \ - for ( i = 0; i < panel_dim; ++i ) \ + for ( i = 0; i < min_p11_m_n; ++i ) \ { \ ctype_r* pi11_r = p11_r + (i )*rs_p + (i )*cs_p; \ ctype_r* pi11_i = p11_i + (i )*rs_p + (i )*cs_p; \ @@ -626,7 +628,7 @@ void PASTEMAC(ch,varname)( \ { \ dim_t i; \ \ - for ( i = 0; i < panel_dim; ++i ) \ + for ( i = 0; i < min_p11_m_n; ++i ) \ { \ ctype_r* pi11_r = p11_r + (i )*rs_p + (i )*cs_p; \ ctype_r* pi11_i = p11_i + (i )*rs_p + (i )*cs_p; \ @@ -644,34 +646,33 @@ void PASTEMAC(ch,varname)( \ micro-kernel; however, zero-filling is needed for trmm, which uses the gemm micro-kernel.*/ \ { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - uplo_t uplop11 = uploc; \ - doff_t diagoffp11 = 0; \ + ctype_r* restrict zero_r = PASTEMAC(chr,0); \ + uplo_t uplop = uploc; \ \ - bli_toggle_uplo( uplop11 ); \ - bli_shift_diag_offset_to_shrink_uplo( uplop11, diagoffp11 ); \ + bli_toggle_uplo( uplop ); \ + bli_shift_diag_offset_to_shrink_uplo( uplop, diagoffp ); \ \ - PASTEMAC(chr,setm)( diagoffp11, \ + PASTEMAC(chr,setm)( diagoffp, \ BLIS_NONUNIT_DIAG, \ - uplop11, \ - panel_dim, \ - panel_dim, \ + uplop, \ + m_panel, \ + n_panel, \ zero_r, \ - p11_r, rs_p, cs_p ); \ - PASTEMAC(chr,setm)( diagoffp11, \ + p_r, rs_p, cs_p ); \ + PASTEMAC(chr,setm)( diagoffp, \ BLIS_NONUNIT_DIAG, \ - uplop11, \ - panel_dim, \ - panel_dim, \ + uplop, \ + m_panel, \ + n_panel, \ zero_r, \ - p11_i, rs_p, cs_p ); \ - PASTEMAC(chr,setm)( diagoffp11, \ + p_i, rs_p, cs_p ); \ + PASTEMAC(chr,setm)( diagoffp, \ BLIS_NONUNIT_DIAG, \ - uplop11, \ - panel_dim, \ - panel_dim, \ + uplop, \ + m_panel, \ + n_panel, \ zero_r, \ - p11_rpi, rs_p, cs_p ); \ + p_rpi, rs_p, cs_p ); \ } \ } \ } diff --git a/frame/1m/packm/bli_packm_struc_cxk_4m.c b/frame/1m/packm/bli_packm_struc_cxk_4m.c index a5dcc0de2..c3c3a811c 100644 --- a/frame/1m/packm/bli_packm_struc_cxk_4m.c +++ b/frame/1m/packm/bli_packm_struc_cxk_4m.c @@ -529,18 +529,6 @@ void PASTEMAC(ch,varname)( \ inc_t is_p, inc_t ldp \ ) \ { \ - bool_t row_stored; \ - bool_t col_stored; \ -\ -\ - /* Create flags to incidate row or column storage. Note that the - schema bit that encodes row or column is describing the form of - micro-panel, not the storage in the micro-panel. Hence the - mismatch in "row" and "column" semantics. */ \ - row_stored = bli_is_col_packed( schema ); \ - col_stored = bli_is_row_packed( schema ); \ -\ -\ /* Pack the panel. */ \ PASTEMAC(ch,kername)( conjc, \ panel_dim, \ @@ -552,9 +540,12 @@ void PASTEMAC(ch,varname)( \ \ /* Tweak the panel according to its triangular structure */ \ { \ + ctype_r* p_r = ( ctype_r* )p; \ + ctype_r* p_i = ( ctype_r* )p + is_p; \ +\ dim_t j = bli_abs( diagoffp ); \ - ctype_r* p11_r = ( ctype_r* )p + (j )*ldp; \ - ctype_r* p11_i = ( ctype_r* )p + is_p + (j )*ldp; \ + ctype_r* p11_r = p_r + (j )*ldp; \ + ctype_r* p11_i = p_i + (j )*ldp; \ \ /* If the diagonal of c is implicitly unit, explicitly set the the diagonal of the packed panel to kappa. */ \ @@ -563,16 +554,16 @@ void PASTEMAC(ch,varname)( \ ctype_r kappa_r = PASTEMAC(ch,real)( *kappa ); \ ctype_r kappa_i = PASTEMAC(ch,imag)( *kappa ); \ \ - PASTEMAC(chr,setd)( 0, \ + PASTEMAC(chr,setd)( diagoffp, \ m_panel, \ n_panel, \ &kappa_r, \ - p11_r, rs_p, cs_p ); \ - PASTEMAC(chr,setd)( 0, \ + p_r, rs_p, cs_p ); \ + PASTEMAC(chr,setd)( diagoffp, \ m_panel, \ n_panel, \ &kappa_i, \ - p11_i, rs_p, cs_p ); \ + p_i, rs_p, cs_p ); \ } \ \ \ @@ -600,27 +591,26 @@ void PASTEMAC(ch,varname)( \ micro-kernel; however, zero-filling is needed for trmm, which uses the gemm micro-kernel.*/ \ { \ - ctype_r* restrict zero_r = PASTEMAC(chr,0); \ - uplo_t uplop11 = uploc; \ - doff_t diagoffp11 = 0; \ + ctype_r* restrict zero_r = PASTEMAC(chr,0); \ + uplo_t uplop = uploc; \ \ - bli_toggle_uplo( uplop11 ); \ - bli_shift_diag_offset_to_shrink_uplo( uplop11, diagoffp11 ); \ + bli_toggle_uplo( uplop ); \ + bli_shift_diag_offset_to_shrink_uplo( uplop, diagoffp ); \ \ - PASTEMAC(chr,setm)( diagoffp11, \ + PASTEMAC(chr,setm)( diagoffp, \ BLIS_NONUNIT_DIAG, \ - uplop11, \ - panel_dim, \ - panel_dim, \ + uplop, \ + m_panel, \ + n_panel, \ zero_r, \ - p11_r, rs_p, cs_p ); \ - PASTEMAC(chr,setm)( diagoffp11, \ + p_r, rs_p, cs_p ); \ + PASTEMAC(chr,setm)( diagoffp, \ BLIS_NONUNIT_DIAG, \ - uplop11, \ - panel_dim, \ - panel_dim, \ + uplop, \ + m_panel, \ + n_panel, \ zero_r, \ - p11_i, rs_p, cs_p ); \ + p_i, rs_p, cs_p ); \ } \ } \ } diff --git a/frame/3/trmm/3m/bli_trmm3m.h b/frame/3/trmm/3m/bli_trmm3m.h index 197b5320d..a56b19170 100644 --- a/frame/3/trmm/3m/bli_trmm3m.h +++ b/frame/3/trmm/3m/bli_trmm3m.h @@ -32,7 +32,6 @@ */ -#include "bli_trmm3m_cntl.h" #include "bli_trmm3m_entry.h" diff --git a/frame/3/trmm/3m/bli_trmm3m_cntl.c b/frame/3/trmm/3m/bli_trmm3m_cntl.c index b24460e8d..6bbbf331d 100644 --- a/frame/3/trmm/3m/bli_trmm3m_cntl.c +++ b/frame/3/trmm/3m/bli_trmm3m_cntl.c @@ -53,18 +53,18 @@ packm_t* trmm3m_l_packb_cntl; packm_t* trmm3m_r_packa_cntl; packm_t* trmm3m_r_packb_cntl; -trmm_t* trmm3m_cntl_bp_ke; +gemm_t* trmm3m_cntl_bp_ke; -trmm_t* trmm3m_l_cntl_op_bp; -trmm_t* trmm3m_l_cntl_mm_op; -trmm_t* trmm3m_l_cntl_vl_mm; +gemm_t* trmm3m_l_cntl_op_bp; +gemm_t* trmm3m_l_cntl_mm_op; +gemm_t* trmm3m_l_cntl_vl_mm; -trmm_t* trmm3m_r_cntl_op_bp; -trmm_t* trmm3m_r_cntl_mm_op; -trmm_t* trmm3m_r_cntl_vl_mm; +gemm_t* trmm3m_r_cntl_op_bp; +gemm_t* trmm3m_r_cntl_mm_op; +gemm_t* trmm3m_r_cntl_vl_mm; -trmm_t* trmm3m_l_cntl; -trmm_t* trmm3m_r_cntl; +gemm_t* trmm3m_l_cntl; +gemm_t* trmm3m_r_cntl; void bli_trmm3m_cntl_init() @@ -77,7 +77,7 @@ void bli_trmm3m_cntl_init() // IMPORTANT: for consistency with trsm, "k" dim // multiple is set to mr. gemm3m_mr, - gemm3m_mr, + gemm3m_kr, TRUE, // densify FALSE, // do NOT invert diagonal FALSE, // reverse iteration if upper? @@ -91,9 +91,9 @@ void bli_trmm3m_cntl_init() BLIS_VARIANT2, // IMPORTANT: m dim multiple here must be mr // since "k" dim multiple is set to mr above. - gemm3m_mr, + gemm3m_kr, gemm3m_nr, - FALSE, // already dense + TRUE, // already dense FALSE, // do NOT invert diagonal FALSE, // reverse iteration if upper? FALSE, // reverse iteration if lower? diff --git a/frame/3/trmm/3m/bli_trmm3m_entry.c b/frame/3/trmm/3m/bli_trmm3m_entry.c index acb8ec4b9..0b4b7f012 100644 --- a/frame/3/trmm/3m/bli_trmm3m_entry.c +++ b/frame/3/trmm/3m/bli_trmm3m_entry.c @@ -34,8 +34,7 @@ #include "blis.h" -extern trmm_t* trmm3m_l_cntl; -extern trmm_t* trmm3m_r_cntl; +extern gemm_t* gemm3m_cntl; void bli_trmm3m_entry( side_t side, obj_t* alpha, @@ -43,7 +42,6 @@ void bli_trmm3m_entry( side_t side, obj_t* b ) { bli_trmm_front( side, alpha, a, b, - trmm3m_l_cntl, - trmm3m_r_cntl ); + gemm3m_cntl ); } diff --git a/frame/3/trmm/4m/bli_trmm4m.h b/frame/3/trmm/4m/bli_trmm4m.h index 7af55fca4..eeb6d33b3 100644 --- a/frame/3/trmm/4m/bli_trmm4m.h +++ b/frame/3/trmm/4m/bli_trmm4m.h @@ -32,7 +32,6 @@ */ -#include "bli_trmm4m_cntl.h" #include "bli_trmm4m_entry.h" diff --git a/frame/3/trmm/4m/bli_trmm4m_cntl.c b/frame/3/trmm/4m/bli_trmm4m_cntl.c index 0876f2f62..5a979134a 100644 --- a/frame/3/trmm/4m/bli_trmm4m_cntl.c +++ b/frame/3/trmm/4m/bli_trmm4m_cntl.c @@ -53,18 +53,18 @@ packm_t* trmm4m_l_packb_cntl; packm_t* trmm4m_r_packa_cntl; packm_t* trmm4m_r_packb_cntl; -trmm_t* trmm4m_cntl_bp_ke; +gemm_t* trmm4m_cntl_bp_ke; -trmm_t* trmm4m_l_cntl_op_bp; -trmm_t* trmm4m_l_cntl_mm_op; -trmm_t* trmm4m_l_cntl_vl_mm; +gemm_t* trmm4m_l_cntl_op_bp; +gemm_t* trmm4m_l_cntl_mm_op; +gemm_t* trmm4m_l_cntl_vl_mm; -trmm_t* trmm4m_r_cntl_op_bp; -trmm_t* trmm4m_r_cntl_mm_op; -trmm_t* trmm4m_r_cntl_vl_mm; +gemm_t* trmm4m_r_cntl_op_bp; +gemm_t* trmm4m_r_cntl_mm_op; +gemm_t* trmm4m_r_cntl_vl_mm; -trmm_t* trmm4m_l_cntl; -trmm_t* trmm4m_r_cntl; +gemm_t* trmm4m_l_cntl; +gemm_t* trmm4m_r_cntl; void bli_trmm4m_cntl_init() @@ -77,7 +77,7 @@ void bli_trmm4m_cntl_init() // IMPORTANT: for consistency with trsm, "k" dim // multiple is set to mr. gemm4m_mr, - gemm4m_mr, + gemm4m_kr, TRUE, // densify FALSE, // do NOT invert diagonal FALSE, // reverse iteration if upper? @@ -91,9 +91,9 @@ void bli_trmm4m_cntl_init() BLIS_VARIANT2, // IMPORTANT: m dim multiple here must be mr // since "k" dim multiple is set to mr above. - gemm4m_mr, + gemm4m_kr, gemm4m_nr, - FALSE, // already dense + TRUE, // already dense FALSE, // do NOT invert diagonal FALSE, // reverse iteration if upper? FALSE, // reverse iteration if lower? diff --git a/frame/3/trmm/4m/bli_trmm4m_entry.c b/frame/3/trmm/4m/bli_trmm4m_entry.c index be9794c37..edb612493 100644 --- a/frame/3/trmm/4m/bli_trmm4m_entry.c +++ b/frame/3/trmm/4m/bli_trmm4m_entry.c @@ -34,8 +34,7 @@ #include "blis.h" -extern trmm_t* trmm4m_l_cntl; -extern trmm_t* trmm4m_r_cntl; +extern gemm_t* gemm4m_cntl; void bli_trmm4m_entry( side_t side, obj_t* alpha, @@ -43,7 +42,6 @@ void bli_trmm4m_entry( side_t side, obj_t* b ) { bli_trmm_front( side, alpha, a, b, - trmm4m_l_cntl, - trmm4m_r_cntl ); + gemm4m_cntl ); } diff --git a/frame/3/trmm/bli_trmm.h b/frame/3/trmm/bli_trmm.h index d58387c13..886824aa1 100644 --- a/frame/3/trmm/bli_trmm.h +++ b/frame/3/trmm/bli_trmm.h @@ -32,7 +32,6 @@ */ -#include "bli_trmm_cntl.h" #include "bli_trmm_check.h" #include "bli_trmm_entry.h" #include "bli_trmm_front.h" diff --git a/frame/3/trmm/bli_trmm_blk_var1f.c b/frame/3/trmm/bli_trmm_blk_var1f.c index e9bf126dd..2f82ddcb2 100644 --- a/frame/3/trmm/bli_trmm_blk_var1f.c +++ b/frame/3/trmm/bli_trmm_blk_var1f.c @@ -37,7 +37,7 @@ void bli_trmm_blk_var1f( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ) { obj_t b_pack_s; @@ -136,7 +136,7 @@ void bli_trmm_blk_var1f( obj_t* a, b_pack, &BLIS_ONE, c1_pack, - cntl_sub_trmm( cntl ), + cntl_sub_gemm( cntl ), trmm_thread_sub_trmm( thread ) ); // Unpack C1 (if C1 was packed). diff --git a/frame/3/trmm/bli_trmm_blk_var1f.h b/frame/3/trmm/bli_trmm_blk_var1f.h index 82a924d1b..ccf3118a8 100644 --- a/frame/3/trmm/bli_trmm_blk_var1f.h +++ b/frame/3/trmm/bli_trmm_blk_var1f.h @@ -35,6 +35,6 @@ void bli_trmm_blk_var1f( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ); diff --git a/frame/3/trmm/bli_trmm_blk_var2b.c b/frame/3/trmm/bli_trmm_blk_var2b.c index eb305fad5..5a368ca21 100644 --- a/frame/3/trmm/bli_trmm_blk_var2b.c +++ b/frame/3/trmm/bli_trmm_blk_var2b.c @@ -37,8 +37,8 @@ void bli_trmm_blk_var2b( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, - trmm_thrinfo_t* thread) + gemm_t* cntl, + trmm_thrinfo_t* thread ) { obj_t a_pack_s; obj_t b1_pack_s, c1_pack_s; @@ -124,7 +124,7 @@ void bli_trmm_blk_var2b( obj_t* a, b1_pack, &BLIS_ONE, c1_pack, - cntl_sub_trmm( cntl ), + cntl_sub_gemm( cntl ), trmm_thread_sub_trmm( thread ) ); // Unpack C1 (if C1 was packed). diff --git a/frame/3/trmm/bli_trmm_blk_var2b.h b/frame/3/trmm/bli_trmm_blk_var2b.h index 5c8f41ca5..dda36c6f8 100644 --- a/frame/3/trmm/bli_trmm_blk_var2b.h +++ b/frame/3/trmm/bli_trmm_blk_var2b.h @@ -35,6 +35,6 @@ void bli_trmm_blk_var2b( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ); diff --git a/frame/3/trmm/bli_trmm_blk_var2f.c b/frame/3/trmm/bli_trmm_blk_var2f.c index aa9aa1bea..e3e32bd02 100644 --- a/frame/3/trmm/bli_trmm_blk_var2f.c +++ b/frame/3/trmm/bli_trmm_blk_var2f.c @@ -37,8 +37,8 @@ void bli_trmm_blk_var2f( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, - trmm_thrinfo_t* thread) + gemm_t* cntl, + trmm_thrinfo_t* thread ) { obj_t a_pack_s; obj_t b1_pack_s, c1_pack_s; @@ -124,7 +124,7 @@ void bli_trmm_blk_var2f( obj_t* a, b1_pack, &BLIS_ONE, c1_pack, - cntl_sub_trmm( cntl ), + cntl_sub_gemm( cntl ), trmm_thread_sub_trmm( thread ) ); // Unpack C1 (if C1 was packed). diff --git a/frame/3/trmm/bli_trmm_blk_var2f.h b/frame/3/trmm/bli_trmm_blk_var2f.h index 75b764e17..4c53ebcac 100644 --- a/frame/3/trmm/bli_trmm_blk_var2f.h +++ b/frame/3/trmm/bli_trmm_blk_var2f.h @@ -35,6 +35,6 @@ void bli_trmm_blk_var2f( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ); diff --git a/frame/3/trmm/bli_trmm_blk_var3b.c b/frame/3/trmm/bli_trmm_blk_var3b.c index a25356d6b..39772904c 100644 --- a/frame/3/trmm/bli_trmm_blk_var3b.c +++ b/frame/3/trmm/bli_trmm_blk_var3b.c @@ -37,7 +37,7 @@ void bli_trmm_blk_var3b( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ) { obj_t c_pack_s; @@ -119,7 +119,7 @@ void bli_trmm_blk_var3b( obj_t* a, b1_pack, &BLIS_ONE, c_pack, - cntl_sub_trmm( cntl ), + cntl_sub_gemm( cntl ), trmm_thread_sub_trmm( thread ) ); } diff --git a/frame/3/trmm/bli_trmm_blk_var3b.h b/frame/3/trmm/bli_trmm_blk_var3b.h index a88c89b93..81629c7e3 100644 --- a/frame/3/trmm/bli_trmm_blk_var3b.h +++ b/frame/3/trmm/bli_trmm_blk_var3b.h @@ -35,6 +35,6 @@ void bli_trmm_blk_var3b( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ); diff --git a/frame/3/trmm/bli_trmm_blk_var3f.c b/frame/3/trmm/bli_trmm_blk_var3f.c index 4b43b6cd9..3d53e3d0f 100644 --- a/frame/3/trmm/bli_trmm_blk_var3f.c +++ b/frame/3/trmm/bli_trmm_blk_var3f.c @@ -37,7 +37,7 @@ void bli_trmm_blk_var3f( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ) { obj_t c_pack_s; @@ -119,7 +119,7 @@ void bli_trmm_blk_var3f( obj_t* a, b1_pack, &BLIS_ONE, c_pack, - cntl_sub_trmm( cntl ), + cntl_sub_gemm( cntl ), trmm_thread_sub_trmm( thread ) ); } diff --git a/frame/3/trmm/bli_trmm_blk_var3f.h b/frame/3/trmm/bli_trmm_blk_var3f.h index d0596941a..51342567b 100644 --- a/frame/3/trmm/bli_trmm_blk_var3f.h +++ b/frame/3/trmm/bli_trmm_blk_var3f.h @@ -35,6 +35,6 @@ void bli_trmm_blk_var3f( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ); diff --git a/frame/3/trmm/bli_trmm_check.c b/frame/3/trmm/bli_trmm_check.c index 0f231f2d8..67fcf4ee4 100644 --- a/frame/3/trmm/bli_trmm_check.c +++ b/frame/3/trmm/bli_trmm_check.c @@ -116,7 +116,7 @@ void bli_trmm_int_check( obj_t* alpha, obj_t* b, obj_t* beta, obj_t* c, - trmm_t* cntl ) + gemm_t* cntl ) { err_t e_val; diff --git a/frame/3/trmm/bli_trmm_check.h b/frame/3/trmm/bli_trmm_check.h index 55a0b57a2..426b01003 100644 --- a/frame/3/trmm/bli_trmm_check.h +++ b/frame/3/trmm/bli_trmm_check.h @@ -49,5 +49,5 @@ void bli_trmm_int_check( obj_t* alpha, obj_t* b, obj_t* beta, obj_t* c, - trmm_t* cntl ); + gemm_t* cntl ); diff --git a/frame/3/trmm/bli_trmm_entry.c b/frame/3/trmm/bli_trmm_entry.c index 5b76725b3..0ff119048 100644 --- a/frame/3/trmm/bli_trmm_entry.c +++ b/frame/3/trmm/bli_trmm_entry.c @@ -34,8 +34,7 @@ #include "blis.h" -extern trmm_t* trmm_l_cntl; -extern trmm_t* trmm_r_cntl; +extern gemm_t* gemm_cntl; void bli_trmm_entry( side_t side, obj_t* alpha, @@ -43,7 +42,6 @@ void bli_trmm_entry( side_t side, obj_t* b ) { bli_trmm_front( side, alpha, a, b, - trmm_l_cntl, - trmm_r_cntl ); + gemm_cntl ); } diff --git a/frame/3/trmm/bli_trmm_front.c b/frame/3/trmm/bli_trmm_front.c index 9ff4fbb4c..61bbfa461 100644 --- a/frame/3/trmm/bli_trmm_front.c +++ b/frame/3/trmm/bli_trmm_front.c @@ -38,10 +38,8 @@ void bli_trmm_front( side_t side, obj_t* alpha, obj_t* a, obj_t* b, - trmm_t* l_cntl, - trmm_t* r_cntl ) + gemm_t* cntl ) { - trmm_t* cntl; obj_t a_local; obj_t b_local; obj_t c_local; @@ -101,10 +99,10 @@ void bli_trmm_front( side_t side, if ( ( bli_obj_is_row_stored( c_local ) && bli_func_prefers_contig_cols( bli_obj_datatype( c_local ), - cntl_gemm_ukrs( l_cntl ) ) ) || + cntl_gemm_ukrs( cntl ) ) ) || ( bli_obj_is_col_stored( c_local ) && bli_func_prefers_contig_rows( bli_obj_datatype( c_local ), - cntl_gemm_ukrs( l_cntl ) ) ) + cntl_gemm_ukrs( cntl ) ) ) ) { bli_toggle_side( side ); @@ -129,9 +127,6 @@ void bli_trmm_front( side_t side, bli_obj_set_as_root( b_local ); bli_obj_set_as_root( c_local ); - // Choose the control tree. - if ( bli_is_left( side ) ) cntl = l_cntl; - else cntl = r_cntl; trmm_thrinfo_t** infos = bli_create_trmm_thrinfo_paths( bli_is_right( side ) ); dim_t n_threads = thread_num_threads( infos[0] ); diff --git a/frame/3/trmm/bli_trmm_front.h b/frame/3/trmm/bli_trmm_front.h index bd80ff7cb..e06198c17 100644 --- a/frame/3/trmm/bli_trmm_front.h +++ b/frame/3/trmm/bli_trmm_front.h @@ -36,6 +36,5 @@ void bli_trmm_front( side_t side, obj_t* alpha, obj_t* a, obj_t* b, - trmm_t* l_cntl, - trmm_t* r_cntl ); + gemm_t* cntl ); diff --git a/frame/3/trmm/bli_trmm_int.c b/frame/3/trmm/bli_trmm_int.c index 8ada7ca20..4878aefea 100644 --- a/frame/3/trmm/bli_trmm_int.c +++ b/frame/3/trmm/bli_trmm_int.c @@ -39,7 +39,7 @@ typedef void (*FUNCPTR_T)( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ); static FUNCPTR_T vars[2][2][4][3] = @@ -89,7 +89,7 @@ void bli_trmm_int( obj_t* alpha, obj_t* b, obj_t* beta, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ) { obj_t a_local; diff --git a/frame/3/trmm/bli_trmm_int.h b/frame/3/trmm/bli_trmm_int.h index 36231aacc..b4595fd38 100644 --- a/frame/3/trmm/bli_trmm_int.h +++ b/frame/3/trmm/bli_trmm_int.h @@ -37,5 +37,5 @@ void bli_trmm_int( obj_t* alpha, obj_t* b, obj_t* beta, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ); diff --git a/frame/3/trmm/bli_trmm_ll_ker_var2.c b/frame/3/trmm/bli_trmm_ll_ker_var2.c index 44a2d2e42..dfe3a36e5 100644 --- a/frame/3/trmm/bli_trmm_ll_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ll_ker_var2.c @@ -56,7 +56,7 @@ static FUNCPTR_T GENARRAY(ftypes,trmm_ll_ker_var2); void bli_trmm_ll_ker_var2( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ) { num_t dt_exec = bli_obj_execution_datatype( *c ); @@ -217,12 +217,14 @@ void PASTEMAC(ch,varname)( \ if ( bli_is_strictly_above_diag_n( diagoffa, m, k ) ) return; \ \ /* Compute the storage stride for the triangular matrix A, which is - usually PACKMR. However, in the case of 3m, the storage stride - captures the (PACKMR * 3/2) factor embedded in the panel stride. - Notice that we must first inflate k up to a multiple of MR, since - the panel stride was originally computed using this inflated k - dimension. */ \ - k_full = ( k % MR != 0 ? k + MR - ( k % MR ) : k ); \ + usually PACKMR. However, in the case of 3m, the storage stride + captures the (PACKMR * 3/2) factor embedded in the panel stride. + Note that trmm does NOT require k to be a multiple of MR or NR + (depending on whether A or B is the triangular matrix), so we can + use k as-is. By contrast, trsm must use an "inflated" version of + k since trsm requires that k be a multiple of MR (when A is + triangular) or NR (when B is triangular). */ \ + k_full = k; \ ss_a = ps_a / k_full; \ \ /* If there is a zero region above where the diagonal of A intersects the @@ -238,13 +240,6 @@ void PASTEMAC(ch,varname)( \ diagoffa = 0; \ c_cast = c_cast + (i )*rs_c; \ } \ -\ - /* For consistency with the trsm macro-kernels, we inflate k to be a - multiple of MR, if necessary. This is needed because we typically - use the same packm variant for trmm as for trsm, and trsm has this - constraint that k must be a multiple of MR so that it can safely - handle bottom-right corner edges of the triangle. */ \ - if ( k % MR != 0 ) k += MR - ( k % MR ); \ \ /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ PASTEMAC(ch,set0s_mxn)( MR, NR, \ @@ -313,7 +308,7 @@ void PASTEMAC(ch,varname)( \ packed so we can index into the corresponding location in b1. */ \ off_a1011 = 0; \ - k_a1011 = diagoffa_i + MR; \ + k_a1011 = bli_min( diagoffa_i + MR, k ); \ \ if( trmm_l_ir_my_iter( i, ir_thread ) ) \ { \ @@ -436,6 +431,8 @@ void PASTEMAC(ch,varname)( \ b1 += cstep_b; \ c1 += cstep_c; \ } \ +/*PASTEMAC(ch,fprintm)( stdout, "trmm_ll_ker_var2: a1", MR, k_a1011, a1, 1, MR, "%4.1f", "" );*/ \ +/*PASTEMAC(ch,fprintm)( stdout, "trmm_ll_ker_var2: b1", k_a1011, NR, b1_i, NR, 1, "%4.1f", "" );*/ \ } INSERT_GENTFUNC_BASIC( trmm_ll_ker_var2, gemm_ukr_t ) diff --git a/frame/3/trmm/bli_trmm_ll_ker_var2.h b/frame/3/trmm/bli_trmm_ll_ker_var2.h index fc338374b..8e5b3e066 100644 --- a/frame/3/trmm/bli_trmm_ll_ker_var2.h +++ b/frame/3/trmm/bli_trmm_ll_ker_var2.h @@ -39,7 +39,7 @@ void bli_trmm_ll_ker_var2( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ); diff --git a/frame/3/trmm/bli_trmm_lu_ker_var2.c b/frame/3/trmm/bli_trmm_lu_ker_var2.c index 20b48f436..1ccb29d6d 100644 --- a/frame/3/trmm/bli_trmm_lu_ker_var2.c +++ b/frame/3/trmm/bli_trmm_lu_ker_var2.c @@ -56,7 +56,7 @@ static FUNCPTR_T GENARRAY(ftypes,trmm_lu_ker_var2); void bli_trmm_lu_ker_var2( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ) { num_t dt_exec = bli_obj_execution_datatype( *c ); @@ -219,10 +219,12 @@ void PASTEMAC(ch,varname)( \ /* Compute the storage stride for the triangular matrix A, which is usually PACKMR. However, in the case of 3m, the storage stride captures the (PACKMR * 3/2) factor embedded in the panel stride. - Notice that we must first inflate k up to a multiple of MR, since - the panel stride was originally computed using this inflated k - dimension. */ \ - k_full = ( k % MR != 0 ? k + MR - ( k % MR ) : k ); \ + Note that trmm does NOT require k to be a multiple of MR or NR + (depending on whether A or B is the triangular matrix), so we can + use k as-is. By contrast, trsm must use an "inflated" version of + k since trsm requires that k be a multiple of MR (when A is + triangular) or NR (when B is triangular). */ \ + k_full = k; \ ss_a = ps_a / k_full; \ \ /* If there is a zero region to the left of where the diagonal of A @@ -245,13 +247,6 @@ void PASTEMAC(ch,varname)( \ { \ m = -diagoffa + k; \ } \ -\ - /* For consistency with the trsm macro-kernels, we inflate k to be a - multiple of MR, if necessary. This is needed because we typically - use the same packm variant for trmm as for trsm, and trsm has this - constraint that k must be a multiple of MR so that it can safely - handle bottom-right corner edges of the triangle. */ \ - if ( k % MR != 0 ) k += MR - ( k % MR ); \ \ /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ PASTEMAC(ch,set0s_mxn)( MR, NR, \ diff --git a/frame/3/trmm/bli_trmm_lu_ker_var2.h b/frame/3/trmm/bli_trmm_lu_ker_var2.h index ff4c49869..1a9ae3352 100644 --- a/frame/3/trmm/bli_trmm_lu_ker_var2.h +++ b/frame/3/trmm/bli_trmm_lu_ker_var2.h @@ -39,7 +39,7 @@ void bli_trmm_lu_ker_var2( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ); diff --git a/frame/3/trmm/bli_trmm_rl_ker_var2.c b/frame/3/trmm/bli_trmm_rl_ker_var2.c index fcf28cdbd..261b45ba9 100644 --- a/frame/3/trmm/bli_trmm_rl_ker_var2.c +++ b/frame/3/trmm/bli_trmm_rl_ker_var2.c @@ -56,7 +56,7 @@ static FUNCPTR_T GENARRAY(ftypes,trmm_rl_ker_var2); void bli_trmm_rl_ker_var2( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ) { num_t dt_exec = bli_obj_execution_datatype( *c ); @@ -219,10 +219,12 @@ void PASTEMAC(ch,varname)( \ /* Compute the storage stride for the triangular matrix B, which is usually PACKNR. However, in the case of 3m, the storage stride captures the (PACKNR * 3/2) factor embedded in the panel stride. - Notice that we must first inflate k up to a multiple of NR, since - the panel stride was originally computed using this inflated k - dimension. */ \ - k_full = ( k % NR != 0 ? k + NR - ( k % NR ) : k ); \ + Note that trmm does NOT require k to be a multiple of MR or NR + (depending on whether A or B is the triangular matrix), so we can + use k as-is. By contrast, trsm must use an "inflated" version of + k since trsm requires that k be a multiple of MR (when A is + triangular) or NR (when B is triangular). */ \ + k_full = k; \ ss_b = ps_b / k_full; \ \ /* If there is a zero region above where the diagonal of B intersects @@ -245,13 +247,6 @@ void PASTEMAC(ch,varname)( \ { \ n = diagoffb + k; \ } \ -\ - /* For consistency with the trsm macro-kernels, we inflate k to be a - multiple of NR, if necessary. This is needed because we typically - use the same packm variant for trmm as for trsm, and trsm has this - constraint that k must be a multiple of NR so that it can safely - handle bottom-right corner edges of the triangle. */ \ - if ( k % NR != 0 ) k += NR - ( k % NR ); \ \ /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ PASTEMAC(ch,set0s_mxn)( MR, NR, \ diff --git a/frame/3/trmm/bli_trmm_rl_ker_var2.h b/frame/3/trmm/bli_trmm_rl_ker_var2.h index 3f4132d94..278ae11c0 100644 --- a/frame/3/trmm/bli_trmm_rl_ker_var2.h +++ b/frame/3/trmm/bli_trmm_rl_ker_var2.h @@ -39,7 +39,7 @@ void bli_trmm_rl_ker_var2( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ); diff --git a/frame/3/trmm/bli_trmm_ru_ker_var2.c b/frame/3/trmm/bli_trmm_ru_ker_var2.c index 7a61c7608..9bbc576b5 100644 --- a/frame/3/trmm/bli_trmm_ru_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ru_ker_var2.c @@ -56,7 +56,7 @@ static FUNCPTR_T GENARRAY(ftypes,trmm_ru_ker_var2); void bli_trmm_ru_ker_var2( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ) { num_t dt_exec = bli_obj_execution_datatype( *c ); @@ -219,10 +219,12 @@ void PASTEMAC(ch,varname)( \ /* Compute the storage stride for the triangular matrix B, which is usually PACKNR. However, in the case of 3m, the storage stride captures the (PACKNR * 3/2) factor embedded in the panel stride. - Notice that we must first inflate k up to a multiple of NR, since - the panel stride was originally computed using this inflated k - dimension. */ \ - k_full = ( k % NR != 0 ? k + NR - ( k % NR ) : k ); \ + Note that trmm does NOT require k to be a multiple of MR or NR + (depending on whether A or B is the triangular matrix), so we can + use k as-is. By contrast, trsm must use an "inflated" version of + k since trsm requires that k be a multiple of MR (when A is + triangular) or NR (when B is triangular). */ \ + k_full = k; \ ss_b = ps_b / k_full; \ \ /* If there is a zero region to the left of where the diagonal of B @@ -246,13 +248,6 @@ void PASTEMAC(ch,varname)( \ { \ k = -diagoffb + n; \ } \ -\ - /* For consistency with the trsm macro-kernels, we inflate k to be a - multiple of NR, if necessary. This is needed because we typically - use the same packm variant for trmm as for trsm, and trsm has this - constraint that k must be a multiple of NR so that it can safely - handle bottom-right corner edges of the triangle. */ \ - if ( k % NR != 0 ) k += NR - ( k % NR ); \ \ /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ PASTEMAC(ch,set0s_mxn)( MR, NR, \ diff --git a/frame/3/trmm/bli_trmm_ru_ker_var2.h b/frame/3/trmm/bli_trmm_ru_ker_var2.h index 49840791c..b7cdc0944 100644 --- a/frame/3/trmm/bli_trmm_ru_ker_var2.h +++ b/frame/3/trmm/bli_trmm_ru_ker_var2.h @@ -39,7 +39,7 @@ void bli_trmm_ru_ker_var2( obj_t* a, obj_t* b, obj_t* c, - trmm_t* cntl, + gemm_t* cntl, trmm_thrinfo_t* thread ); diff --git a/frame/3/trmm/bli_trmm_cntl.c b/frame/3/trmm/old/bli_trmm_cntl.c similarity index 63% rename from frame/3/trmm/bli_trmm_cntl.c rename to frame/3/trmm/old/bli_trmm_cntl.c index 6c46cd40d..d3b2fb8ac 100644 --- a/frame/3/trmm/bli_trmm_cntl.c +++ b/frame/3/trmm/old/bli_trmm_cntl.c @@ -50,21 +50,13 @@ extern gemm_t* gemm_cntl_bp_ke; packm_t* trmm_l_packa_cntl; packm_t* trmm_l_packb_cntl; -packm_t* trmm_r_packa_cntl; -packm_t* trmm_r_packb_cntl; - trmm_t* trmm_cntl_bp_ke; trmm_t* trmm_l_cntl_op_bp; trmm_t* trmm_l_cntl_mm_op; trmm_t* trmm_l_cntl_vl_mm; -trmm_t* trmm_r_cntl_op_bp; -trmm_t* trmm_r_cntl_mm_op; -trmm_t* trmm_r_cntl_vl_mm; - trmm_t* trmm_l_cntl; -trmm_t* trmm_r_cntl; void bli_trmm_cntl_init() @@ -74,10 +66,10 @@ void bli_trmm_cntl_init() = bli_packm_cntl_obj_create( BLIS_BLOCKED, BLIS_VARIANT1, - // IMPORTANT: for consistency with trsm, "k" dim - // multiple is set to mr. - gemm_mr, + // IMPORTANT: Unlike trsm, trmm does not require a + // "k" dim multiple equal to mr. gemm_mr, + gemm_kr, TRUE, // densify FALSE, // do NOT invert diagonal FALSE, // reverse iteration if upper? @@ -89,40 +81,9 @@ void bli_trmm_cntl_init() = bli_packm_cntl_obj_create( BLIS_BLOCKED, BLIS_VARIANT1, - // IMPORTANT: m dim multiple here must be mr - // since "k" dim multiple is set to mr above. - gemm_mr, - gemm_nr, - FALSE, // already dense - FALSE, // do NOT invert diagonal - FALSE, // reverse iteration if upper? - FALSE, // reverse iteration if lower? - BLIS_PACKED_COL_PANELS, - BLIS_BUFFER_FOR_B_PANEL ); - - // Create control tree objects for packm operations (right side). - trmm_r_packa_cntl - = - bli_packm_cntl_obj_create( BLIS_BLOCKED, - BLIS_VARIANT1, - // IMPORTANT: for consistency with trsm, "k" dim - // multiple is set to nr. - gemm_mr, - gemm_nr, - FALSE, // already dense - FALSE, // do NOT invert diagonal - FALSE, // reverse iteration if upper? - FALSE, // reverse iteration if lower? - BLIS_PACKED_ROW_PANELS, - BLIS_BUFFER_FOR_A_BLOCK ); - - trmm_r_packb_cntl - = - bli_packm_cntl_obj_create( BLIS_BLOCKED, - BLIS_VARIANT1, - // IMPORTANT: m dim multiple here must be nr - // since "k" dim multiple is set to nr above. - gemm_nr, + // IMPORTANT: Unlike trsm, trmm does not require a + // "k" dim multiple equal to mr. + gemm_kr, gemm_nr, TRUE, // densify FALSE, // do NOT invert diagonal @@ -131,7 +92,6 @@ void bli_trmm_cntl_init() BLIS_PACKED_COL_PANELS, BLIS_BUFFER_FOR_B_PANEL ); - // Create control tree object for lowest-level block-panel kernel. trmm_cntl_bp_ke = @@ -190,74 +150,20 @@ void bli_trmm_cntl_init() NULL, NULL ); - // Create control tree object for outer panel (to block-panel) - // problem (right side). - trmm_r_cntl_op_bp - = - bli_trmm_cntl_obj_create( BLIS_BLOCKED, - BLIS_VARIANT1, - gemm_mc, - gemm_ukrs, - NULL, - trmm_r_packa_cntl, - trmm_r_packb_cntl, - NULL, - trmm_cntl_bp_ke, - gemm_cntl_bp_ke, - NULL ); - - // Create control tree object for general problem via multiple - // rank-k (outer panel) updates (right side). - trmm_r_cntl_mm_op - = - bli_trmm_cntl_obj_create( BLIS_BLOCKED, - BLIS_VARIANT3, - gemm_kc, - gemm_ukrs, - NULL, - NULL, - NULL, - NULL, - trmm_r_cntl_op_bp, - NULL, - NULL ); - - // Create control tree object for very large problem via multiple - // general problems (right side). - trmm_r_cntl_vl_mm - = - bli_trmm_cntl_obj_create( BLIS_BLOCKED, - BLIS_VARIANT2, - gemm_nc, - gemm_ukrs, - NULL, - NULL, - NULL, - NULL, - trmm_r_cntl_mm_op, - NULL, - NULL ); - // Alias the "master" trmm control trees to shorter names. trmm_l_cntl = trmm_l_cntl_vl_mm; - trmm_r_cntl = trmm_r_cntl_vl_mm; } void bli_trmm_cntl_finalize() { bli_cntl_obj_free( trmm_l_packa_cntl ); bli_cntl_obj_free( trmm_l_packb_cntl ); - bli_cntl_obj_free( trmm_r_packa_cntl ); - bli_cntl_obj_free( trmm_r_packb_cntl ); bli_cntl_obj_free( trmm_cntl_bp_ke ); bli_cntl_obj_free( trmm_l_cntl_op_bp ); bli_cntl_obj_free( trmm_l_cntl_mm_op ); bli_cntl_obj_free( trmm_l_cntl_vl_mm ); - bli_cntl_obj_free( trmm_r_cntl_op_bp ); - bli_cntl_obj_free( trmm_r_cntl_mm_op ); - bli_cntl_obj_free( trmm_r_cntl_vl_mm ); } trmm_t* bli_trmm_cntl_obj_create( impl_t impl_type, diff --git a/frame/3/trmm/bli_trmm_cntl.h b/frame/3/trmm/old/bli_trmm_cntl.h similarity index 98% rename from frame/3/trmm/bli_trmm_cntl.h rename to frame/3/trmm/old/bli_trmm_cntl.h index dfb1bd5d8..7c47014ca 100644 --- a/frame/3/trmm/bli_trmm_cntl.h +++ b/frame/3/trmm/old/bli_trmm_cntl.h @@ -48,7 +48,7 @@ struct trmm_s }; typedef struct trmm_s trmm_t; -#define cntl_sub_trmm( cntl ) cntl->sub_trmm +#define cntl_sub_gemm( cntl ) cntl->sub_trmm void bli_trmm_cntl_init( void ); void bli_trmm_cntl_finalize( void ); diff --git a/frame/3/trmm/other/bli_trmm_ll_blk_var1.c b/frame/3/trmm/other/bli_trmm_ll_blk_var1.c index 34ae5cf84..7c68ca917 100644 --- a/frame/3/trmm/other/bli_trmm_ll_blk_var1.c +++ b/frame/3/trmm/other/bli_trmm_ll_blk_var1.c @@ -115,7 +115,7 @@ void bli_trmm_ll_blk_var1( obj_t* alpha, &b_pack, beta, &c1_pack, - cntl_sub_trmm( cntl ) ); + cntl_sub_gemm( cntl ) ); // Unpack C1 (if C1 was packed). bli_unpackm_int( &c1_pack, &c1, diff --git a/frame/3/trmm/other/bli_trmm_ll_blk_var4.c b/frame/3/trmm/other/bli_trmm_ll_blk_var4.c index cf7765da4..3b69f3824 100644 --- a/frame/3/trmm/other/bli_trmm_ll_blk_var4.c +++ b/frame/3/trmm/other/bli_trmm_ll_blk_var4.c @@ -127,7 +127,7 @@ void bli_trmm_ll_blk_var4( obj_t* alpha, &b_pack_inc, beta, &c1_pack_inc, - cntl_sub_trmm( cntl ) ); + cntl_sub_gemm( cntl ) ); } // Unpack C1 (if C1 was packed). @@ -172,7 +172,7 @@ void bli_trmm_ll_blk_var4( obj_t* alpha, &b_pack, beta, &c1_pack, - cntl_sub_trmm( cntl ) ); + cntl_sub_gemm( cntl ) ); else bli_gemm_int( alpha, &a1_pack, diff --git a/frame/3/trmm/other/bli_trmm_lu_blk_var1.c b/frame/3/trmm/other/bli_trmm_lu_blk_var1.c index 1606790c5..0aab4b414 100644 --- a/frame/3/trmm/other/bli_trmm_lu_blk_var1.c +++ b/frame/3/trmm/other/bli_trmm_lu_blk_var1.c @@ -112,7 +112,7 @@ void bli_trmm_lu_blk_var1( obj_t* alpha, &b_pack, beta, &c1_pack, - cntl_sub_trmm( cntl ) ); + cntl_sub_gemm( cntl ) ); // Unpack C1 (if C1 was packed). bli_unpackm_int( &c1_pack, &c1, diff --git a/frame/3/trmm/other/bli_trmm_lu_blk_var4.c b/frame/3/trmm/other/bli_trmm_lu_blk_var4.c index d256fb89a..2d54f3d44 100644 --- a/frame/3/trmm/other/bli_trmm_lu_blk_var4.c +++ b/frame/3/trmm/other/bli_trmm_lu_blk_var4.c @@ -125,7 +125,7 @@ void bli_trmm_lu_blk_var4( obj_t* alpha, &b_pack_inc, beta, &c1_pack_inc, - cntl_sub_trmm( cntl ) ); + cntl_sub_gemm( cntl ) ); } // Unpack C1 (if C1 was packed). @@ -170,7 +170,7 @@ void bli_trmm_lu_blk_var4( obj_t* alpha, &b_pack, beta, &c1_pack, - cntl_sub_trmm( cntl ) ); + cntl_sub_gemm( cntl ) ); else bli_gemm_int( alpha, &a1_pack, diff --git a/frame/3/trmm3/3m/bli_trmm33m_entry.c b/frame/3/trmm3/3m/bli_trmm33m_entry.c index dd5196b2c..2cf4ed12e 100644 --- a/frame/3/trmm3/3m/bli_trmm33m_entry.c +++ b/frame/3/trmm3/3m/bli_trmm33m_entry.c @@ -34,8 +34,7 @@ #include "blis.h" -extern trmm_t* trmm3m_l_cntl; -extern trmm_t* trmm3m_r_cntl; +extern gemm_t* gemm3m_cntl; void bli_trmm33m_entry( side_t side, obj_t* alpha, @@ -45,7 +44,6 @@ void bli_trmm33m_entry( side_t side, obj_t* c ) { bli_trmm3_front( side, alpha, a, b, beta, c, - trmm3m_l_cntl, - trmm3m_r_cntl ); + gemm3m_cntl ); } diff --git a/frame/3/trmm3/4m/bli_trmm34m_entry.c b/frame/3/trmm3/4m/bli_trmm34m_entry.c index 20ebb1b69..c6a2b8b51 100644 --- a/frame/3/trmm3/4m/bli_trmm34m_entry.c +++ b/frame/3/trmm3/4m/bli_trmm34m_entry.c @@ -34,8 +34,7 @@ #include "blis.h" -extern trmm_t* trmm4m_l_cntl; -extern trmm_t* trmm4m_r_cntl; +extern gemm_t* gemm4m_cntl; void bli_trmm34m_entry( side_t side, obj_t* alpha, @@ -45,7 +44,6 @@ void bli_trmm34m_entry( side_t side, obj_t* c ) { bli_trmm3_front( side, alpha, a, b, beta, c, - trmm4m_l_cntl, - trmm4m_r_cntl ); + gemm4m_cntl ); } diff --git a/frame/3/trmm3/bli_trmm3_entry.c b/frame/3/trmm3/bli_trmm3_entry.c index 3e68e48f0..1e243d609 100644 --- a/frame/3/trmm3/bli_trmm3_entry.c +++ b/frame/3/trmm3/bli_trmm3_entry.c @@ -34,8 +34,7 @@ #include "blis.h" -extern trmm_t* trmm_l_cntl; -extern trmm_t* trmm_r_cntl; +extern gemm_t* gemm_cntl; void bli_trmm3_entry( side_t side, obj_t* alpha, @@ -45,7 +44,6 @@ void bli_trmm3_entry( side_t side, obj_t* c ) { bli_trmm3_front( side, alpha, a, b, beta, c, - trmm_l_cntl, - trmm_r_cntl ); + gemm_cntl ); } diff --git a/frame/3/trmm3/bli_trmm3_front.c b/frame/3/trmm3/bli_trmm3_front.c index e8ab2b942..12b747bc9 100644 --- a/frame/3/trmm3/bli_trmm3_front.c +++ b/frame/3/trmm3/bli_trmm3_front.c @@ -40,10 +40,8 @@ void bli_trmm3_front( side_t side, obj_t* b, obj_t* beta, obj_t* c, - trmm_t* l_cntl, - trmm_t* r_cntl ) + gemm_t* cntl ) { - trmm_t* cntl; obj_t a_local; obj_t b_local; obj_t c_local; @@ -103,10 +101,10 @@ void bli_trmm3_front( side_t side, if ( ( bli_obj_is_row_stored( c_local ) && bli_func_prefers_contig_cols( bli_obj_datatype( c_local ), - cntl_gemm_ukrs( l_cntl ) ) ) || + cntl_gemm_ukrs( cntl ) ) ) || ( bli_obj_is_col_stored( c_local ) && bli_func_prefers_contig_rows( bli_obj_datatype( c_local ), - cntl_gemm_ukrs( l_cntl ) ) ) + cntl_gemm_ukrs( cntl ) ) ) ) { bli_toggle_side( side ); @@ -131,9 +129,6 @@ void bli_trmm3_front( side_t side, bli_obj_set_as_root( b_local ); bli_obj_set_as_root( c_local ); - // Choose the control tree. - if ( bli_is_left( side ) ) cntl = l_cntl; - else cntl = r_cntl; trmm_thrinfo_t** infos = bli_create_trmm_thrinfo_paths( FALSE ); dim_t n_threads = thread_num_threads( infos[0] ); diff --git a/frame/3/trmm3/bli_trmm3_front.h b/frame/3/trmm3/bli_trmm3_front.h index 4a169cc9f..f6ebbf27d 100644 --- a/frame/3/trmm3/bli_trmm3_front.h +++ b/frame/3/trmm3/bli_trmm3_front.h @@ -38,5 +38,4 @@ void bli_trmm3_front( side_t side, obj_t* b, obj_t* beta, obj_t* c, - trmm_t* l_cntl, - trmm_t* r_cntl ); + gemm_t* cntl ); diff --git a/frame/cntl/bli_cntl_init.c b/frame/cntl/bli_cntl_init.c index 71eb951c3..3f885cd7e 100644 --- a/frame/cntl/bli_cntl_init.c +++ b/frame/cntl/bli_cntl_init.c @@ -57,17 +57,14 @@ void bli_cntl_init( void ) // Level-3 bli_gemm_cntl_init(); - bli_trmm_cntl_init(); bli_trsm_cntl_init(); // Level-3 via 4m bli_gemm4m_cntl_init(); - bli_trmm4m_cntl_init(); bli_trsm4m_cntl_init(); // Level-3 via 3m bli_gemm3m_cntl_init(); - bli_trmm3m_cntl_init(); bli_trsm3m_cntl_init(); } @@ -94,17 +91,14 @@ void bli_cntl_finalize( void ) // Level-3 bli_gemm_cntl_finalize(); - bli_trmm_cntl_finalize(); bli_trsm_cntl_finalize(); // Level-3 via 4m bli_gemm4m_cntl_finalize(); - bli_trmm4m_cntl_finalize(); bli_trsm4m_cntl_finalize(); // Level-3 via 3m bli_gemm3m_cntl_finalize(); - bli_trmm3m_cntl_finalize(); bli_trsm3m_cntl_finalize(); }