From 7038bbaa05484141195822291cf3ba88cbce4980 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Fri, 4 Dec 2020 16:08:15 -0600 Subject: [PATCH] Optionally disable trsm diagonal pre-inversion. Details: - Implemented a configure-time option, --disable-trsm-preinversion, that optionally disables the pre-inversion of diagonal elements of the triangular matrix in the trsm operation and instead uses division instructions within the gemmtrsm microkernels. Pre-inversion is enabled by default. When it is disabled, performance may suffer slightly, but numerical robustness should improve for certain pathological cases involving denormal (subnormal) numbers that would otherwise result in overflow in the pre-inverted value. Thanks to Bhaskar Nallani for reporting this issue via #461. - Added preprocessor macro guards to bli_trsm_cntl.c as well as the gemmtrsm microkernels for 'haswell' and 'penryn' kernel sets pursuant to the aforementioned feature. - Added macros to frame/include/bli_x86_asm_macros.h related to division instructions. --- build/bli_config.h.in | 6 ++ configure | 29 +++++ frame/3/trsm/bli_trsm_cntl.c | 8 +- frame/include/bli_x86_asm_macros.h | 34 ++++-- .../3/bli_gemmtrsm_l_haswell_asm_d6x8.c | 100 ++++++++++++++---- .../3/bli_gemmtrsm_u_haswell_asm_d6x8.c | 100 ++++++++++++++---- .../penryn/3/bli_gemmtrsm_l_penryn_asm_d4x4.c | 24 ++++- .../penryn/3/bli_gemmtrsm_u_penryn_asm_d4x4.c | 24 ++++- 8 files changed, 271 insertions(+), 54 deletions(-) diff --git a/build/bli_config.h.in b/build/bli_config.h.in index d7f032dde..fa6bbbe12 100644 --- a/build/bli_config.h.in +++ b/build/bli_config.h.in @@ -153,6 +153,12 @@ #define BLIS_DISABLE_MEMKIND #endif +#if @enable_trsm_preinversion@ +#define BLIS_ENABLE_TRSM_PREINVERSION +#else +#define BLIS_DISABLE_TRSM_PREINVERSION +#endif + #if @enable_pragma_omp_simd@ #define BLIS_ENABLE_PRAGMA_OMP_SIMD #else diff --git a/configure b/configure index 5396d75d4..35a4dcee9 100755 --- a/configure +++ b/configure @@ -298,6 +298,20 @@ print_usage() echo " which may be ignored in select situations if the" echo " implementation has a good reason to do so." echo " " + echo " --disable-trsm-preinversion, --enable-trsm-preinversion" + echo " " + echo " Disable (enabled by default) pre-inversion of triangular" + echo " matrix diagonals when performing trsm. When pre-inversion" + echo " is enabled, diagonal elements are inverted outside of the" + echo " microkernel (e.g. during packing) so that the microkernel" + echo " can use multiply instructions. When disabled, division" + echo " instructions are used within the microkernel. Executing" + echo " these division instructions within the microkernel will" + echo " incur a performance penalty, but numerical robustness will" + echo " improve for certain cases involving denormal numbers that" + echo " would otherwise result in overflow in the pre-inverted" + echo " values." + echo " " echo " --force-version=STRING" echo " " echo " Force configure to use an arbitrary version string" @@ -2013,6 +2027,7 @@ main() enable_mixed_dt_extra_mem='yes' enable_sup_handling='yes' enable_memkind='' # The default memkind value is determined later on. + enable_trsm_preinversion='yes' force_version='no' complex_return='default' @@ -2210,6 +2225,12 @@ main() without-memkind) enable_memkind='no' ;; + enable-trsm-preinversion) + enable_trsm_preinversion='yes' + ;; + disable-trsm-preinversion) + enable_trsm_preinversion='no' + ;; force-version=*) force_version=${OPTARG#*=} ;; @@ -3048,6 +3069,13 @@ main() echo "${script_name}: small matrix handling is disabled." enable_sup_handling_01=0 fi + if [ "x${enable_trsm_preinversion}" = "xyes" ]; then + echo "${script_name}: trsm diagonal element pre-inversion is enabled." + enable_trsm_preinversion_01=1 + else + echo "${script_name}: trsm diagonal element pre-inversion is disabled." + enable_trsm_preinversion_01=0 + fi # Report integer sizes. if [ "x${int_type_size}" = "x32" ]; then @@ -3313,6 +3341,7 @@ main() | sed -e "s/@enable_mixed_dt_extra_mem@/${enable_mixed_dt_extra_mem_01}/g" \ | sed -e "s/@enable_sup_handling@/${enable_sup_handling_01}/g" \ | sed -e "s/@enable_memkind@/${enable_memkind_01}/g" \ + | sed -e "s/@enable_trsm_preinversion@/${enable_trsm_preinversion_01}/g" \ | sed -e "s/@enable_pragma_omp_simd@/${enable_pragma_omp_simd_01}/g" \ | sed -e "s/@enable_sandbox@/${enable_sandbox_01}/g" \ | sed -e "s/@enable_shared@/${enable_shared_01}/g" \ diff --git a/frame/3/trsm/bli_trsm_cntl.c b/frame/3/trsm/bli_trsm_cntl.c index 845370448..4a7a4de8f 100644 --- a/frame/3/trsm/bli_trsm_cntl.c +++ b/frame/3/trsm/bli_trsm_cntl.c @@ -99,7 +99,7 @@ cntl_t* bli_trsm_l_cntl_create packa_fp, BLIS_MR, BLIS_MR, - TRUE, // do NOT invert diagonal + FALSE, // do NOT invert diagonal TRUE, // reverse iteration if upper? FALSE, // reverse iteration if lower? schema_a, // normally BLIS_PACKED_ROW_PANELS @@ -137,7 +137,11 @@ cntl_t* bli_trsm_l_cntl_create packa_fp, BLIS_MR, BLIS_MR, - TRUE, // do NOT invert diagonal +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + TRUE, // invert diagonal +#else + FALSE, // do NOT invert diagonal +#endif TRUE, // reverse iteration if upper? FALSE, // reverse iteration if lower? schema_a, // normally BLIS_PACKED_ROW_PANELS diff --git a/frame/include/bli_x86_asm_macros.h b/frame/include/bli_x86_asm_macros.h index eca0b6959..a4987b4c5 100644 --- a/frame/include/bli_x86_asm_macros.h +++ b/frame/include/bli_x86_asm_macros.h @@ -855,8 +855,11 @@ #define SUBPD(_0, _1) INSTR_(subpd, _0, _1) #define MULPS(_0, _1) INSTR_(mulps, _0, _1) #define MULPD(_0, _1) INSTR_(mulpd, _0, _1) +#define DIVPS(_0, _1) INSTR_(divps, _0, _1) +#define DIVPD(_0, _1) INSTR_(divpd, _0, _1) #define XORPS(_0, _1) INSTR_(xorps, _0, _1) #define XORPD(_0, _1) INSTR_(xorpd, _0, _1) + #define UCOMISS(_0, _1) INSTR_(ucomiss, _0, _1) #define UCOMISD(_0, _1) INSTR_(ucomisd, _0, _1) #define COMISS(_0, _1) INSTR_(comiss, _0, _1) @@ -868,8 +871,11 @@ #define subpd(_0, _1) SUBPD(_0, _1) #define mulps(_0, _1) MULPS(_0, _1) #define mulpd(_0, _1) MULPD(_0, _1) +#define divps(_0, _1) DIVPS(_0, _1) +#define divpd(_0, _1) DIVPD(_0, _1) #define xorps(_0, _1) XORPS(_0, _1) #define xorpd(_0, _1) XORPD(_0, _1) + #define ucomiss(_0, _1) UCOMISS(_0, _1) #define ucomisd(_0, _1) UCOMISD(_0, _1) #define cmoiss(_0, _1) COMISS(_0, _1) @@ -879,10 +885,6 @@ #define VADDSUBPD(_0, _1, _2) INSTR_(vaddsubpd, _0, _1, _2) #define VHADDPD(_0, _1, _2) INSTR_(vhaddpd, _0, _1, _2) #define VHADDPS(_0, _1, _2) INSTR_(vhaddps, _0, _1, _2) -#define VUCOMISS(_0, _1) INSTR_(vucomiss, _0, _1) -#define VUCOMISD(_0, _1) INSTR_(vucomisd, _0, _1) -#define VCOMISS(_0, _1) INSTR_(vcomiss, _0, _1) -#define VCOMISD(_0, _1) INSTR_(vcomisd, _0, _1) #define VADDPS(_0, _1, _2) INSTR_(vaddps, _0, _1, _2) #define VADDPD(_0, _1, _2) INSTR_(vaddpd, _0, _1, _2) #define VSUBPS(_0, _1, _2) INSTR_(vsubps, _0, _1, _2) @@ -891,6 +893,10 @@ #define VMULSD(_0, _1, _2) INSTR_(vmulsd, _0, _1, _2) #define VMULPS(_0, _1, _2) INSTR_(vmulps, _0, _1, _2) #define VMULPD(_0, _1, _2) INSTR_(vmulpd, _0, _1, _2) +#define VDIVSS(_0, _1, _2) INSTR_(vdivss, _0, _1, _2) +#define VDIVSD(_0, _1, _2) INSTR_(vdivsd, _0, _1, _2) +#define VDIVPS(_0, _1, _2) INSTR_(vdivps, _0, _1, _2) +#define VDIVPD(_0, _1, _2) INSTR_(vdivpd, _0, _1, _2) #define VPMULLD(_0, _1, _2) INSTR_(vpmulld, _0, _1, _2) #define VPMULLQ(_0, _1, _2) INSTR_(vpmullq, _0, _1, _2) #define VPADDD(_0, _1, _2) INSTR_(vpaddd, _0, _1, _2) @@ -898,6 +904,12 @@ #define VXORPS(_0, _1, _2) INSTR_(vxorps, _0, _1, _2) #define VXORPD(_0, _1, _2) INSTR_(vxorpd, _0, _1, _2) #define VPXORD(_0, _1, _2) INSTR_(vpxord, _0, _1, _2) + +#define VUCOMISS(_0, _1) INSTR_(vucomiss, _0, _1) +#define VUCOMISD(_0, _1) INSTR_(vucomisd, _0, _1) +#define VCOMISS(_0, _1) INSTR_(vcomiss, _0, _1) +#define VCOMISD(_0, _1) INSTR_(vcomisd, _0, _1) + #define VFMADD132SS(_0, _1, _2) INSTR_(vfmadd132ss, _0, _1, _2) #define VFMADD213SS(_0, _1, _2) INSTR_(vfmadd213ss, _0, _1, _2) #define VFMADD231SS(_0, _1, _2) INSTR_(vfmadd231ss, _0, _1, _2) @@ -1003,10 +1015,6 @@ #define vaddsubpd(_0, _1, _2) VADDSUBPD(_0, _1, _2) #define vhaddpd(_0, _1, _2) VHADDPD(_0, _1, _2) #define vhaddps(_0, _1, _2) VHADDPS(_0, _1, _2) -#define vucomiss(_0, _1) VUCOMISS(_0, _1) -#define vucomisd(_0, _1) VUCOMISD(_0, _1) -#define vcomiss(_0, _1) VCOMISS(_0, _1) -#define vcomisd(_0, _1) VCOMISD(_0, _1) #define vaddps(_0, _1, _2) VADDPS(_0, _1, _2) #define vaddpd(_0, _1, _2) VADDPD(_0, _1, _2) #define vsubps(_0, _1, _2) VSUBPS(_0, _1, _2) @@ -1015,6 +1023,10 @@ #define vmulps(_0, _1, _2) VMULPS(_0, _1, _2) #define vmulsd(_0, _1, _2) VMULSD(_0, _1, _2) #define vmulpd(_0, _1, _2) VMULPD(_0, _1, _2) +#define vdivss(_0, _1, _2) VDIVSS(_0, _1, _2) +#define vdivps(_0, _1, _2) VDIVPS(_0, _1, _2) +#define vdivsd(_0, _1, _2) VDIVSD(_0, _1, _2) +#define vdivpd(_0, _1, _2) VDIVPD(_0, _1, _2) #define vpmulld(_0, _1, _2) VPMULLD(_0, _1, _2) #define vpmullq(_0, _1, _2) VPMULLQ(_0, _1, _2) #define vpaddd(_0, _1, _2) VPADDD(_0, _1, _2) @@ -1022,6 +1034,12 @@ #define vxorps(_0, _1, _2) VXORPS(_0, _1, _2) #define vxorpd(_0, _1, _2) VXORPD(_0, _1, _2) #define vpxord(_0, _1, _2) VPXORD(_0, _1, _2) + +#define vucomiss(_0, _1) VUCOMISS(_0, _1) +#define vucomisd(_0, _1) VUCOMISD(_0, _1) +#define vcomiss(_0, _1) VCOMISS(_0, _1) +#define vcomisd(_0, _1) VCOMISD(_0, _1) + #define vfmadd132ss(_0, _1, _2) VFMADD132SS(_0, _1, _2) #define vfmadd213ss(_0, _1, _2) VFMADD213SS(_0, _1, _2) #define vfmadd231ss(_0, _1, _2) VFMADD231SS(_0, _1, _2) diff --git a/kernels/haswell/3/bli_gemmtrsm_l_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemmtrsm_l_haswell_asm_d6x8.c index 3d69556ff..a6edf8c49 100644 --- a/kernels/haswell/3/bli_gemmtrsm_l_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemmtrsm_l_haswell_asm_d6x8.c @@ -374,8 +374,13 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vbroadcastss(mem(0+0*6)*4(rax), ymm0) // ymm0 = (1/alpha00) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulps(ymm0, ymm4, ymm4) // ymm4 *= (1/alpha00) vmulps(ymm0, ymm5, ymm5) // ymm5 *= (1/alpha00) +#else + vdivps(ymm0, ymm4, ymm4) // ymm4 /= alpha00 + vdivps(ymm0, ymm5, ymm5) // ymm5 /= alpha00 +#endif vmovups(ymm4, mem(rcx)) // store ( beta00..beta07 ) = ymm4 vmovups(ymm5, mem(rdx)) // store ( beta08..beta0F ) = ymm5 @@ -393,8 +398,13 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vsubps(ymm2, ymm6, ymm6) // ymm6 -= ymm2 vsubps(ymm3, ymm7, ymm7) // ymm7 -= ymm3 - vmulps(ymm6, ymm1, ymm6) // ymm6 *= (1/alpha11) - vmulps(ymm7, ymm1, ymm7) // ymm7 *= (1/alpha11) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm1, ymm6, ymm6) // ymm6 *= (1/alpha11) + vmulps(ymm1, ymm7, ymm7) // ymm7 *= (1/alpha11) +#else + vdivps(ymm1, ymm6, ymm6) // ymm6 /= alpha11 + vdivps(ymm1, ymm7, ymm7) // ymm7 /= alpha11 +#endif vmovups(ymm6, mem(rcx)) // store ( beta10..beta17 ) = ymm6 vmovups(ymm7, mem(rdx)) // store ( beta18..beta1F ) = ymm7 @@ -417,8 +427,13 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vsubps(ymm2, ymm8, ymm8) // ymm8 -= ymm2 vsubps(ymm3, ymm9, ymm9) // ymm9 -= ymm3 - vmulps(ymm8, ymm0, ymm8) // ymm8 *= (1/alpha22) - vmulps(ymm9, ymm0, ymm9) // ymm9 *= (1/alpha22) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm0, ymm8, ymm8) // ymm8 *= (1/alpha22) + vmulps(ymm0, ymm9, ymm9) // ymm9 *= (1/alpha22) +#else + vdivps(ymm0, ymm8, ymm8) // ymm8 /= alpha22 + vdivps(ymm0, ymm9, ymm9) // ymm9 /= alpha22 +#endif vmovups(ymm8, mem(rcx)) // store ( beta20..beta27 ) = ymm8 vmovups(ymm9, mem(rdx)) // store ( beta28..beta2F ) = ymm9 @@ -446,8 +461,13 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vsubps(ymm2, ymm10, ymm10) // ymm10 -= ymm2 vsubps(ymm3, ymm11, ymm11) // ymm11 -= ymm3 - vmulps(ymm10, ymm1, ymm10) // ymm10 *= (1/alpha33) - vmulps(ymm11, ymm1, ymm11) // ymm11 *= (1/alpha33) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm1, ymm10, ymm10) // ymm10 *= (1/alpha33) + vmulps(ymm1, ymm11, ymm11) // ymm11 *= (1/alpha33) +#else + vdivps(ymm1, ymm10, ymm10) // ymm10 /= alpha33 + vdivps(ymm1, ymm11, ymm11) // ymm11 /= alpha33 +#endif vmovups(ymm10, mem(rcx)) // store ( beta30..beta37 ) = ymm10 vmovups(ymm11, mem(rdx)) // store ( beta38..beta3F ) = ymm11 @@ -480,8 +500,13 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vsubps(ymm2, ymm12, ymm12) // ymm12 -= ymm2 vsubps(ymm3, ymm13, ymm13) // ymm13 -= ymm3 - vmulps(ymm12, ymm0, ymm12) // ymm12 *= (1/alpha44) - vmulps(ymm13, ymm0, ymm13) // ymm13 *= (1/alpha44) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm0, ymm12, ymm12) // ymm12 *= (1/alpha44) + vmulps(ymm0, ymm13, ymm13) // ymm13 *= (1/alpha44) +#else + vdivps(ymm0, ymm12, ymm12) // ymm12 /= alpha44 + vdivps(ymm0, ymm13, ymm13) // ymm13 /= alpha44 +#endif vmovups(ymm12, mem(rcx)) // store ( beta40..beta47 ) = ymm12 vmovups(ymm13, mem(rdx)) // store ( beta48..beta4F ) = ymm13 @@ -519,8 +544,13 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vsubps(ymm2, ymm14, ymm14) // ymm14 -= ymm2 vsubps(ymm3, ymm15, ymm15) // ymm15 -= ymm3 - vmulps(ymm14, ymm1, ymm14) // ymm14 *= (1/alpha55) - vmulps(ymm15, ymm1, ymm15) // ymm15 *= (1/alpha55) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm1, ymm14, ymm14) // ymm14 *= (1/alpha55) + vmulps(ymm1, ymm15, ymm15) // ymm15 *= (1/alpha55) +#else + vdivps(ymm1, ymm14, ymm14) // ymm14 /= alpha55 + vdivps(ymm1, ymm15, ymm15) // ymm15 /= alpha55 +#endif vmovups(ymm14, mem(rcx)) // store ( beta50..beta57 ) = ymm14 vmovups(ymm15, mem(rdx)) // store ( beta58..beta5F ) = ymm15 @@ -1129,8 +1159,13 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vbroadcastsd(mem(0+0*6)*8(rax), ymm0) // ymm0 = (1/alpha00) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulpd(ymm0, ymm4, ymm4) // ymm4 *= (1/alpha00) vmulpd(ymm0, ymm5, ymm5) // ymm5 *= (1/alpha00) +#else + vdivpd(ymm0, ymm4, ymm4) // ymm4 /= alpha00 + vdivpd(ymm0, ymm5, ymm5) // ymm5 /= alpha00 +#endif vmovupd(ymm4, mem(rcx)) // store ( beta00..beta03 ) = ymm4 vmovupd(ymm5, mem(rdx)) // store ( beta04..beta07 ) = ymm5 @@ -1148,8 +1183,13 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vsubpd(ymm2, ymm6, ymm6) // ymm6 -= ymm2 vsubpd(ymm3, ymm7, ymm7) // ymm7 -= ymm3 - vmulpd(ymm6, ymm1, ymm6) // ymm6 *= (1/alpha11) - vmulpd(ymm7, ymm1, ymm7) // ymm7 *= (1/alpha11) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm1, ymm6, ymm6) // ymm6 *= (1/alpha11) + vmulpd(ymm1, ymm7, ymm7) // ymm7 *= (1/alpha11) +#else + vdivpd(ymm1, ymm6, ymm6) // ymm6 /= alpha11 + vdivpd(ymm1, ymm7, ymm7) // ymm7 /= alpha11 +#endif vmovupd(ymm6, mem(rcx)) // store ( beta10..beta13 ) = ymm6 vmovupd(ymm7, mem(rdx)) // store ( beta14..beta17 ) = ymm7 @@ -1172,8 +1212,13 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vsubpd(ymm2, ymm8, ymm8) // ymm8 -= ymm2 vsubpd(ymm3, ymm9, ymm9) // ymm9 -= ymm3 - vmulpd(ymm8, ymm0, ymm8) // ymm8 *= (1/alpha22) - vmulpd(ymm9, ymm0, ymm9) // ymm9 *= (1/alpha22) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm0, ymm8, ymm8) // ymm8 *= (1/alpha22) + vmulpd(ymm0, ymm9, ymm9) // ymm9 *= (1/alpha22) +#else + vdivpd(ymm0, ymm8, ymm8) // ymm8 /= alpha22 + vdivpd(ymm0, ymm9, ymm9) // ymm9 /= alpha22 +#endif vmovupd(ymm8, mem(rcx)) // store ( beta20..beta23 ) = ymm8 vmovupd(ymm9, mem(rdx)) // store ( beta24..beta27 ) = ymm9 @@ -1201,8 +1246,13 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vsubpd(ymm2, ymm10, ymm10) // ymm10 -= ymm2 vsubpd(ymm3, ymm11, ymm11) // ymm11 -= ymm3 - vmulpd(ymm10, ymm1, ymm10) // ymm10 *= (1/alpha33) - vmulpd(ymm11, ymm1, ymm11) // ymm11 *= (1/alpha33) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm1, ymm10, ymm10) // ymm10 *= (1/alpha33) + vmulpd(ymm1, ymm11, ymm11) // ymm11 *= (1/alpha33) +#else + vdivpd(ymm1, ymm10, ymm10) // ymm10 /= alpha33 + vdivpd(ymm1, ymm11, ymm11) // ymm11 /= alpha33 +#endif vmovupd(ymm10, mem(rcx)) // store ( beta30..beta33 ) = ymm10 vmovupd(ymm11, mem(rdx)) // store ( beta34..beta37 ) = ymm11 @@ -1235,8 +1285,13 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vsubpd(ymm2, ymm12, ymm12) // ymm12 -= ymm2 vsubpd(ymm3, ymm13, ymm13) // ymm13 -= ymm3 - vmulpd(ymm12, ymm0, ymm12) // ymm12 *= (1/alpha44) - vmulpd(ymm13, ymm0, ymm13) // ymm13 *= (1/alpha44) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm0, ymm12, ymm12) // ymm12 *= (1/alpha44) + vmulpd(ymm0, ymm13, ymm13) // ymm13 *= (1/alpha44) +#else + vdivpd(ymm0, ymm12, ymm12) // ymm12 /= alpha44 + vdivpd(ymm0, ymm13, ymm13) // ymm13 /= alpha44 +#endif vmovupd(ymm12, mem(rcx)) // store ( beta40..beta43 ) = ymm12 vmovupd(ymm13, mem(rdx)) // store ( beta44..beta47 ) = ymm13 @@ -1274,8 +1329,13 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vsubpd(ymm2, ymm14, ymm14) // ymm14 -= ymm2 vsubpd(ymm3, ymm15, ymm15) // ymm15 -= ymm3 - vmulpd(ymm14, ymm1, ymm14) // ymm14 *= (1/alpha55) - vmulpd(ymm15, ymm1, ymm15) // ymm15 *= (1/alpha55) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm1, ymm14, ymm14) // ymm14 *= (1/alpha55) + vmulpd(ymm1, ymm15, ymm15) // ymm15 *= (1/alpha55) +#else + vdivpd(ymm1, ymm14, ymm14) // ymm14 /= alpha55 + vdivpd(ymm1, ymm15, ymm15) // ymm15 /= alpha55 +#endif vmovupd(ymm14, mem(rcx)) // store ( beta50..beta53 ) = ymm14 vmovupd(ymm15, mem(rdx)) // store ( beta54..beta57 ) = ymm15 diff --git a/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c index ee54d1e3a..b14fb1177 100644 --- a/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c @@ -379,8 +379,13 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vbroadcastss(mem(5+5*6)*4(rax), ymm0) // ymm0 = (1/alpha55) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulps(ymm0, ymm14, ymm14) // ymm14 *= (1/alpha55) vmulps(ymm0, ymm15, ymm15) // ymm15 *= (1/alpha55) +#else + vdivps(ymm0, ymm14, ymm14) // ymm14 /= alpha55 + vdivps(ymm0, ymm15, ymm15) // ymm15 /= alpha55 +#endif vmovups(ymm14, mem(rcx)) // store ( beta50..beta57 ) = ymm14 vmovups(ymm15, mem(rdx)) // store ( beta58..beta5F ) = ymm15 @@ -398,8 +403,13 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vsubps(ymm2, ymm12, ymm12) // ymm12 -= ymm2 vsubps(ymm3, ymm13, ymm13) // ymm13 -= ymm3 - vmulps(ymm12, ymm1, ymm12) // ymm12 *= (1/alpha44) - vmulps(ymm13, ymm1, ymm13) // ymm13 *= (1/alpha44) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm1, ymm12, ymm12) // ymm12 *= (1/alpha44) + vmulps(ymm1, ymm13, ymm13) // ymm13 *= (1/alpha44) +#else + vdivps(ymm1, ymm12, ymm12) // ymm12 /= alpha44 + vdivps(ymm1, ymm13, ymm13) // ymm13 /= alpha44 +#endif vmovups(ymm12, mem(rcx)) // store ( beta40..beta47 ) = ymm12 vmovups(ymm13, mem(rdx)) // store ( beta48..beta4F ) = ymm13 @@ -422,8 +432,13 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vsubps(ymm2, ymm10, ymm10) // ymm10 -= ymm2 vsubps(ymm3, ymm11, ymm11) // ymm11 -= ymm3 - vmulps(ymm10, ymm0, ymm10) // ymm10 *= (1/alpha33) - vmulps(ymm11, ymm0, ymm11) // ymm11 *= (1/alpha33) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm0, ymm10, ymm10) // ymm10 *= (1/alpha33) + vmulps(ymm0, ymm11, ymm11) // ymm11 *= (1/alpha33) +#else + vdivps(ymm0, ymm10, ymm10) // ymm10 /= alpha33 + vdivps(ymm0, ymm11, ymm11) // ymm11 /= alpha33 +#endif vmovups(ymm10, mem(rcx)) // store ( beta30..beta37 ) = ymm10 vmovups(ymm11, mem(rdx)) // store ( beta38..beta3F ) = ymm11 @@ -451,8 +466,13 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vsubps(ymm2, ymm8, ymm8) // ymm8 -= ymm2 vsubps(ymm3, ymm9, ymm9) // ymm9 -= ymm3 - vmulps(ymm8, ymm1, ymm8) // ymm8 *= (1/alpha33) - vmulps(ymm9, ymm1, ymm9) // ymm9 *= (1/alpha33) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm1, ymm8, ymm8) // ymm8 *= (1/alpha22) + vmulps(ymm1, ymm9, ymm9) // ymm9 *= (1/alpha22) +#else + vdivps(ymm1, ymm8, ymm8) // ymm8 /= alpha22 + vdivps(ymm1, ymm9, ymm9) // ymm9 /= alpha22 +#endif vmovups(ymm8, mem(rcx)) // store ( beta20..beta27 ) = ymm8 vmovups(ymm9, mem(rdx)) // store ( beta28..beta2F ) = ymm9 @@ -485,8 +505,13 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vsubps(ymm2, ymm6, ymm6) // ymm6 -= ymm2 vsubps(ymm3, ymm7, ymm7) // ymm7 -= ymm3 - vmulps(ymm6, ymm0, ymm6) // ymm6 *= (1/alpha44) - vmulps(ymm7, ymm0, ymm7) // ymm7 *= (1/alpha44) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm0, ymm6, ymm6) // ymm6 *= (1/alpha11) + vmulps(ymm0, ymm7, ymm7) // ymm7 *= (1/alpha11) +#else + vdivps(ymm0, ymm6, ymm6) // ymm6 /= alpha11 + vdivps(ymm0, ymm7, ymm7) // ymm7 /= alpha11 +#endif vmovups(ymm6, mem(rcx)) // store ( beta10..beta17 ) = ymm6 vmovups(ymm7, mem(rdx)) // store ( beta18..beta1F ) = ymm7 @@ -524,8 +549,13 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vsubps(ymm2, ymm4, ymm4) // ymm4 -= ymm2 vsubps(ymm3, ymm5, ymm5) // ymm5 -= ymm3 - vmulps(ymm4, ymm1, ymm4) // ymm4 *= (1/alpha00) - vmulps(ymm5, ymm1, ymm5) // ymm5 *= (1/alpha00) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulps(ymm1, ymm4, ymm4) // ymm4 *= (1/alpha00) + vmulps(ymm1, ymm5, ymm5) // ymm5 *= (1/alpha00) +#else + vdivps(ymm1, ymm4, ymm4) // ymm4 /= alpha00 + vdivps(ymm1, ymm5, ymm5) // ymm5 /= alpha00 +#endif vmovups(ymm4, mem(rcx)) // store ( beta00..beta07 ) = ymm4 vmovups(ymm5, mem(rdx)) // store ( beta08..beta0F ) = ymm5 @@ -1138,8 +1168,13 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 vbroadcastsd(mem(5+5*6)*8(rax), ymm0) // ymm0 = (1/alpha55) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulpd(ymm0, ymm14, ymm14) // ymm14 *= (1/alpha55) vmulpd(ymm0, ymm15, ymm15) // ymm15 *= (1/alpha55) +#else + vdivpd(ymm0, ymm14, ymm14) // ymm14 /= alpha55 + vdivpd(ymm0, ymm15, ymm15) // ymm15 /= alpha55 +#endif vmovupd(ymm14, mem(rcx)) // store ( beta50..beta53 ) = ymm14 vmovupd(ymm15, mem(rdx)) // store ( beta54..beta57 ) = ymm15 @@ -1157,8 +1192,13 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 vsubpd(ymm2, ymm12, ymm12) // ymm12 -= ymm2 vsubpd(ymm3, ymm13, ymm13) // ymm13 -= ymm3 - vmulpd(ymm12, ymm1, ymm12) // ymm12 *= (1/alpha44) - vmulpd(ymm13, ymm1, ymm13) // ymm13 *= (1/alpha44) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm1, ymm12, ymm12) // ymm12 *= (1/alpha44) + vmulpd(ymm1, ymm13, ymm13) // ymm13 *= (1/alpha44) +#else + vdivpd(ymm1, ymm12, ymm12) // ymm12 /= alpha44 + vdivpd(ymm1, ymm13, ymm13) // ymm13 /= alpha44 +#endif vmovupd(ymm12, mem(rcx)) // store ( beta40..beta43 ) = ymm12 vmovupd(ymm13, mem(rdx)) // store ( beta44..beta47 ) = ymm13 @@ -1181,8 +1221,13 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 vsubpd(ymm2, ymm10, ymm10) // ymm10 -= ymm2 vsubpd(ymm3, ymm11, ymm11) // ymm11 -= ymm3 - vmulpd(ymm10, ymm0, ymm10) // ymm10 *= (1/alpha33) - vmulpd(ymm11, ymm0, ymm11) // ymm11 *= (1/alpha33) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm0, ymm10, ymm10) // ymm10 *= (1/alpha33) + vmulpd(ymm0, ymm11, ymm11) // ymm11 *= (1/alpha33) +#else + vdivpd(ymm0, ymm10, ymm10) // ymm10 /= alpha33 + vdivpd(ymm0, ymm11, ymm11) // ymm11 /= alpha33 +#endif vmovupd(ymm10, mem(rcx)) // store ( beta30..beta33 ) = ymm10 vmovupd(ymm11, mem(rdx)) // store ( beta34..beta37 ) = ymm11 @@ -1210,8 +1255,13 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 vsubpd(ymm2, ymm8, ymm8) // ymm8 -= ymm2 vsubpd(ymm3, ymm9, ymm9) // ymm9 -= ymm3 - vmulpd(ymm8, ymm1, ymm8) // ymm8 *= (1/alpha33) - vmulpd(ymm9, ymm1, ymm9) // ymm9 *= (1/alpha33) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm1, ymm8, ymm8) // ymm8 *= (1/alpha22) + vmulpd(ymm1, ymm9, ymm9) // ymm9 *= (1/alpha22) +#else + vdivpd(ymm1, ymm8, ymm8) // ymm8 /= alpha22 + vdivpd(ymm1, ymm9, ymm9) // ymm9 /= alpha22 +#endif vmovupd(ymm8, mem(rcx)) // store ( beta20..beta23 ) = ymm8 vmovupd(ymm9, mem(rdx)) // store ( beta24..beta27 ) = ymm9 @@ -1244,8 +1294,13 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 vsubpd(ymm2, ymm6, ymm6) // ymm6 -= ymm2 vsubpd(ymm3, ymm7, ymm7) // ymm7 -= ymm3 - vmulpd(ymm6, ymm0, ymm6) // ymm6 *= (1/alpha44) - vmulpd(ymm7, ymm0, ymm7) // ymm7 *= (1/alpha44) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm0, ymm6, ymm6) // ymm6 *= (1/alpha11) + vmulpd(ymm0, ymm7, ymm7) // ymm7 *= (1/alpha11) +#else + vdivpd(ymm0, ymm6, ymm6) // ymm6 /= alpha11 + vdivpd(ymm0, ymm7, ymm7) // ymm7 /= alpha11 +#endif vmovupd(ymm6, mem(rcx)) // store ( beta10..beta13 ) = ymm6 vmovupd(ymm7, mem(rdx)) // store ( beta14..beta17 ) = ymm7 @@ -1283,8 +1338,13 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 vsubpd(ymm2, ymm4, ymm4) // ymm4 -= ymm2 vsubpd(ymm3, ymm5, ymm5) // ymm5 -= ymm3 - vmulpd(ymm4, ymm1, ymm4) // ymm4 *= (1/alpha00) - vmulpd(ymm5, ymm1, ymm5) // ymm5 *= (1/alpha00) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + vmulpd(ymm1, ymm4, ymm4) // ymm4 *= (1/alpha00) + vmulpd(ymm1, ymm5, ymm5) // ymm5 *= (1/alpha00) +#else + vdivpd(ymm1, ymm4, ymm4) // ymm4 /= alpha00 + vdivpd(ymm1, ymm5, ymm5) // ymm5 /= alpha00 +#endif vmovupd(ymm4, mem(rcx)) // store ( beta00..beta03 ) = ymm4 vmovupd(ymm5, mem(rdx)) // store ( beta04..beta07 ) = ymm5 diff --git a/kernels/penryn/3/bli_gemmtrsm_l_penryn_asm_d4x4.c b/kernels/penryn/3/bli_gemmtrsm_l_penryn_asm_d4x4.c index 6739e262f..56afcf08c 100644 --- a/kernels/penryn/3/bli_gemmtrsm_l_penryn_asm_d4x4.c +++ b/kernels/penryn/3/bli_gemmtrsm_l_penryn_asm_d4x4.c @@ -415,8 +415,13 @@ void bli_dgemmtrsm_l_penryn_asm_4x4 movddup(mem(0+0*4)*8(rax), xmm0) // load xmm0 = (1/alpha00) - mulpd(xmm0, xmm8) // xmm8 *= (1/alpha00); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + mulpd(xmm0, xmm8) // xmm8 *= (1/alpha00); mulpd(xmm0, xmm12) // xmm12 *= (1/alpha00); +#else + divpd(xmm0, xmm8) // xmm8 /= alpha00; + divpd(xmm0, xmm12) // xmm12 /= alpha00; +#endif movaps(xmm8, mem(rbx, 0*16)) // store ( beta00 beta01 ) = xmm8 movaps(xmm12, mem(rbx, 1*16)) // store ( beta02 beta03 ) = xmm12 @@ -439,8 +444,13 @@ void bli_dgemmtrsm_l_penryn_asm_4x4 mulpd(xmm12, xmm4) // xmm4 = alpha10 * ( beta02 beta03 ) subpd(xmm0, xmm9) // xmm9 -= xmm0 subpd(xmm4, xmm13) // xmm13 -= xmm4 - mulpd(xmm1, xmm9) // xmm9 *= (1/alpha11); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + mulpd(xmm1, xmm9) // xmm9 *= (1/alpha11); mulpd(xmm1, xmm13) // xmm13 *= (1/alpha11); +#else + divpd(xmm1, xmm9) // xmm9 /= alpha11; + divpd(xmm1, xmm13) // xmm13 /= alpha11; +#endif movaps(xmm9, mem(rbx, 2*16)) // store ( beta10 beta11 ) = xmm9 movaps(xmm13, mem(rbx, 3*16)) // store ( beta12 beta13 ) = xmm13 @@ -469,8 +479,13 @@ void bli_dgemmtrsm_l_penryn_asm_4x4 addpd(xmm5, xmm4) // xmm4 += xmm5; subpd(xmm0, xmm10) // xmm10 -= xmm0 subpd(xmm4, xmm14) // xmm14 -= xmm4 +#ifdef BLIS_ENABLE_TRSM_PREINVERSION mulpd(xmm2, xmm10) // xmm10 *= (1/alpha22); mulpd(xmm2, xmm14) // xmm14 *= (1/alpha22); +#else + divpd(xmm2, xmm10) // xmm10 /= alpha22; + divpd(xmm2, xmm14) // xmm14 /= alpha22; +#endif movaps(xmm10, mem(rbx, 4*16)) // store ( beta20 beta21 ) = xmm10 movaps(xmm14, mem(rbx, 5*16)) // store ( beta22 beta23 ) = xmm14 @@ -505,8 +520,13 @@ void bli_dgemmtrsm_l_penryn_asm_4x4 addpd(xmm6, xmm4) // xmm4 += xmm6; subpd(xmm0, xmm11) // xmm11 -= xmm0 subpd(xmm4, xmm15) // xmm15 -= xmm4 +#ifdef BLIS_ENABLE_TRSM_PREINVERSION mulpd(xmm3, xmm11) // xmm11 *= (1/alpha33); mulpd(xmm3, xmm15) // xmm15 *= (1/alpha33); +#else + divpd(xmm3, xmm11) // xmm11 /= alpha33; + divpd(xmm3, xmm15) // xmm15 /= alpha33; +#endif movaps(xmm11, mem(rbx, 6*16)) // store ( beta30 beta31 ) = xmm11 movaps(xmm15, mem(rbx, 7*16)) // store ( beta32 beta33 ) = xmm15 diff --git a/kernels/penryn/3/bli_gemmtrsm_u_penryn_asm_d4x4.c b/kernels/penryn/3/bli_gemmtrsm_u_penryn_asm_d4x4.c index 5c355aac8..9811e0e32 100644 --- a/kernels/penryn/3/bli_gemmtrsm_u_penryn_asm_d4x4.c +++ b/kernels/penryn/3/bli_gemmtrsm_u_penryn_asm_d4x4.c @@ -401,8 +401,13 @@ void bli_dgemmtrsm_u_penryn_asm_4x4 movddup(mem(3+3*4)*8(rax), xmm3) // load xmm3 = (1/alpha33) +#ifdef BLIS_ENABLE_TRSM_PREINVERSION mulpd(xmm3, xmm11) // xmm11 *= (1/alpha33); mulpd(xmm3, xmm15) // xmm15 *= (1/alpha33); +#else + divpd(xmm3, xmm11) // xmm11 /= alpha33; + divpd(xmm3, xmm15) // xmm15 /= alpha33; +#endif movaps(xmm11, mem(rbx, 6*16)) // store ( beta30 beta31 ) = xmm11 movaps(xmm15, mem(rbx, 7*16)) // store ( beta32 beta33 ) = xmm15 @@ -425,8 +430,13 @@ void bli_dgemmtrsm_u_penryn_asm_4x4 mulpd(xmm15, xmm7) // xmm7 = alpha23 * ( beta32 beta33 ) subpd(xmm3, xmm10) // xmm10 -= xmm3 subpd(xmm7, xmm14) // xmm14 -= xmm7 +#ifdef BLIS_ENABLE_TRSM_PREINVERSION mulpd(xmm2, xmm10) // xmm10 *= (1/alpha22); mulpd(xmm2, xmm14) // xmm14 *= (1/alpha22); +#else + divpd(xmm2, xmm10) // xmm10 /= alpha22; + divpd(xmm2, xmm14) // xmm14 /= alpha22; +#endif movaps(xmm10, mem(rbx, 4*16)) // store ( beta20 beta21 ) = xmm10 movaps(xmm14, mem(rbx, 5*16)) // store ( beta22 beta23 ) = xmm14 @@ -455,8 +465,13 @@ void bli_dgemmtrsm_u_penryn_asm_4x4 addpd(xmm7, xmm6) // xmm6 += xmm7; subpd(xmm2, xmm9) // xmm9 -= xmm2 subpd(xmm6, xmm13) // xmm13 -= xmm6 - mulpd(xmm1, xmm9) // xmm9 *= (1/alpha11); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + mulpd(xmm1, xmm9) // xmm9 *= (1/alpha11); mulpd(xmm1, xmm13) // xmm13 *= (1/alpha11); +#else + divpd(xmm1, xmm9) // xmm9 /= alpha11; + divpd(xmm1, xmm13) // xmm13 /= alpha11; +#endif movaps(xmm9, mem(rbx, 2*16)) // store ( beta10 beta11 ) = xmm9 movaps(xmm13, mem(rbx, 3*16)) // store ( beta12 beta13 ) = xmm13 @@ -491,8 +506,13 @@ void bli_dgemmtrsm_u_penryn_asm_4x4 addpd(xmm7, xmm5) // xmm5 += xmm7; subpd(xmm1, xmm8) // xmm8 -= xmm1 subpd(xmm5, xmm12) // xmm12 -= xmm5 - mulpd(xmm0, xmm8) // xmm8 *= (1/alpha00); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + mulpd(xmm0, xmm8) // xmm8 *= (1/alpha00); mulpd(xmm0, xmm12) // xmm12 *= (1/alpha00); +#else + divpd(xmm0, xmm8) // xmm8 /= alpha00; + divpd(xmm0, xmm12) // xmm12 /= alpha00; +#endif movaps(xmm8, mem(rbx, 0*16)) // store ( beta00 beta01 ) = xmm8 movaps(xmm12, mem(rbx, 1*16)) // store ( beta02 beta03 ) = xmm12