From 963d0393b023f4134bb0c682923faf9964c0e645 Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Mon, 25 Jul 2016 14:40:53 -0500 Subject: [PATCH] Add 24xk pack kernel. --- frame/1m/bli_l1m_ker.h | 1 + frame/1m/packm/bli_packm_cxk.c | 5 +- frame/1m/packm/ukernels/bli_packm_cxk_ref.c | 482 +++++++++++++------- frame/1m/packm/ukernels/bli_packm_cxk_ref.h | 2 + frame/include/bli_kernel_macro_defs.h | 18 + frame/include/bli_kernel_pre_macro_defs.h | 7 + kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c | 138 ++++-- 7 files changed, 448 insertions(+), 205 deletions(-) diff --git a/frame/1m/bli_l1m_ker.h b/frame/1m/bli_l1m_ker.h index 794609f44..9e9dbbdb7 100644 --- a/frame/1m/bli_l1m_ker.h +++ b/frame/1m/bli_l1m_ker.h @@ -60,6 +60,7 @@ INSERT_GENTPROT_BASIC( packm_10xk_ker_name ) INSERT_GENTPROT_BASIC( packm_12xk_ker_name ) INSERT_GENTPROT_BASIC( packm_14xk_ker_name ) INSERT_GENTPROT_BASIC( packm_16xk_ker_name ) +INSERT_GENTPROT_BASIC( packm_24xk_ker_name ) INSERT_GENTPROT_BASIC( packm_30xk_ker_name ) diff --git a/frame/1m/packm/bli_packm_cxk.c b/frame/1m/packm/bli_packm_cxk.c index 3c2ab6fd0..4f167c355 100644 --- a/frame/1m/packm/bli_packm_cxk.c +++ b/frame/1m/packm/bli_packm_cxk.c @@ -166,7 +166,10 @@ static FUNCPTR_T ftypes[FUNCPTR_ARRAY_LENGTH][BLIS_NUM_FP_TYPES] = }, /* micro-panel width = 24 */ { - NULL, NULL, NULL, NULL, + BLIS_SPACKM_24XK_KERNEL, + BLIS_CPACKM_24XK_KERNEL, + BLIS_DPACKM_24XK_KERNEL, + BLIS_ZPACKM_24XK_KERNEL, }, /* micro-panel width = 25 */ { diff --git a/frame/1m/packm/ukernels/bli_packm_cxk_ref.c b/frame/1m/packm/ukernels/bli_packm_cxk_ref.c index b33df08cf..c3dbb8d5c 100644 --- a/frame/1m/packm/ukernels/bli_packm_cxk_ref.c +++ b/frame/1m/packm/ukernels/bli_packm_cxk_ref.c @@ -1014,172 +1014,332 @@ void PASTEMAC(ch,varname) \ void* restrict p, inc_t ldp \ ) \ { \ - ctype* restrict kappa_cast = kappa; \ - ctype* restrict alpha1 = a; \ - ctype* restrict pi1 = p; \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict alpha1 = a; \ + ctype* restrict pi1 = p; \ \ - if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( ; n != 0; --n ) \ - { \ - PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 1) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 2) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 3) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 4) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 5) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 + 6*inca), *(pi1 + 6) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 + 7*inca), *(pi1 + 7) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 + 8*inca), *(pi1 + 8) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 + 9*inca), *(pi1 + 9) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +10*inca), *(pi1 +10) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +11*inca), *(pi1 +11) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +12*inca), *(pi1 +12) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +13*inca), *(pi1 +13) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +14*inca), *(pi1 +14) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +15*inca), *(pi1 +15) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +16*inca), *(pi1 +16) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +17*inca), *(pi1 +17) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +18*inca), *(pi1 +18) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +19*inca), *(pi1 +19) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +20*inca), *(pi1 +20) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +21*inca), *(pi1 +21) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +22*inca), *(pi1 +22) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +23*inca), *(pi1 +23) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +24*inca), *(pi1 +24) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +25*inca), *(pi1 +25) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +26*inca), *(pi1 +26) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +27*inca), *(pi1 +27) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +28*inca), *(pi1 +28) ); \ - PASTEMAC(ch,copyjs)( *(alpha1 +29*inca), *(pi1 +29) ); \ + if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ + { \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( ; n != 0; --n ) \ + { \ + PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 6*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 7*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 8*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 9*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +10*inca), *(pi1 +10) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +11*inca), *(pi1 +11) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +12*inca), *(pi1 +12) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +13*inca), *(pi1 +13) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +14*inca), *(pi1 +14) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +15*inca), *(pi1 +15) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +16*inca), *(pi1 +16) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +17*inca), *(pi1 +17) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +18*inca), *(pi1 +18) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +19*inca), *(pi1 +19) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +20*inca), *(pi1 +20) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +21*inca), *(pi1 +21) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +22*inca), *(pi1 +22) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +23*inca), *(pi1 +23) ); \ \ - alpha1 += lda; \ - pi1 += ldp; \ - } \ - } \ - else \ - { \ - for ( ; n != 0; --n ) \ - { \ - PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ - PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 1) ); \ - PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 2) ); \ - PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 3) ); \ - PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 4) ); \ - PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 5) ); \ - PASTEMAC(ch,copys)( *(alpha1 + 6*inca), *(pi1 + 6) ); \ - PASTEMAC(ch,copys)( *(alpha1 + 7*inca), *(pi1 + 7) ); \ - PASTEMAC(ch,copys)( *(alpha1 + 8*inca), *(pi1 + 8) ); \ - PASTEMAC(ch,copys)( *(alpha1 + 9*inca), *(pi1 + 9) ); \ - PASTEMAC(ch,copys)( *(alpha1 +10*inca), *(pi1 +10) ); \ - PASTEMAC(ch,copys)( *(alpha1 +11*inca), *(pi1 +11) ); \ - PASTEMAC(ch,copys)( *(alpha1 +12*inca), *(pi1 +12) ); \ - PASTEMAC(ch,copys)( *(alpha1 +13*inca), *(pi1 +13) ); \ - PASTEMAC(ch,copys)( *(alpha1 +14*inca), *(pi1 +14) ); \ - PASTEMAC(ch,copys)( *(alpha1 +15*inca), *(pi1 +15) ); \ - PASTEMAC(ch,copys)( *(alpha1 +16*inca), *(pi1 +16) ); \ - PASTEMAC(ch,copys)( *(alpha1 +17*inca), *(pi1 +17) ); \ - PASTEMAC(ch,copys)( *(alpha1 +18*inca), *(pi1 +18) ); \ - PASTEMAC(ch,copys)( *(alpha1 +19*inca), *(pi1 +19) ); \ - PASTEMAC(ch,copys)( *(alpha1 +20*inca), *(pi1 +20) ); \ - PASTEMAC(ch,copys)( *(alpha1 +21*inca), *(pi1 +21) ); \ - PASTEMAC(ch,copys)( *(alpha1 +22*inca), *(pi1 +22) ); \ - PASTEMAC(ch,copys)( *(alpha1 +23*inca), *(pi1 +23) ); \ - PASTEMAC(ch,copys)( *(alpha1 +24*inca), *(pi1 +24) ); \ - PASTEMAC(ch,copys)( *(alpha1 +25*inca), *(pi1 +25) ); \ - PASTEMAC(ch,copys)( *(alpha1 +26*inca), *(pi1 +26) ); \ - PASTEMAC(ch,copys)( *(alpha1 +27*inca), *(pi1 +27) ); \ - PASTEMAC(ch,copys)( *(alpha1 +28*inca), *(pi1 +28) ); \ - PASTEMAC(ch,copys)( *(alpha1 +29*inca), *(pi1 +29) ); \ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + else \ + { \ + for ( ; n != 0; --n ) \ + { \ + PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 6*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 7*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 8*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 9*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,copys)( *(alpha1 +10*inca), *(pi1 +10) ); \ + PASTEMAC(ch,copys)( *(alpha1 +11*inca), *(pi1 +11) ); \ + PASTEMAC(ch,copys)( *(alpha1 +12*inca), *(pi1 +12) ); \ + PASTEMAC(ch,copys)( *(alpha1 +13*inca), *(pi1 +13) ); \ + PASTEMAC(ch,copys)( *(alpha1 +14*inca), *(pi1 +14) ); \ + PASTEMAC(ch,copys)( *(alpha1 +15*inca), *(pi1 +15) ); \ + PASTEMAC(ch,copys)( *(alpha1 +16*inca), *(pi1 +16) ); \ + PASTEMAC(ch,copys)( *(alpha1 +17*inca), *(pi1 +17) ); \ + PASTEMAC(ch,copys)( *(alpha1 +18*inca), *(pi1 +18) ); \ + PASTEMAC(ch,copys)( *(alpha1 +19*inca), *(pi1 +19) ); \ + PASTEMAC(ch,copys)( *(alpha1 +20*inca), *(pi1 +20) ); \ + PASTEMAC(ch,copys)( *(alpha1 +21*inca), *(pi1 +21) ); \ + PASTEMAC(ch,copys)( *(alpha1 +22*inca), *(pi1 +22) ); \ + PASTEMAC(ch,copys)( *(alpha1 +23*inca), *(pi1 +23) ); \ \ - alpha1 += lda; \ - pi1 += ldp; \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_conj( conja ) ) \ - { \ - for ( ; n != 0; --n ) \ - { \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 6*inca), *(pi1 + 6) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 7*inca), *(pi1 + 7) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 8*inca), *(pi1 + 8) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 9*inca), *(pi1 + 9) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +10*inca), *(pi1 +10) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +11*inca), *(pi1 +11) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +12*inca), *(pi1 +12) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +13*inca), *(pi1 +13) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +14*inca), *(pi1 +14) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +15*inca), *(pi1 +15) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +16*inca), *(pi1 +16) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +17*inca), *(pi1 +17) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +18*inca), *(pi1 +18) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +19*inca), *(pi1 +19) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +20*inca), *(pi1 +20) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +21*inca), *(pi1 +21) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +22*inca), *(pi1 +22) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +23*inca), *(pi1 +23) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +24*inca), *(pi1 +24) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +25*inca), *(pi1 +25) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +26*inca), *(pi1 +26) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +27*inca), *(pi1 +27) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +28*inca), *(pi1 +28) ); \ - PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +29*inca), *(pi1 +29) ); \ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + } \ + else \ + { \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( ; n != 0; --n ) \ + { \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 6*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 7*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 8*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 9*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +10*inca), *(pi1 +10) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +11*inca), *(pi1 +11) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +12*inca), *(pi1 +12) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +13*inca), *(pi1 +13) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +14*inca), *(pi1 +14) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +15*inca), *(pi1 +15) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +16*inca), *(pi1 +16) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +17*inca), *(pi1 +17) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +18*inca), *(pi1 +18) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +19*inca), *(pi1 +19) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +20*inca), *(pi1 +20) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +21*inca), *(pi1 +21) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +22*inca), *(pi1 +22) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +23*inca), *(pi1 +23) ); \ \ - alpha1 += lda; \ - pi1 += ldp; \ - } \ - } \ - else \ - { \ - for ( ; n != 0; --n ) \ - { \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 6*inca), *(pi1 + 6) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 7*inca), *(pi1 + 7) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 8*inca), *(pi1 + 8) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 9*inca), *(pi1 + 9) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +10*inca), *(pi1 +10) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +11*inca), *(pi1 +11) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +12*inca), *(pi1 +12) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +13*inca), *(pi1 +13) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +14*inca), *(pi1 +14) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +15*inca), *(pi1 +15) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +16*inca), *(pi1 +16) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +17*inca), *(pi1 +17) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +18*inca), *(pi1 +18) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +19*inca), *(pi1 +19) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +20*inca), *(pi1 +20) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +21*inca), *(pi1 +21) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +22*inca), *(pi1 +22) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +23*inca), *(pi1 +23) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +24*inca), *(pi1 +24) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +25*inca), *(pi1 +25) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +26*inca), *(pi1 +26) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +27*inca), *(pi1 +27) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +28*inca), *(pi1 +28) ); \ - PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +29*inca), *(pi1 +29) ); \ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + else \ + { \ + for ( ; n != 0; --n ) \ + { \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 6*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 7*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 8*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 9*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +10*inca), *(pi1 +10) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +11*inca), *(pi1 +11) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +12*inca), *(pi1 +12) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +13*inca), *(pi1 +13) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +14*inca), *(pi1 +14) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +15*inca), *(pi1 +15) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +16*inca), *(pi1 +16) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +17*inca), *(pi1 +17) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +18*inca), *(pi1 +18) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +19*inca), *(pi1 +19) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +20*inca), *(pi1 +20) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +21*inca), *(pi1 +21) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +22*inca), *(pi1 +22) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +23*inca), *(pi1 +23) ); \ \ - alpha1 += lda; \ - pi1 += ldp; \ - } \ - } \ - } \ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( packm_24xk_ref ) + + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + conj_t conja, \ + dim_t n, \ + void* restrict kappa, \ + void* restrict a, inc_t inca, inc_t lda, \ + void* restrict p, inc_t ldp \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict alpha1 = a; \ + ctype* restrict pi1 = p; \ +\ + if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \ + { \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( ; n != 0; --n ) \ + { \ + PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 6*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 7*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 8*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 + 9*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +10*inca), *(pi1 +10) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +11*inca), *(pi1 +11) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +12*inca), *(pi1 +12) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +13*inca), *(pi1 +13) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +14*inca), *(pi1 +14) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +15*inca), *(pi1 +15) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +16*inca), *(pi1 +16) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +17*inca), *(pi1 +17) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +18*inca), *(pi1 +18) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +19*inca), *(pi1 +19) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +20*inca), *(pi1 +20) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +21*inca), *(pi1 +21) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +22*inca), *(pi1 +22) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +23*inca), *(pi1 +23) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +24*inca), *(pi1 +24) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +25*inca), *(pi1 +25) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +26*inca), *(pi1 +26) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +27*inca), *(pi1 +27) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +28*inca), *(pi1 +28) ); \ + PASTEMAC(ch,copyjs)( *(alpha1 +29*inca), *(pi1 +29) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + else \ + { \ + for ( ; n != 0; --n ) \ + { \ + PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 6*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 7*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 8*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,copys)( *(alpha1 + 9*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,copys)( *(alpha1 +10*inca), *(pi1 +10) ); \ + PASTEMAC(ch,copys)( *(alpha1 +11*inca), *(pi1 +11) ); \ + PASTEMAC(ch,copys)( *(alpha1 +12*inca), *(pi1 +12) ); \ + PASTEMAC(ch,copys)( *(alpha1 +13*inca), *(pi1 +13) ); \ + PASTEMAC(ch,copys)( *(alpha1 +14*inca), *(pi1 +14) ); \ + PASTEMAC(ch,copys)( *(alpha1 +15*inca), *(pi1 +15) ); \ + PASTEMAC(ch,copys)( *(alpha1 +16*inca), *(pi1 +16) ); \ + PASTEMAC(ch,copys)( *(alpha1 +17*inca), *(pi1 +17) ); \ + PASTEMAC(ch,copys)( *(alpha1 +18*inca), *(pi1 +18) ); \ + PASTEMAC(ch,copys)( *(alpha1 +19*inca), *(pi1 +19) ); \ + PASTEMAC(ch,copys)( *(alpha1 +20*inca), *(pi1 +20) ); \ + PASTEMAC(ch,copys)( *(alpha1 +21*inca), *(pi1 +21) ); \ + PASTEMAC(ch,copys)( *(alpha1 +22*inca), *(pi1 +22) ); \ + PASTEMAC(ch,copys)( *(alpha1 +23*inca), *(pi1 +23) ); \ + PASTEMAC(ch,copys)( *(alpha1 +24*inca), *(pi1 +24) ); \ + PASTEMAC(ch,copys)( *(alpha1 +25*inca), *(pi1 +25) ); \ + PASTEMAC(ch,copys)( *(alpha1 +26*inca), *(pi1 +26) ); \ + PASTEMAC(ch,copys)( *(alpha1 +27*inca), *(pi1 +27) ); \ + PASTEMAC(ch,copys)( *(alpha1 +28*inca), *(pi1 +28) ); \ + PASTEMAC(ch,copys)( *(alpha1 +29*inca), *(pi1 +29) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + } \ + else \ + { \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( ; n != 0; --n ) \ + { \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 6*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 7*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 8*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 9*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +10*inca), *(pi1 +10) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +11*inca), *(pi1 +11) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +12*inca), *(pi1 +12) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +13*inca), *(pi1 +13) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +14*inca), *(pi1 +14) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +15*inca), *(pi1 +15) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +16*inca), *(pi1 +16) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +17*inca), *(pi1 +17) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +18*inca), *(pi1 +18) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +19*inca), *(pi1 +19) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +20*inca), *(pi1 +20) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +21*inca), *(pi1 +21) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +22*inca), *(pi1 +22) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +23*inca), *(pi1 +23) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +24*inca), *(pi1 +24) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +25*inca), *(pi1 +25) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +26*inca), *(pi1 +26) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +27*inca), *(pi1 +27) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +28*inca), *(pi1 +28) ); \ + PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +29*inca), *(pi1 +29) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + else \ + { \ + for ( ; n != 0; --n ) \ + { \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 6*inca), *(pi1 + 6) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 7*inca), *(pi1 + 7) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 8*inca), *(pi1 + 8) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 9*inca), *(pi1 + 9) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +10*inca), *(pi1 +10) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +11*inca), *(pi1 +11) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +12*inca), *(pi1 +12) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +13*inca), *(pi1 +13) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +14*inca), *(pi1 +14) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +15*inca), *(pi1 +15) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +16*inca), *(pi1 +16) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +17*inca), *(pi1 +17) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +18*inca), *(pi1 +18) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +19*inca), *(pi1 +19) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +20*inca), *(pi1 +20) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +21*inca), *(pi1 +21) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +22*inca), *(pi1 +22) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +23*inca), *(pi1 +23) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +24*inca), *(pi1 +24) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +25*inca), *(pi1 +25) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +26*inca), *(pi1 +26) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +27*inca), *(pi1 +27) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +28*inca), *(pi1 +28) ); \ + PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +29*inca), *(pi1 +29) ); \ +\ + alpha1 += lda; \ + pi1 += ldp; \ + } \ + } \ + } \ } INSERT_GENTFUNC_BASIC0( packm_30xk_ref ) diff --git a/frame/1m/packm/ukernels/bli_packm_cxk_ref.h b/frame/1m/packm/ukernels/bli_packm_cxk_ref.h index 9a55e20ea..a20573cdb 100644 --- a/frame/1m/packm/ukernels/bli_packm_cxk_ref.h +++ b/frame/1m/packm/ukernels/bli_packm_cxk_ref.h @@ -52,6 +52,8 @@ #define packm_14xk_ker_name packm_14xk_ref #undef packm_16xk_ker_name #define packm_16xk_ker_name packm_16xk_ref +#undef packm_24xk_ker_name +#define packm_24xk_ker_name packm_24xk_ref #undef packm_30xk_ker_name #define packm_30xk_ker_name packm_30xk_ref diff --git a/frame/include/bli_kernel_macro_defs.h b/frame/include/bli_kernel_macro_defs.h index 00a2aa4b9..305253241 100644 --- a/frame/include/bli_kernel_macro_defs.h +++ b/frame/include/bli_kernel_macro_defs.h @@ -445,6 +445,24 @@ #define BLIS_ZPACKM_16XK_KERNEL BLIS_ZPACKM_16XK_KERNEL_REF #endif +// packm_24xk kernels + +#ifndef BLIS_SPACKM_24XK_KERNEL +#define BLIS_SPACKM_24XK_KERNEL BLIS_SPACKM_24XK_KERNEL_REF +#endif + +#ifndef BLIS_DPACKM_24XK_KERNEL +#define BLIS_DPACKM_24XK_KERNEL BLIS_DPACKM_24XK_KERNEL_REF +#endif + +#ifndef BLIS_CPACKM_24XK_KERNEL +#define BLIS_CPACKM_24XK_KERNEL BLIS_CPACKM_24XK_KERNEL_REF +#endif + +#ifndef BLIS_ZPACKM_24XK_KERNEL +#define BLIS_ZPACKM_24XK_KERNEL BLIS_ZPACKM_24XK_KERNEL_REF +#endif + // packm_30xk kernels #ifndef BLIS_SPACKM_30XK_KERNEL diff --git a/frame/include/bli_kernel_pre_macro_defs.h b/frame/include/bli_kernel_pre_macro_defs.h index 98e4c3928..81f1deb98 100644 --- a/frame/include/bli_kernel_pre_macro_defs.h +++ b/frame/include/bli_kernel_pre_macro_defs.h @@ -143,6 +143,13 @@ #define BLIS_CPACKM_16XK_KERNEL_REF bli_cpackm_16xk_ref #define BLIS_ZPACKM_16XK_KERNEL_REF bli_zpackm_16xk_ref +// packm_24xk kernels + +#define BLIS_SPACKM_24XK_KERNEL_REF bli_spackm_24xk_ref +#define BLIS_DPACKM_24XK_KERNEL_REF bli_dpackm_24xk_ref +#define BLIS_CPACKM_24XK_KERNEL_REF bli_cpackm_24xk_ref +#define BLIS_ZPACKM_24XK_KERNEL_REF bli_zpackm_24xk_ref + // packm_30xk kernels #define BLIS_SPACKM_30XK_KERNEL_REF bli_spackm_30xk_ref diff --git a/kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c b/kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c index 92c86a02e..5846f4a82 100644 --- a/kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c +++ b/kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c @@ -139,37 +139,37 @@ // #define SUBITER(n,a,b,...) \ \ - VMOVAPD(ZMM(a), MEM(RBX,(n+1)*64)) \ + VMOVAPD(ZMM(0), MEM(RBX,(n+1-1)*64)) \ \ PREFETCH_A_L1(n) \ PREFETCH_B_L1(n) \ PREFETCH_A_L2(n) \ PREFETCH_B_L2(n) \ \ - VFMADD231PD(ZMM( 8), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 0)*8)) \ - VFMADD231PD(ZMM( 9), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 1)*8)) \ - VFMADD231PD(ZMM(10), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 2)*8)) \ - VFMADD231PD(ZMM(11), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 3)*8)) \ - VFMADD231PD(ZMM(12), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 4)*8)) \ - VFMADD231PD(ZMM(13), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 5)*8)) \ - VFMADD231PD(ZMM(14), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 6)*8)) \ - VFMADD231PD(ZMM(15), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 7)*8)) \ - VFMADD231PD(ZMM(16), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 8)*8)) \ - VFMADD231PD(ZMM(17), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 9)*8)) \ - VFMADD231PD(ZMM(18), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+10)*8)) \ - VFMADD231PD(ZMM(19), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+11)*8)) \ - VFMADD231PD(ZMM(20), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+12)*8)) \ - VFMADD231PD(ZMM(21), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+13)*8)) \ - VFMADD231PD(ZMM(22), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+14)*8)) \ - VFMADD231PD(ZMM(23), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+15)*8)) \ - VFMADD231PD(ZMM(24), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+16)*8)) \ - VFMADD231PD(ZMM(25), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+17)*8)) \ - VFMADD231PD(ZMM(26), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+18)*8)) \ - VFMADD231PD(ZMM(27), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+19)*8)) \ - VFMADD231PD(ZMM(28), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+20)*8)) \ - VFMADD231PD(ZMM(29), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+21)*8)) \ - VFMADD231PD(ZMM(30), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+22)*8)) \ - VFMADD231PD(ZMM(31), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+23)*8)) + VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 0)*8)) \ + VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 1)*8)) \ + VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 2)*8)) \ + VFMADD231PD(ZMM(11), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 3)*8)) \ + VFMADD231PD(ZMM(12), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 4)*8)) \ + VFMADD231PD(ZMM(13), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 5)*8)) \ + VFMADD231PD(ZMM(14), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 6)*8)) \ + VFMADD231PD(ZMM(15), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 7)*8)) \ + VFMADD231PD(ZMM(16), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 8)*8)) \ + VFMADD231PD(ZMM(17), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 9)*8)) \ + VFMADD231PD(ZMM(18), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+10)*8)) \ + VFMADD231PD(ZMM(19), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+11)*8)) \ + VFMADD231PD(ZMM(20), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+12)*8)) \ + VFMADD231PD(ZMM(21), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+13)*8)) \ + VFMADD231PD(ZMM(22), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+14)*8)) \ + VFMADD231PD(ZMM(23), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+15)*8)) \ + VFMADD231PD(ZMM(24), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+16)*8)) \ + VFMADD231PD(ZMM(25), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+17)*8)) \ + VFMADD231PD(ZMM(26), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+18)*8)) \ + VFMADD231PD(ZMM(27), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+19)*8)) \ + VFMADD231PD(ZMM(28), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+20)*8)) \ + VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+21)*8)) \ + VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+22)*8)) \ + VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+23)*8)) #define TAIL_LOOP(NAME) \ \ @@ -418,6 +418,22 @@ void bli_dgemm_opt_24x8( cntx_t* restrict cntx ) { + /* + for (dim_t i = 0;i < 24;i++) + { + for (dim_t j = 0;j < 8;j++) + { + c[i*rs_c+j*cs_c] *= *beta; + for (dim_t p = 0;p < k;p++) + { + c[i*rs_c+j*cs_c] += (*alpha)*a[i+p*24]*b[j+p*8]; + } + } + } + + return; + */ + const double * a_next = bli_auxinfo_next_a( data ); const double * b_next = bli_auxinfo_next_b( data ); @@ -446,21 +462,21 @@ void bli_dgemm_opt_24x8( VMOVAPS(ZMM(11), ZMM(8)) MOV(RAX, VAR(a)) //load address of a VMOVAPS(ZMM(12), ZMM(8)) VMOVAPS(ZMM(13), ZMM(8)) MOV(RBX, VAR(b)) //load address of b - VMOVAPS(ZMM(14), ZMM(8)) VMOVAPD(ZMM(0), MEM(RBX)) //pre-load b + VMOVAPS(ZMM(14), ZMM(8)) //VMOVAPD(ZMM(0), MEM(RBX)) //pre-load b VMOVAPS(ZMM(15), ZMM(8)) VMOVAPS(ZMM(16), ZMM(8)) MOV(RCX, VAR(c)) //load address of c VMOVAPS(ZMM(17), ZMM(8)) //set up indexing information for prefetching C - VMOVAPS(ZMM(18), ZMM(8)) MOV(RDI, VAR(offsetPtr)) - VMOVAPS(ZMM(19), ZMM(8)) VBROADCASTSS(ZMM(4), VAR(rs_c)) - VMOVAPS(ZMM(20), ZMM(8)) VMOVAPS(ZMM(2), MEM(RDI)) //at this point zmm2 contains (0...15) - VMOVAPS(ZMM(21), ZMM(8)) VPMULLD(ZMM(2), ZMM(2), ZMM(4)) //and now zmm2 contains (rs_c*0...15) - VMOVAPS(ZMM(22), ZMM(8)) VMOVAPS(YMM(3), MEM(RDI,64)) //at this point ymm3 contains (16...23) - VMOVAPS(ZMM(23), ZMM(8)) VPMULLD(YMM(3), YMM(3), YMM(4)) //and now ymm3 contains (rs_c*16...23) + VMOVAPS(ZMM(18), ZMM(8)) //MOV(RDI, VAR(offsetPtr)) + VMOVAPS(ZMM(19), ZMM(8)) //VBROADCASTSS(ZMM(4), VAR(rs_c)) + VMOVAPS(ZMM(20), ZMM(8)) //VMOVAPS(ZMM(2), MEM(RDI)) //at this point zmm2 contains (0...15) + VMOVAPS(ZMM(21), ZMM(8)) //VPMULLD(ZMM(2), ZMM(2), ZMM(4)) //and now zmm2 contains (rs_c*0...15) + VMOVAPS(ZMM(22), ZMM(8)) //VMOVAPS(YMM(3), MEM(RDI,64)) //at this point ymm3 contains (16...23) + VMOVAPS(ZMM(23), ZMM(8)) //VPMULLD(YMM(3), YMM(3), YMM(4)) //and now ymm3 contains (rs_c*16...23) VMOVAPS(ZMM(24), ZMM(8)) - VMOVAPS(ZMM(25), ZMM(8)) MOV(R8, IMM(4*24*8)) //offset for 4 iterations - VMOVAPS(ZMM(26), ZMM(8)) LEA(R9, MEM(R8,R8,2)) //*3 - VMOVAPS(ZMM(27), ZMM(8)) LEA(R10, MEM(R8,R8,4)) //*5 - VMOVAPS(ZMM(28), ZMM(8)) LEA(R11, MEM(R9,R8,4)) //*7 + VMOVAPS(ZMM(25), ZMM(8)) //MOV(R8, IMM(4*24*8)) //offset for 4 iterations + VMOVAPS(ZMM(26), ZMM(8)) //LEA(R9, MEM(R8,R8,2)) //*3 + VMOVAPS(ZMM(27), ZMM(8)) //LEA(R10, MEM(R8,R8,4)) //*5 + VMOVAPS(ZMM(28), ZMM(8)) //LEA(R11, MEM(R9,R8,4)) //*7 VMOVAPS(ZMM(29), ZMM(8)) VMOVAPS(ZMM(30), ZMM(8)) VMOVAPS(ZMM(31), ZMM(8)) @@ -487,7 +503,7 @@ void bli_dgemm_opt_24x8( //MOV(RSI, IMM(0+C_L1_ITERS)) - LABEL(PREFETCH_C_L1) + //LABEL(PREFETCH_C_L1) //prefetch C into L1 //KXNORW(K(1), K(0), K(0)) @@ -495,7 +511,43 @@ void bli_dgemm_opt_24x8( //VSCATTERPFDPS(0, MEM(RCX,ZMM(2),8) MASK_K(1)) //VSCATTERPFDPD(0, MEM(RCX,YMM(3),8) MASK_K(2)) - MAIN_LOOP_L1 + //MAIN_LOOP_L1 + + LABEL(MAINLOOP) + + VMOVAPD(ZMM(0), MEM(RBX)) + + VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(RAX, 0*8)) + VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(RAX, 1*8)) + VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(RAX, 2*8)) + VFMADD231PD(ZMM(11), ZMM(0), MEM_1TO8(RAX, 3*8)) + VFMADD231PD(ZMM(12), ZMM(0), MEM_1TO8(RAX, 4*8)) + VFMADD231PD(ZMM(13), ZMM(0), MEM_1TO8(RAX, 5*8)) + VFMADD231PD(ZMM(14), ZMM(0), MEM_1TO8(RAX, 6*8)) + VFMADD231PD(ZMM(15), ZMM(0), MEM_1TO8(RAX, 7*8)) + VFMADD231PD(ZMM(16), ZMM(0), MEM_1TO8(RAX, 8*8)) + VFMADD231PD(ZMM(17), ZMM(0), MEM_1TO8(RAX, 9*8)) + VFMADD231PD(ZMM(18), ZMM(0), MEM_1TO8(RAX,10*8)) + VFMADD231PD(ZMM(19), ZMM(0), MEM_1TO8(RAX,11*8)) + VFMADD231PD(ZMM(20), ZMM(0), MEM_1TO8(RAX,12*8)) + VFMADD231PD(ZMM(21), ZMM(0), MEM_1TO8(RAX,13*8)) + VFMADD231PD(ZMM(22), ZMM(0), MEM_1TO8(RAX,14*8)) + VFMADD231PD(ZMM(23), ZMM(0), MEM_1TO8(RAX,15*8)) + VFMADD231PD(ZMM(24), ZMM(0), MEM_1TO8(RAX,16*8)) + VFMADD231PD(ZMM(25), ZMM(0), MEM_1TO8(RAX,17*8)) + VFMADD231PD(ZMM(26), ZMM(0), MEM_1TO8(RAX,18*8)) + VFMADD231PD(ZMM(27), ZMM(0), MEM_1TO8(RAX,19*8)) + VFMADD231PD(ZMM(28), ZMM(0), MEM_1TO8(RAX,20*8)) + VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(RAX,21*8)) + VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(RAX,22*8)) + VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(RAX,23*8)) + + ADD(RAX, IMM(24*8)) + ADD(RBX, IMM( 8*8)) + + SUB(RSI, IMM(1)) + + JNZ(MAINLOOP) LABEL(POSTACCUM) @@ -505,8 +557,10 @@ void bli_dgemm_opt_24x8( MOV(VAR(mid2h), EDX) #endif - VBROADCASTSD(ZMM(0), VAR(alpha)) - VBROADCASTSD(ZMM(1), VAR(beta)) + MOV(RAX, VAR(alpha)) + MOV(RBX, VAR(beta)) + VBROADCASTSD(ZMM(0), MEM(RAX)) + VBROADCASTSD(ZMM(1), MEM(RBX)) // Check if C is row stride. If not, jump to the slow scattered update MOV(RAX, VAR(rs_c)) @@ -519,7 +573,6 @@ void bli_dgemm_opt_24x8( VMOVQ(RDX, XMM(1)) SAL1(RDX) //shift out sign bit //JZ(COLSTORBZ) - JMP(COLSTORBZ) UPDATE_C_FOUR_ROWS( 8, 9,10,11) UPDATE_C_FOUR_ROWS(12,13,14,15) @@ -546,13 +599,12 @@ void bli_dgemm_opt_24x8( MOV(RDI, VAR(offsetPtr)) VMOVAPS(ZMM(2), MEM(RDI)) /* Note that this ignores the upper 32 bits in cs_c */ - VPBROADCASTD(ZMM(3), EBX) + VBROADCASTSS(ZMM(3), VAR(cs_c)) VPMULLD(ZMM(2), ZMM(3), ZMM(2)) VMOVQ(RDX, XMM(1)) SAL1(RDX) //shift out sign bit //JZ(SCATTERBZ) - JMP(SCATTERBZ) UPDATE_C_ROW_SCATTERED( 8) UPDATE_C_ROW_SCATTERED( 9)