diff --git a/.appveyor.yml b/.appveyor.yml index efc98f555..87aee9c97 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -39,7 +39,7 @@ build_script: - bash -lc "cd /c/projects/blis && ./configure %CONFIGURE_OPTS% --enable-threading=%THREADING% --enable-arg-max-hack --prefix=/c/blis %CONFIG%" - bash -lc "cd /c/projects/blis && mingw32-make -j4 V=1" - bash -lc "cd /c/projects/blis && mingw32-make install" -- ps: Compress-Archive -Path C:\blis -DestinationPath C:\blis.zip +- 7z a C:\blis.zip C:\blis - ps: Push-AppveyorArtifact C:\blis.zip test_script: diff --git a/CHANGELOG b/CHANGELOG index 784c9f5fd..80150e185 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,10 +1,1703 @@ -commit e0408c3ca3d53bc8e6fedac46ea42c86e06c922d (HEAD -> master, tag: 0.5.1) +commit 18c876b989fd0dcaa27becd14e4f16bdac7e89b3 (HEAD -> master, tag: 0.6.0) +Author: Field G. Van Zee +Date: Mon Jun 3 18:37:19 2019 -0500 + + Version file update (0.6.0) + +commit 0f1b3bf49eb593ca7bb08b68a7209f7cd550f912 (origin/master, origin/HEAD) +Author: Field G. Van Zee +Date: Mon Jun 3 18:35:19 2019 -0500 + + ReleaseNotes.md update in advance of next version. + + Details: + - Updated ReleaseNotes.md in preparation for next version. + - CREDITS file update. + +commit 27da2e8400d900855da0d834b5417d7e83f21de1 +Author: Field G. Van Zee +Date: Mon Jun 3 17:14:56 2019 -0500 + + Minor edits to docs/PerformanceSmall.md. + + Details: + - Added performance analysis to "Comments" section of both Kaby Lake and + Epyc sections. + - Added emphasis to certain passages. + +commit 09ba05c6f87efbaadf085497dc137845f16ee9c5 +Author: Field G. Van Zee +Date: Mon Jun 3 16:53:19 2019 -0500 + + Added sup performance graphs/document to 'docs'. + + Details: + - Added a new markdown document, docs/PerformanceSmall.md, which + publishes new performance graphs for Kaby Lake and Epyc showcasing + the new BLIS sup (small/skinny/unpacked) framework logic and kernels. + For now, only single-threaded dgemm performance is shown. + - Reorganized graphs in docs/graphs into docs/graphs/large, with new + graphs being placed in docs/graphs/sup. + - Updates to scripts in test/sup/octave, mostly to allow decent output + in both GNU octave and Matlab. + - Updated README.md to mention and refer to the new PerformanceSmall.md + document. + +commit 6bf449cc6941734748034de0e9af22b75f1d6ba1 +Merge: abd8a9fa a4e8801d +Author: Field G. Van Zee +Date: Fri May 31 17:42:40 2019 -0500 + + Merge branch 'amd' + +commit a4e8801d08d81fa42ebea6a05a990de8dcedc803 (origin/amd, amd) +Author: Field G. Van Zee +Date: Fri May 31 17:30:51 2019 -0500 + + Increased MT sup threshold for double to 201. + + Details: + - Fine-tuned the double-precision real MT threshold (which controls + whether the sup implementation kicks for smaller m dimension values) + from 180 to 201 for haswell and 180 to 256 for zen. + - Updated octave scripts in test/sup/octave to include a seventh column + to display performance for m = n = k. + +commit abd8a9fa7df4569aa2711964c19888b8e248901f (origin/pfhp) +Author: Field G. Van Zee +Date: Tue May 28 12:49:44 2019 -0500 + + Inadvertantly hidden xerbla_() in blastest (#313). + + Details: + - Attempted a fix to issue #313, which reports that when building only + a shared library (ie: static library build is disabled), running the + BLAS test drivers can fail because those drivers provide their own + local version of xerbla_() as a clever (albeit still rather hackish) + way of checking the error codes that result from the individual tests. + This local xerbla_() function is never found at link-time because the + BLAS test drivers' Makefile imports BLIS compilation flags via the + get-user-cflags-for() function, which currently conveys the + -fvisibility=hidden flag, which hides symbols unless they are + explicitly annotated for export. The -fvisibility=hidden flag was + only ever intended for use when building BLIS (not for applications), + and so the attempted solution here is to omit the symbol export + flag(s) from get-user-cflags-for() by storing the symbol export + flag(s) to a new BULID_SYMFLAGS variable instead of appending it + to the subconfigurations' CMISCFLAGS variable (which is returned by + every get-*-cflags-for() function). Thanks to M. Zhou for reporting + this issue and also to Isuru Fernando for suggesting the fix. + - Renamed BUILD_FLAGS to BUILD_CPPFLAGS to harmonize with the newly + created BUILD_SYMFLAGS. + - Fixed typo in entry for --export-shared flag in 'configure --help' + text. + +commit 755730608d923538273a90c48bfdf77571f86519 +Author: Field G. Van Zee +Date: Thu May 23 17:34:36 2019 -0500 + + Minor rewording of language around mt env. vars. + +commit ba31abe73c97c16c78fffc59a215761b8d9fd1f6 +Author: Field G. Van Zee +Date: Thu May 23 14:59:53 2019 -0500 + + Added BLIS theading info to Performance.md. + + Details: + - Documented the BLIS environment variables that were set + (e.g. BLIS_JC_NT, BLIS_IC_NT, BLIS_JR_NT) for each machine and + threading configuration in order to achieve the parallelism reported + on in docs/Performance.md. + +commit cb788ffc89cac03b44803620412a5e83450ca949 +Author: Field G. Van Zee +Date: Thu May 23 13:00:53 2019 -0500 + + Increased MT sup threshold for double to 180. + + Details: + - Increased the double-precision real MT threshold (which controls + whether the sup implementation kicks for smaller m dimension values) + from 80 to 180, and this change was made for both haswell and zen + subconfigurations. This is less about the m dimension in particular + and more about facilitating a smoother performance transition when + m = n = k. + +commit 057f5f3d211e7513f457ee6ca6c9555d00ad1e57 +Author: Field G. Van Zee +Date: Thu May 23 12:51:17 2019 -0500 + + Minor build system housekeeping. + + Details: + - Commented out redundant setting of LIBBLIS_LINK within all driver- + level Makefiles. This variable is already set within common.mk, and + so the only time it should be overridden is if the user wants to link + to a different copy of libblis. + - Very minor changes to build/gen-make-frags/gen-make-frag.sh. + - Whitespace and inconsequential quoting change to configure. + - Moved top-level 'windows' directory into a new 'attic' directory. + +commit 32392cfc72af7f42da817a129748349fb1951346 +Author: Jeff Hammond +Date: Tue May 14 15:52:30 2019 -0400 + + add info about CXX in configure (#311) + +commit fa7e6b182b8365465ade178b0e4cd344ff6f6460 +Author: Field G. Van Zee +Date: Wed May 1 19:13:00 2019 -0500 + + Define _POSIX_C_SOURCE in bli_system.h. + + Details: + - Added + #ifndef _POSIX_C_SOURCE + #define _POSIX_C_SOURCE 200809L + #endif + to bli_system.h so that an application that uses BLIS (specifically, + an application that #includes blis.h) does not need to remember to + #define the macro itself (either on the command line or in the code + that includes blis.h) in order to activate things like the pthreads. + Thanks to Christos Psarras for reporting this issue and suggesting + this fix. + - Commented out #include in bli_system.h, since I don't + think this header is used/needed anymore. + - Comment update to function macro for bli_?normiv_unb_var1() in + frame/util/bli_util_unb_var1.c. + +commit 3df84f1b5d5e1146bb01bfc466ac20c60a9cc859 +Author: Field G. Van Zee +Date: Sat Apr 27 21:27:32 2019 -0500 + + Minor bugfixes in sup dgemm implementation. + + Details: + - Fixed an obscure but in the bli_dgemmsup_rv_haswell_asm_5x8n() kernel + that only affected the beta == 0, column-storage output case. Thanks + to the BLAS test drivers for catching this bug. + - Previously, bli_gemmsup_ref_var1n() and _var2m() were returning if + k = 0, when the correct action would be to scale by beta (and then + return). Thanks to the BLAS test drivers to catching this bug. + - Changed the sup threshold behavior such that the sup implementation + only kicks in if a matrix dimension is strictly less than (rather than + less than or equal to) the threshold in question. + - Initialize all thresholds to zero (instead of 10) by default in + ref_kernels/bli_cntx_ref.c. This, combined with the above change to + threshold testing means that calls to BLIS or BLAS with one or more + matrix dimensions of zero will no longer trigger the sup + implementation. + - Added disabled debugging output to frame/3/bli_l3_sup.c (for future + use, perhaps). + +commit ecbdd1c42dcebfecd729fe351e6bb0076aba7d81 +Author: Field G. Van Zee +Date: Sat Apr 27 19:38:11 2019 -0500 + + Ceased use of BLIS_ENABLE_SUP_MR/NR_EXT macros. + + Details: + - Removed already limited use of the BLIS_ENABLE_SUP_MR_EXT and + BLIS_ENABLE_SUP_NR_EXT macros in bli_gemmsup_ref_var1n() and + bli_gemmsup_ref_var2m(). Their purpose was merely to avoid a long + conditional that would determine whether to allow the last iteration + to be merged with the second-to-last iteration. Functionally, the + macros were not needed, and they ended up causing problems when + building configuration families such as intel64 and x86_64. + +commit aa8a6bec3036a41e1bff2034f8ef6766a704ec49 +Author: Field G. Van Zee +Date: Sat Apr 27 18:53:33 2019 -0500 + + Fixed typo in --disable-sup-handling macro guard. + + Details: + - Fixed an incorrectly-named macro guard that is intended to allow + disabling of the sup framework via the configure option + --disable-sup-handling. In this case, the preprocessor macro, + BLIS_DISABLE_SUP_HANDLING, was still named by its name from an older + uncommitted version of the code (BLIS_DISABLE_SM_HANDLING). + +commit b9c9f03502c78a63cfcc21654b06e9089e2a3822 +Author: Field G. Van Zee +Date: Sat Apr 27 18:44:50 2019 -0500 + + Implemented gemm on skinny/unpacked matrices. + + Details: + - Implemented a new sub-framework within BLIS to support the management + of code and kernels that specifically target matrix problems for which + at least one dimension is deemed to be small, which can result in long + and skinny matrix operands that are ill-suited for the conventional + level-3 implementations in BLIS. The new framework tackles the problem + in two ways. First the stripped-down algorithmic loops forgo the + packing that is famously performed in the classic code path. That is, + the computation is performed by a new family of kernels tailored + specifically for operating on the source matrices as-is (unpacked). + Second, these new kernels will typically (and in the case of haswell + and zen, do in fact) include separate assembly sub-kernels for + handling of edge cases, which helps smooth performance when performing + problems whose m and n dimension are not naturally multiples of the + register blocksizes. In a reference to the sub-framework's purpose of + supporting skinny/unpacked level-3 operations, the "sup" operation + suffix (e.g. gemmsup) is typically used to denote a separate namespace + for related code and kernels. NOTE: Since the sup framework does not + perform any packing, it targets row- and column-stored matrices A, B, + and C. For now, if any matrix has non-unit strides in both dimensions, + the problem is computed by the conventional implementation. + - Implemented the default sup handler as a front-end to two variants. + bli_gemmsup_ref_var2() provides a block-panel variant (in which the + 2nd loop around the microkernel iterates over n and the 1st loop + iterates over m), while bli_gemmsup_ref_var1() provides a panel-block + variant (2nd loop over m and 1st loop over n). However, these variants + are not used by default and provided for reference only. Instead, the + default sup handler calls _var2m() and _var1n(), which are similar + to _var2() and _var1(), respectively, except that they defer to the + sup kernel itself to iterate over the m and n dimension, respectively. + In other words, these variants rely not on microkernels, but on + so-called "millikernels" that iterate along m and k, or n and k. + The benefit of using millikernels is a reduction of function call + and related (local integer typecast) overhead as well as the ability + for the kernel to know which micropanel (A or B) will change during + the next iteration of the 1st loop, which allows it to focus its + prefetching on that micropanel. (In _var2m()'s millikernel, the upanel + of A changes while the same upanel of B is reused. In _var1n()'s, the + upanel of B changes while the upanel of A is reused.) + - Added a new configure option, --[en|dis]able-sup-handling, which is + enabled by default. However, the default thresholds at which the + default sup handler is activated are set to zero for each of the m, n, + and k dimensions, which effectively disables the implementation. (The + default sup handler only accepts the problem if at least one dimension + is smaller than or equal to its corresponding threshold. If all + dimensions are larger than their thresholds, the problem is rejected + by the sup front-end and control is passed back to the conventional + implementation, which proceeds normally.) + - Added support to the cntx_t structure to track new fields related to + the sup framework, most notably: + - sup thresholds: the thresholds at which the sup handler is called. + - sup handlers: the address of the function to call to implement + the level-3 skinny/unpacked matrix implementation. + - sup blocksizes: the register and cache blocksizes used by the sup + implementation (which may be the same or different from those used + by the conventional packm-based approach). + - sup kernels: the kernels that the handler will use in implementing + the sup functionality. + - sup kernel prefs: the IO preference of the sup kernels, which may + differ from the preferences of the conventional gemm microkernels' + IO preferences. + - Added a bool_t to the rntm_t structure that indicates whether sup + handling should be enabled/disabled. This allows per-call control + of whether the sup implementation is used, which is useful for test + drivers that wish to switch between the conventional and sup codes + without having to link to different copies of BLIS. The corresponding + accessor functions for this new bool_t are defined in bli_rntm.h. + - Implemented several row-preferential gemmsup kernels in a new + directory, kernels/haswell/3/sup. These kernels include two general + implementation types--'rd' and 'rv'--for the 6x8 base shape, with + two specialized millikernels that embed the 1st loop within the kernel + itself. + - Added ref_kernels/3/bli_gemmsup_ref.c, which provides reference + gemmsup microkernels. NOTE: These microkernels, unlike the current + crop of conventional (pack-based) microkernels, do not use constant + loop bounds. Additionally, their inner loop iterates over the k + dimension. + - Defined new typedef enums: + - stor3_t: captures the effective storage combination of the level-3 + problem. Valid values are BLIS_RRR, BLIS_RRC, BLIS_RCR, etc. A + special value of BLIS_XXX is used to denote an arbitrary combination + which, in practice, means that at least one of the operands is + stored according to general stride. + - threshid_t: captures each of the three dimension thresholds. + - Changed bli_adjust_strides() in bli_obj.c so that bli_obj_create() + can be passed "-1, -1" as a lazy request for row storage. (Note that + "0, 0" is still accepted as a lazy request for column storage.) + - Added support for various instructions to bli_x86_asm_macros.h, + including imul, vhaddps/pd, and other instructions related to integer + vectors. + - Disabled the older small matrix handling code inserted by AMD in + bli_gemm_front.c, since the sup framework introduced in this commit + is intended to provide a more generalized solution. + - Added test/sup directory, which contains standalone performance test + drivers, a Makefile, a runme.sh script, and an 'octave' directory + containing scripts compatible with GNU Octave. (They also may work + with matlab, but if not, they are probably close to working.) + - Reinterpret the storage combination string (sc_str) in the various + level-3 testsuite modules (e.g. src/test_gemm.c) so that the order + of each matrix storage char is "cab" rather than "abc". + - Comment updates in level-3 BLAS API wrappers in frame/compat. + +commit 0d549ceda822833bec192bbf80633599620c15d9 +Author: Isuru Fernando +Date: Sat Apr 27 22:56:02 2019 +0000 + + make unix friendly archives on appveyor (#310) + +commit 945928c650051c04d6900c7f4e9e29cd0e5b299f +Merge: 663f6629 74e513eb +Author: Field G. Van Zee +Date: Wed Apr 17 15:58:56 2019 -0500 + + Merge branch 'amd' of github.com:flame/blis into amd + +commit 74e513eb6a6787a925d43cd1500277d54d86ab8f (origin/dev) +Author: Field G. Van Zee +Date: Wed Apr 17 13:34:44 2019 -0500 + + Support row storage in Eigen gemm test/3 driver. + + Details: + - Added preprocessor branches to test/3/test_gemm.c to explicitly + support row-stored matrices. Column-stored matrices are also still + supported (and is the default for now). (This is mainly residual work + leftover from initial integration of Eigen into the test drivers, so + if we ever want to test Eigen with row-stored matrices, the code will + be ready to use, even if it is not yet integrated into the Makefile + in test/3.) + +commit b5d457fae9bd75c4ca67f7bc7214e527aa248127 +Author: Field G. Van Zee +Date: Tue Apr 16 12:50:01 2019 -0500 + + Applied forgotten variable rename from 89a70cc. + + Details: + - Somehow the variable name change (root_file_name -> root_inputname) + in flatten-headers.py mentioned in the commit log entry for 89a70cc + didn't make it into the actual commit. This commit applies that + change. + +commit 89a70cccf869333147eb2559cdfa5a23dc915824 +Author: Field G. Van Zee +Date: Thu Apr 11 18:33:08 2019 -0500 + + GNU-like handling of installation prefix et al. + + Details: + - Changed the default installation prefix from $HOME/lib to /usr/local. + - Modified the way configure internally handles the prefix, libdir, + includedir, and sharedir (and also added an --exec-prefix option). + The defaults to these variables are set as follows: + prefix: /usr/local + exec_prefix: ${prefix} + libdir: ${exec_prefix}/lib + includedir: ${prefix}/include + sharedir: ${prefix}/share + The key change, aside from the addition of exec_prefix and its use to + define the default to libdir, is that the variables are substituted + into config.mk with quoting that delays evaluation, meaning the + substituted values may contain unevaluated references to other + variables (namely, ${prefix} and ${exec_prefix}). This more closely + follows GNU conventions, including those used by GNU autoconf, and + also allows make to override any one of the variables *after* + configure has already been run (e.g. during 'make install'). + - Updates to build/config.mk.in pursuant to above changes. + - Updates to output of 'configure --help' pursuant to above changes. + - Updated docs/BuildSystem.md to reflect the new default installation + prefix, as well as mention EXECPREFIX and SHAREDIR. + - Changed the definitions of the UNINSTALL_OLD_* variables in the + top-level Makefile to use $(wildcard ...) instead of 'find'. This + was motivated by the new way of handling prefix and friends, which + leads to the 'find' command being run on /usr/local (by default), + which can take a while almost never yielding any benefit (since the + user will very rarely use the uninstall-old targets). + - Removed periods from the end of descriptive output statements (i.e., + non-verbose output) since those statements often end with file or + directory paths, which get confusing to read when puctuated by a + period. + - Trival change to 'make showconfig' output. + - Removed my name from 'configure --help'. (Many have contributed to it + over the years.) + - In configure script, changed the default state of threading_model + variable from 'no' to 'off' to match that of debug_type, where there + are similarly more than two valid states. ('no' is still accepted + if given via the --enable-debug= option, though it will be + standardized to 'off' prior to config.mk being written out.) + - Minor variable name change in flatten-headers.py that was intended for + 32812ff. + - CREDITS file update. + +commit 32812ff5aba05d34c421fe1024a61f3e2d5e7052 +Author: Field G. Van Zee +Date: Tue Apr 9 12:20:19 2019 -0500 + + Minor bugfix to flatten-headers.py. + + Details: + - Fixed a minor bug in flatten-headers.py whereby the script, upon + encountering a #include directive for the root header file, would + erroneously recurse and inline the conents of that root header. + The script has been modified to avoid recursion into any headers + that share the same name as the root-level header that was passed + into the script. (Note: this bug didn't actually manifest in BLIS, + so it's merely a precaution for usage of flatten-headers.py in other + contexts.) + +commit bec90e0b6aeb3c9b19589c2b700fda2d66f6ccdf +Author: Field G. Van Zee +Date: Tue Apr 2 17:45:13 2019 -0500 + + Minor update to docs/HardwareSupport.md document. + + Details: + - Added more details and clarifying language to implications of 1m and + the recycling of microkernels between microarchitectures. + +commit 89cd650e7be01b59aefaa85885a3ea78970351e4 +Author: Field G. Van Zee +Date: Tue Apr 2 17:23:55 2019 -0500 + + Use void_fp for function pointers instead of void*. + + Change void*-typed function pointers to void_fp. + - Updated all instances of void* variables that store function pointers + to variables of a new type, void_fp. Originally, I wanted to define + the type of void_fp as "void (*void_fp)( void )"--that is, a pointer + to a function with no return value and no arguments. However, once + I did this, I realized that gcc complains with incompatible pointer + type (-Wincompatible-pointer-types) warnings every time any such a + pointer is being assigned to its final, type-accurate function + pointer type. That is, gcc will silently typecast a void* to + another defined function pointer type (e.g. dscalv_ker_ft) during + an assignment from the former to the latter, but the same statement + will trigger a warning when typecasting from a void_fp type. I suspect + an explicit typecast is needed in order to avoid the warning, which + I'm not willing to insert at this time. + - Added a typedef to bli_type_defs.h defining void_fp as void*, along + with a commented-out version of the aborted definition described + above. (Note that POSIX requires that void* and function pointers + be interchangeable; it is the C standard that does not provide this + guarantee.) + - Comment updates to various _oapi.c files. + +commit ffce3d632b284eb52474036096815ec38ca8dd5f +Author: Field G. Van Zee +Date: Tue Apr 2 14:40:50 2019 -0500 + + Renamed armv8a gemm kernel filename. + + Details: + - Renamed + kernels/armv8a/3/bli_gemm_armv8a_opt_4x4.c + to + kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c. + This follows the naming convention used by other kernel sets, most + notably haswell. + +commit 77867478af02144544b4e7b6df5d54d874f3f93b +Author: Isuru Fernando +Date: Tue Apr 2 13:33:11 2019 -0500 + + Use pthreads on MinGW and Cygwin (#307) + +commit 7bc75882f02ce3470a357950878492e87e688cec +Author: Field G. Van Zee +Date: Thu Mar 28 17:40:50 2019 -0500 + + Updated Eigen results in docs/graphs with 3.3.90. + + Details: + - Updated the level-3 performance graphs in docs/graphs with new Eigen + results, this time using a development version cloned from their git + mirror on March 27, 2019 (version 3.3.90). Performance is improved + over 3.3.7, though still noticeably short of BLIS/MKL in most cases. + - Very minor updates to docs/Performance.md and matlab scripts in + test/3/matlab. + +commit 20ea7a1217d3833db89a96158c42da2d6e968ed8 +Author: Field G. Van Zee +Date: Wed Mar 27 18:09:17 2019 -0500 + + Minor text updates (Eigen) to docs/Performance.md. + + Details: + - Added/updated a few more details, mostly regarding Eigen. + +commit bfb7e1bc6af468e4ff22f7e27151ea400dcd318a +Merge: 044df950 2c85e1dd +Author: Field G. Van Zee +Date: Wed Mar 27 17:58:19 2019 -0500 + + Merge branch 'dev' + +commit 2c85e1dd9d5d84da7228ea4ae6deec56a89b3a8f (dev) +Author: Field G. Van Zee +Date: Wed Mar 27 16:29:51 2019 -0500 + + Added Eigen results to performance graphs. + + Details: + - Updated the Haswell, SkylakeX, and Epyc performance graphs in + docs/graphs to report on Eigen implementations, where applicable. + Specifically, Eigen implements all level-3 operations sequentially, + however, of those operations it only provides multithreaded gemm. + Thus, mt results for symm/hemm, syrk/herk, trmm, and trsm are + omitted. Thanks to Sameer Agarwal for his help configuring and + using Eigen. + - Updated docs/Performance.md to note the new implementation tested. + - CREDITS file update. + +commit bfac7e385f8061f2e6591de208b0acf852f04580 +Author: Field G. Van Zee +Date: Wed Mar 27 16:04:48 2019 -0500 + + Added ability to plot with Eigen in test/3/matlab. + + Details: + - Updated matlab scripts in test/3/matlab to optionally plot/display + Eigen performance curves. Whether Eigen is plotted is determined by + a new boolean function parameter, with_eigen. + - Updated runme.m scratchpad to reflect the latest invocations of the + plot_panel_4x5() function (with Eigen plotting enabled). + +commit 67535317b9411c90de7fa4cb5b0fdb8f61fdcd79 +Author: Field G. Van Zee +Date: Wed Mar 27 13:32:18 2019 -0500 + + Fixed mislabeled eigen output from test/3 drivers. + + Details: + - Fixed the Makefile in test/3 so that it no longer incorrectly labels + the matlab output variables from Eigen-linked hemm, herk, trmm, and + trsm driver output as "vendor". (The gemm drivers were already + correctly outputing matlab variables containing the "eigen" label.) + +commit 044df9506f823643c0cdd53e81ad3c27a9f9d4ff +Author: Isuru Fernando +Date: Wed Mar 27 12:39:31 2019 -0500 + + Test with shared on windows (#306) + + Export macros can't support both shared and static at the same time. + When blis is built with both shared and static, headers assume that + shared is used at link time and dllimports the symbols with __imp_ + prefix. + + To use the headers with static libraries a user can give + -DBLIS_EXPORT= to import the symbol without the __imp_ prefix + +commit 5e6b160c8a85e5e23bab0f64958a8acf4918a4ed +Author: Field G. Van Zee +Date: Tue Mar 26 19:10:59 2019 -0500 + + Link to Eigen BLAS for non-gemm drivers in test/3. + + Details: + - Adjusted test/3/Makefile so that the test drivers are linked against + Eigen's BLAS library for hemm, herk, trmm, and trsm. We have to do + this since Eigen's headers don't define implementations to the + standard BLAS APIs. + - Simplified #included headers in hemm, herk, trmm, and trsm source + driver files, since nothing specific to Eigen is needed at + compile-time for those operations. + +commit e593221383aae19dfdc3f30539de80ed05cfec7f +Merge: 92fb9c87 c208b9dc +Author: Field G. Van Zee +Date: Tue Mar 26 15:51:45 2019 -0500 + + Merge branch 'master' into dev + +commit 92fb9c87bf88b9f9c401eeecd9aa9c3521bc2adb +Author: Field G. Van Zee +Date: Tue Mar 26 15:43:23 2019 -0500 + + Add more support for Eigen to drivers in test/3. + + Details: + - Use compile-time implementations of Eigen in test_gemm.c via new + EIGEN cpp macro, defined on command line. (Linking to Eigen's BLAS + library is not necessary.) However, as of Eigen 3.3.7, Eigen only + parallelizes the gemm operation and not hemm, herk, trmm, trsm, or + any other level-3 operation. + - Fixed a bug in trmm and trsm drivers whereby the wrong function + (bli_does_trans()) was being called to determine whether the object + for matrix A should be created for a left- or right-side case. This + was corrected by changing the function to bli_is_left(), as is done + in the hemm driver. + - Added support for running Eigen test drivers from runme.sh. + +commit c208b9dc46852c877197d53b6dd913a046b6ebb6 +Author: Isuru Fernando +Date: Mon Mar 25 13:03:44 2019 -0500 + + Fix clang version detection (#305) + + clang -dumpversion gives 4.2.1 for all clang versions as clang was + originally compatible with gcc 4.2.1 + + Apple clang version and clang version are two different things + and the real clang version cannot be deduced from apple clang version + programatically. Rely on wikipedia to map apple clang to clang version + + Also fixes assembly detection with clang + + clang 3.8 can't build knl as it doesn't recognize zmm0 + +commit feefcab4427a75b0b55af215486b85abcda314f7 +Author: Field G. Van Zee +Date: Thu Mar 21 18:11:20 2019 -0500 + + Allow disabling of BLAS prototypes at compile-time. + + Details: + - Modified bli_blas.h so that: + - By default, if the BLAS layer is enabled at configure-time, BLAS + prototypes are also enabled within blis.h; + - But if the user #defines BLIS_DISABLE_BLAS_DEFS prior to including + blis.h, BLAS prototypes are skipped over entirely so that, for + example, the application or some other header pulled in by the + application may prototype the BLAS functions without causing any + duplication. + - Updated docs/BuildSystem.md to document the feature above, and + related text. + +commit 288843b06d91e1b4fade337959aef773090bd1c9 +Author: Field G. Van Zee +Date: Wed Mar 20 17:52:23 2019 -0500 + + Added Eigen support to test/3 Makefile, runme.sh. + + Details: + - Added targets to test/3/Makefile that link against a BLAS library + build by Eigen. It appears, however, that Eigen's BLAS library does + not support multithreading. (It may be that multithreading is only + available when using the native C++ APIs.) + - Updated runme.sh with a few Eigen-related tweaks. + - Minor tweaks to docs/Performance.md. + +commit 153e0be21d9ff413e370511b68d553dd02abada9 +Author: Field G. Van Zee +Date: Tue Mar 19 17:53:18 2019 -0500 + + More minor tweaks to docs/Performance.md. + + Details: + - Defined GFLOPS as billions of floating-point operations per second, + and reworded the sentence after about normalization. + +commit 05c4e42642cc0c8dbfa94a6c21e975ac30c0517a +Author: Field G. Van Zee +Date: Tue Mar 19 17:07:20 2019 -0500 + + CHANGELOG update (0.5.2) + +commit 9204cd0cb0cc27790b8b5a2deb0233acd9edeb9b (tag: 0.5.2) +Author: Field G. Van Zee +Date: Tue Mar 19 17:07:18 2019 -0500 + + Version file update (0.5.2) + +commit 64560cd9248ebf4c02c4a1eeef958e1ca434e510 +Author: Field G. Van Zee +Date: Tue Mar 19 17:04:20 2019 -0500 + + ReleaseNotes.md update in advance of next version. + + Details: + - Updated ReleaseNotes.md in preparation for next version. + +commit ab5ad557ea69479d487c9a3cb516f43fa1089863 +Author: Field G. Van Zee +Date: Tue Mar 19 16:50:41 2019 -0500 + + Very minor tweaks to Performance.md. + +commit 03c4a25e1aa8a6c21abbb789baa599ac419c3641 +Author: Field G. Van Zee +Date: Tue Mar 19 16:47:15 2019 -0500 + + Minor fixes to docs/Performance.md. + + Details: + - Fixed some incorrect labels associated with the pdf/png graphs, + apparently the result of copy-pasting. + +commit fe6dd8b132f39ecb8893d54cd8e75d4bbf6dab83 +Author: Field G. Van Zee +Date: Tue Mar 19 16:30:23 2019 -0500 + + Fixed broken section links in docs/Performance.md. + + Details: + - Fixed a few broken section links in the Contents section. + +commit 913cf97653f5f9a40aa89a5b79e2b0a8882dd509 +Author: Field G. Van Zee +Date: Tue Mar 19 16:15:24 2019 -0500 + + Added docs/Performance.md and docs/graphs subdir. + + Details: + - Added a new markdown document, docs/Performance.md, which reports + performance of a representative set of level-3 operations across a + variety of hardware architectures, comparing BLIS to OpenBLAS and a + vendor library (MKL on Intel/AMD, ARMPL on ARM). Performance graphs, + in pdf and png formats, reside in docs/graphs. + - Updated README.md to link to new Performance.md document. + - Minor updates to CREDITS, docs/Multithreading.md. + - Minor updates to matlab scripts in test/3/matlab. + +commit 9945ef24fd758396b698b19bb4e23e53b9d95725 +Author: Field G. Van Zee +Date: Tue Mar 19 15:28:44 2019 -0500 + + Adjusted cache blocksizes for zen subconfig. + + Details: + - Adjusted the zen sub-configuration's cache blocksizes for float, + scomplex, and dcomplex based on the existing values for double. + (The previous values were taken directly from the haswell subconfig, + which targets Intel Haswell/Broadwell/Skylake systems.) + +commit d202d008d51251609d08d3c278bb6f4ca9caf8e4 +Author: Field G. Van Zee +Date: Mon Mar 18 18:18:25 2019 -0500 + + Renamed --enable-export-all to --export-shared=[]. + + Details: + - Replaced the existing --enable-export-all / --disable-export-all + configure option with --export-shared=[public|all], with the 'public' + instance of the latter corresponding to --disable-export-all and the + 'all' instance corresponding to --enable-export-all. Nothing else + semantically about the option, or its default, has changed. + +commit ff78089870f714663026a7136e696603b5259560 +Author: Field G. Van Zee +Date: Mon Mar 18 13:22:55 2019 -0500 + + Updates to docs/Multithreading.md. + + Details: + - Made extra explicit the fact that: (a) multithreading in BLIS is + disabled by default; and (b) even with multithreading enabled, the + user must specify multithreading at runtime in order to observe + parallelism. Thanks to M. Zhou for suggesting these clarifications + in #292. + - Also made explicit that only the environment variable and global + runtime API methods are available when using the BLAS API. If the + user wishes to use the local runtime API (specify multithreading on + a per-call basis), one of the native BLIS APIs must be used. + +commit 663f662932c3f182fefc3c77daa1bf8c3394bb8b +Merge: 938c05ef 6bfe3812 +Author: Field G. Van Zee +Date: Sat Mar 16 16:17:12 2019 -0500 + + Merge branch 'amd' of github.com:flame/blis into amd + +commit 938c05ef8654e2fc013d39a57f51d91d40cc40fb +Merge: 4ed39c09 5a5f494e +Author: Field G. Van Zee +Date: Sat Mar 16 16:01:43 2019 -0500 + + Merge branch 'amd' of github.com:flame/blis into amd + +commit 6bfe3812e29b86c95b828822e4e5473b48891167 +Author: Field G. Van Zee +Date: Fri Mar 15 13:57:49 2019 -0500 + + Use -fvisibility=[...] with clang on Linux/BSD/OSX. + + Details: + - Modified common.mk to use the -fvisibility=[hidden|default] option + when compiling with clang on non-Windows platforms (Linux, BSD, OS X, + etc.). Thanks to Isuru Fernando for pointing out this option works + with clang on these OSes. + +commit 809395649c5bbf48778ede4c03c1df705dd49566 +Author: Field G. Van Zee +Date: Wed Mar 13 18:21:35 2019 -0500 + + Annotated additional symbols for export. + + Details: + - Added export annotations to additional function prototypes in order to + accommodate the testsuite. + - Disabled calling bli_amaxv_check() from within the testsuite's + test_amaxv.c. + +commit e095926c643fd9c9c2220ebecd749caae0f71d42 +Author: Field G. Van Zee +Date: Wed Mar 13 17:35:18 2019 -0500 + + Support shared lib export of only public symbols. + + Details: + - Introduced a new configure option, --enable-export-all, which will + cause all shared library symbols to be exported by default, or, + alternatively, --disable-export-all, which will cause all symbols to + be hidden by default, with only those symbols that are annotated for + visibility, via BLIS_EXPORT_BLIS (and BLIS_EXPORT_BLAS for BLAS + symbols), to be exported. The default for this configure option is + --disable-export-all. Thanks to Isuru Fernando for consulting on + this commit. + - Removed BLIS_EXPORT_BLIS annotations from frame/1m/bli_l1m_unb_var1.h, + which was intended for 5a5f494. + - Relocated BLIS_EXPORT-related cpp logic from bli_config.h.in to + frame/include/bli_config_macro_defs.h. + - Provided appropriate logic within common.mk to implement variable + symbol visibility for gcc, clang, and icc (to the extend that each of + these compilers allow). + - Relocated --help text associated with debug option (-d) to configure + slightly further down in the list. + +commit 5a5f494e428372c7c27ed1f14802e15a83221e87 +Author: Field G. Van Zee +Date: Tue Mar 12 18:45:09 2019 -0500 + + Removed export macros from all internal prototypes. + + Details: + - After merging PR #303, at Isuru's request, I removed the use of + BLIS_EXPORT_BLIS from all function prototypes *except* those that we + potentially wish to be exported in shared/dynamic libraries. In other + words, I removed the use of BLIS_EXPORT_BLIS from all prototypes of + functions that can be considered private or for internal use only. + This is likely the last big modification along the path towards + implementing the functionality spelled out in issue #248. Thanks + again to Isuru Fernando for his initial efforts of sprinkling the + export macros throughout BLIS, which made removing them where + necessary relatively painless. Also, I'd like to thank Tony Kelman, + Nathaniel Smith, Ian Henriksen, Marat Dukhan, and Matthew Brett for + participating in the initial discussion in issue #37 that was later + summarized and restated in issue #248. + - CREDITS file update. + +commit 3dc18920b6226026406f1d2a8b2c2b405a2649d5 +Merge: b938c16b 766769ee +Author: Field G. Van Zee +Date: Tue Mar 12 11:20:25 2019 -0500 + + Merge branch 'master' into dev + +commit 766769eeb944bd28641a6f72c49a734da20da755 +Author: Isuru Fernando +Date: Mon Mar 11 19:05:32 2019 -0500 + + Export functions without def file (#303) + + * Revert "restore bli_extern_defs exporting for now" + + This reverts commit 09fb07c350b2acee17645e8e9e1b8d829c73dca8. + + * Remove symbols not intended to be public + + * No need of def file anymore + + * Fix whitespace + + * No need of configure option + + * Remove export macro from definitions + + * Remove blas export macro from definitions + +commit 4ed39c0971c7917e2675cf5449f563b1f4751ccc +Merge: 540ec1b4 b938c16b +Author: Field G. Van Zee +Date: Fri Mar 8 11:56:58 2019 -0600 + + Merge branch 'amd' of github.com:flame/blis into amd + +commit b938c16b0c9e839335ac2c14944b82890143d02f +Author: Field G. Van Zee +Date: Thu Mar 7 16:40:39 2019 -0600 + + Renamed test/3m4m to test/3. + + Details: + - Renamed '3m4m' directory to '3', which captures the directory nicely + since it builds test drivers to test level-3 operations. + - These test drivers ceased to be used to test the 3m and 4m (or even + 1m) induced methods long ago, hence the name change. + +commit ab89a40582ec7acf802e59b0763bed099a02edd8 +Author: Field G. Van Zee +Date: Thu Mar 7 16:26:12 2019 -0600 + + More minor updates and edits to test/3m4m. + + Details: + - Further updates to matlab scripts, mostly for compatibility with + GNU Octave. + - More tweaks to runme.sh. + - Updates to runme.m that allow copy-paste into matlab interactive + session to generate graphs. + +commit f0e70dfbf3fee4c4e382c2c4e87c25454cbc79a1 +Author: Field G. Van Zee +Date: Thu Mar 7 01:04:05 2019 +0000 + + Very minor updates to test/3m4m for ul252. + + Details: + - Very minor updates to the newly revamped test/3m4m drivers when used + on a Xeon Platinum (SkylakeX). + +commit 9f1dbe572b1fd5e7dd30d5649bdf59259ad770d5 +Author: Field G. Van Zee +Date: Tue Mar 5 17:47:55 2019 -0600 + + Overhauled test/3m4m Makefile and scripts. + + Details: + - Rewrote much of Makefile to generate executables for single- and dual- + socket multithreading as well as single-threaded. Each of the three + can also use a different problem size range/increment, as is often + appropriate when doubling/halving the number of threads. + - Rewrote runme.sh script to flexibly execute as many threading + parameter scenarios as is given in the input parameter string + (currently set within the script itself). The string also encodes + the maximum problem size for each threading scenario, which is used + to identify the executable to run. Also improved the "progress" output + of the script to reduce redundant info and improve readability in + terminals that are not especially wide. + - Minor updates to test_*.c source files. + - Updated matlab scripts according to changes made to the Makefile, + test drivers, and runme.sh script, and renamed 'plot_all.m' to + 'runme.m'. + +commit 3bdab823fa93342895bf45d812439324a37db77c +Merge: 70f12f20 e2a02ebd +Author: Field G. Van Zee +Date: Thu Feb 28 14:07:24 2019 -0600 + + Merge branch 'master' into dev + +commit e2a02ebd005503c63138d48a2b7d18978ee29205 +Author: Field G. Van Zee +Date: Thu Feb 28 13:58:59 2019 -0600 + + Updates (from ls5) to test/3m4m/runme.sh. + + Details: + - Lonestar5-specific updates to runme.sh. + +commit f0dcc8944fa379d53770f5cae5d670140918f00c +Author: Isuru Fernando +Date: Wed Feb 27 17:27:23 2019 -0600 + + Add symbol export macro for all functions (#302) + + * initial export of blis functions + + * Regenerate def file for master + + * restore bli_extern_defs exporting for now + +commit 540ec1b479712d5e1da637a718927249c15d867f +Author: Field G. Van Zee +Date: Sun Feb 24 19:09:10 2019 -0600 + + Updated level-3 BLAS to call object API directly. + + Details: + - Updated the BLAS compatibility layer for level-3 operations so that + the corresponding BLIS object API is called directly rather than first + calling the typed BLIS API. The previous code based on the typed BLIS + API calls is still available in a deactivated cpp macro branch, which + may be re-activated by #defining BLIS_BLAS3_CALLS_TAPI. (This does not + yet correspond to a configure option. If it seems like people might + want to toggle this behavior more regularly, a configure option can be + added in the future.) + - Updated the BLIS typed API to statically "pre-initialize" objects via + new initializor macros. Initialization is then finished via calls to + static functions bli_obj_init_finish_1x1() and bli_obj_init_finish(), + which are similar to the previously-called functions, + bli_obj_create_1x1_with_attached_buffer() and + bli_obj_create_with_attached_buffer(), respectively. (The BLAS + compatibility layer updates mentioned above employ this new technique + as well.) + - Transformed certain routines in bli_param_map.c--specifically, the + ones that convert netlib-style parameters to BLIS equivalents--into + static functions, now in bli_param_map.h. (The remaining three classes + of conversation routines were left unchanged.) + - Added the aforementioned pre-initializor macros to bli_type_defs.h. + - Relocated bli_obj_init_const() and bli_obj_init_constdata() from + bli_obj_macro_defs.h to bli_type_defs.h. + - Added a few macros to bli_param_macro_defs.h for testing domains for + real/complexness and precisions for single/double-ness. + +commit 8e023bc914e9b4ac1f13614feb360b105fbe44d2 +Author: Field G. Van Zee +Date: Fri Feb 22 16:55:30 2019 -0600 + + Updates to 3m4m/matlab scripts. + + Details: + - Minor updates to matlab graph-generating scripts. + - Added a plot_all.m script that is more of a scratchpad for copying and + pasting function invocations into matlab to generate plots that are + presently of interest to us. + +commit 70f12f209bc1901b5205902503707134cf2991a0 +Author: Field G. Van Zee +Date: Wed Feb 20 16:10:10 2019 -0600 + + Changed unsafe-loop to unsafe-math optimizations. + + Details: + - Changed -funsafe-loop-optimizations (re-)introduced in 7690855 for + make_defs.mk files' CRVECFLAGS to -funsafe-math-optimizations (to + account for a miscommunication in issue #300). Thanks to Dave Love + for this suggestion and Jeff Hammond for his feedback on the topic. + +commit 7690855c5106a56e5b341a350f8db1c78caacd89 +Author: Field G. Van Zee +Date: Mon Feb 18 19:16:01 2019 -0600 + + Restored -funsafe-loop-optimizations to subconfigs. + + Details: + - Restored use of -funsafe-loop-optimizations in the definitions of + CRVECFLAGS (when using gcc), but only for sub-configurations (and + not configuration families such as amd64, intel64, and x86_64). + This more or less reverts 5190d05 and 6cf1550. + +commit 44994d1490897b08cde52a615a2e37ddae8b2061 +Author: Field G. Van Zee +Date: Mon Feb 18 18:35:30 2019 -0600 + + Disable TBM, XOP, LWP instructions in AMD configs. + + Details: + - Added -mno-tbm -mno-xop -mno-lwp to CKVECFLAGS in bulldozer, + piledriver, steamroller, and excavator configurations to explicitly + disable AMD's bulldozer-era TBM, XOP, and LWP instruction sets in an + attempt to fix the invalid instruction error that has plagued Travis + CI builds since 6a014a3. Thanks to Devin Matthews for pointing out + that the offending instruction was part of TBM (issue #300). + - Restored -O3 to piledriver configuration's COPTFLAGS. + +commit 1e5b530744c1906140d47f43c5cad235eaa619cf +Author: Field G. Van Zee +Date: Mon Feb 18 18:04:38 2019 -0600 + + Reverted piledriver COPTFLAGS from -O3 to -O2. + + Details: + - Debugging continues; changing COPTFLAGS for piledriver subconfig from + -O3 to -O2, its original value prior to 6a014a3. + +commit 6cf155049168652c512aefdd16d74e7ff39b98df +Author: Field G. Van Zee +Date: Mon Feb 18 17:29:51 2019 -0600 + + Removed -funsafe-loop-optimizations from all configs. + + Details: + - Error persists. Removed -funsafe-loop-optimizations from all remaining + sub-configurations. + +commit 5190d05a27c5fa4c7942e20094f76eb9a9785c3e +Author: Field G. Van Zee +Date: Mon Feb 18 17:07:35 2019 -0600 + + Removed -funsafe-loop-optimizations from piledriver. + + Details: + - Error persists; continuing debugging from bf0fb78c by removing + -funsafe-loop-optimizations from piledriver configuration. + +commit bf0fb78c5e575372060d22f5ceeb5b332e8978ec +Author: Field G. Van Zee +Date: Mon Feb 18 16:51:38 2019 -0600 + + Removed -funsafe-loop-optimizations from families. + + Details: + - Removed -funsafe-loop-optimizations from the configuration families + affected by 6a014a3, specifically: intel64, amd64, and x86_64. + This is part of an attempt to debug why the sde, as executed by + Travis CI, is crashing via the following error: + + TID 0 SDE-ERROR: Executed instruction not valid for specified chip + (ICELAKE): 0x9172a5: bextr_xop rax, rcx, 0x103 + +commit 6a014a3377a2e829dbc294b814ca257a2bfcb763 +Author: Field G. Van Zee +Date: Mon Feb 18 14:52:29 2019 -0600 + + Standardized optimization flags in make_defs.mk. + + Details: + - Per Dave Love's recommendation in issue #300, this commit defines + COPTFLAGS := -03 + and + CRVECFLAGS := $(CKVECFLAGS) -funsafe-loop-optimizations + in the make_defs.mk for all Intel- and AMD-based configurations. + +commit 565fa3853b381051ac92cff764625909d105644d +Author: Field G. Van Zee +Date: Mon Feb 18 11:43:58 2019 -0600 + + Redirect trsm pc, ir parallelism to ic, jr loops. + + Details: + - trsm parallelization was temporarily simplifed in 075143d to entirely + ignore any parallelism specified via the pc or ir loops. Now, any + parallelism specified to the pc loop will be redirected to the ic + loop, and any parallelism specified to the ir loop will be redirected + to the jr loop. (Note that because of inter-iteration dependencies, + trsm cannot parallelize the ir loop. Parallelism via the pc loop is + at least somewhat feasible in theory, but it would require tracking + dependencies between blocks--something for which BLIS currently lacks + the necessary supporting infrastructure.) + +commit a023c643f25222593f4c98c2166212561d030621 +Author: Field G. Van Zee +Date: Thu Feb 14 20:18:55 2019 -0600 + + Regenerated symbols in build/libblis-symbols.def. + + Details: + - Reran ./build/regen-symbols.sh after running + 'configure --enable-cblas auto' + +commit 075143dfd92194647da9022c1a58511b20fc11f3 +Author: Field G. Van Zee +Date: Thu Feb 14 18:52:45 2019 -0600 + + Added support for IC loop parallelism to trsm. + + Details: + - Parallelism within the IC loop (3rd loop around the microkernel) is + now supported within the trsm operation. This is done via a new branch + on each of the control and thread trees, which guide execution of a + new trsm-only subproblem from within bli_trsm_blk_var1(). This trsm + subproblem corresponds to the macrokernel computation on only the + block of A that contains the diagonal (labeled as A11 in algorithms + with FLAME-like partitioning), and the corresponding row panel of C. + During the trsm subproblem, all threads within the JC communicator + participate and parallelize along the JR loop, including any + parallelism that was specified for the IC loop. (IR loop parallelism + is not supported for trsm due to inter-iteration dependencies.) After + this trsm subproblem is complete, a barrier synchronizes all + participating threads and then they proceed to apply the prescribed + BLIS_IC_NT (or equivalent) ways of parallelism (and any BLIS_JR_NT + parallelism specified within) to the remaining gemm subproblem (the + rank-k update that is performed using the newly updated row-panel of + B). Thus, trsm now supports JC, IC, and JR loop parallelism. + - Modified bli_trsm_l_cntl_create() to create the new "prenode" branch + of the trsm_l cntl_t tree. The trsm_r tree was left unchanged, for + now, since it is not currently used. (All trsm problems are cast in + terms of left-side trsm.) + - Updated bli_cntl_free_w_thrinfo() to be able to free the newly shaped + trsm cntl_t trees. Fixed a potentially latent bug whereby a cntl_t + subnode is only recursed upon if there existed a corresponding + thrinfo_t node, which may not always exist (for problems too small + to employ full parallelization due to the minimum granularity imposed + by micropanels). + - Updated other functions in frame/base/bli_cntl.c, such as + bli_cntl_copy() and bli_cntl_mark_family(), to recurse on sub-prenodes + if they exist. + - Updated bli_thrinfo_free() to recurse into sub-nodes and prenodes + when they exist, and added support for growing a prenode branch to + bli_thrinfo_grow() via a corresponding set of help functions named + with the _prenode() suffix. + - Added a bszid_t field thrinfo_t nodes. This field comes in handy when + debugging the allocation/release of thrinfo_t nodes, as it helps trace + the "identity" of each nodes as it is created/destroyed. + - Renamed + bli_l3_thrinfo_print_paths() -> bli_l3_thrinfo_print_gemm_paths() + and created a separate bli_l3_thrinfo_print_trsm_paths() function to + print out the newly reconfigured thrinfo_t trees for the trsm + operation. + - Trival changes to bli_gemm_blk_var?.c and bli_trsm_blk_var?.c + regarding variable declarations. + - Removed subpart_t enum values BLIS_SUBPART1T, BLIS_SUBPART1B, + BLIS_SUBPART1L, BLIS_SUBPART1R. Then added support for two new labels + (semantically speaking): BLIS_SUBPART1A and BLIS_SUBPART1B, which + represent the subpartition ahead of and behind, respectively, + BLIS_SUBPART1. Updated check functions in bli_check.c accordingly. + - Shuffled layering/APIs for bli_acquire_mpart_[mn]dim() and + bli_acquire_mpart_t2b/b2t(), _l2r/r2l(). + - Deprecated old functions in frame/3/bli_l3_thrinfo.c. + +commit 78bc0bc8b6b528c79b11f81ea19250a1db7450ed +Author: Nicholai Tukanov +Date: Thu Feb 14 13:29:02 2019 -0600 + + Power9 sub-configuration (#298) + + Formally registered power9 sub-configuration. + + Details: + - Added and registered power9 sub-configuration into the build system. + Thanks to Nicholai Tukanov and Devangi Parikh for these contributions. + - Note: The sub-configuration does not yet have a corresponding + architecture-specific kernel set registered, and so for now the + sub-config is using the generic kernel set. + +commit 6b832731261f9e7ad003a9ea4682e9ca973ef844 +Author: Field G. Van Zee +Date: Tue Feb 12 16:01:28 2019 -0600 + + Generalized ref kernels' pragma omp simd usage. + + Details: + - Replaced direct usage of _Pragma( "omp simd" ) in reference kernels + with PRAGMA_SIMD, which is defined as a function of the compiler being + used in a new bli_pragma_macro_defs.h file. That definition is cleared + when BLIS detects that the -fopenmp-simd command line option is + unsupported. Thanks to Devin Matthews and Jeff Hammond for suggestions + that guided this commit. + - Updated configure and bli_config.h.in so that the appropriate anchor + is substituted in (when the corresponding pragma omp simd support is + present). + +commit b1f5ce8622b682b79f956fed83f04a60daa8e0fc +Author: Field G. Van Zee +Date: Tue Feb 5 17:38:50 2019 -0600 + + Minor updates to scripts in test/mixeddt/matlab. + +commit 38203ecd15b1fa50897d733daeac6850d254e581 +Author: Devangi N. Parikh +Date: Mon Feb 4 15:28:28 2019 -0500 + + Added thunderx2 system in the mixeddt test scripts + + Details: + - Added thunderx2 (tx2) as a system in the runme.sh in test/mixeddt + +commit dfc91843ea52297bf636147793029a0c1345be04 +Author: Devangi N. Parikh +Date: Mon Feb 4 15:23:40 2019 -0500 + + Fixed gcc flags for thunderx2 subconfiguration + + Details: + - Fixed -march flag. Thunderx2 is an armv8.1a architecture not armv8a. + +commit c665eb9b888ec7e41bd0a28c4c8ac4094d0a01b5 +Author: Field G. Van Zee +Date: Mon Jan 28 16:22:23 2019 -0600 + + Minor updates to docs, Makefiles. + + Details: + - Changed all occurrances of + micro-kernel -> microkernel + macro-kernel -> macrokernel + micro-panel -> micropanel + in all markdown documents in 'docs' directory. This change is being + made since we've reached the point in adoption and acceptance of + BLIS's insights where words such as "microkernel" are no longer new, + and therefore now merit being unhyphenated. + - Updated "Implementation Notes" sections of KernelsHowTo.md, which + still contained references to nonexistent cpp macros such as + BLIS_DEFAULT_MR_? and BLIS_PACKDIM_MR_?. + - Added 'run-fast' and 'check-fast' targets to testsuite/Makefile. + - Minor updates to Testsuite.md, including suggesting use of + 'make check' and 'make check-fast' when running from the local + testsuite directory. + - Added a comment to top-level Makefile explaining the purpose behind + the TESTSUITE_WRAPPER variable, which at first glance appears to serve + no purpose. + +commit 1aa280d0520ed5eaea3b119b4e92b789ecad78a4 +Author: M. Zhou <5723047+cdluminate@users.noreply.github.com> +Date: Sun Jan 27 21:40:48 2019 +0000 + + Amend OS detection for kFreeBSD. (#295) + +commit fffc23bb35d117a433886eb52ee684ff5cf6997f +Author: Field G. Van Zee +Date: Fri Jan 25 13:35:31 2019 -0600 + + CREDITS file update. + +commit 26c5cf495ce22521af5a36a1012491213d5a4551 +Author: Field G. Van Zee +Date: Thu Jan 24 18:49:31 2019 -0600 + + Fixed bug in skx subconfig related to bdd46f9. + + Details: + - Fixed code in the skx subconfiguration that became a bug after + committing bdd46f9. Specifically, the bli_cntx_init_skx() function + was overwriting default blocksizes for the scomplex and dcomplex + microkernels despite the fact that only single and double real + microkernels were being registered. This was not a problem prior to + bdd46f9 since all microkernels used dynamically-queried (at runtime) + register blocksizes for loop bounds. However, post-bdd46f9, this + became a bug because the reference ukernels for scomplex and dcomplex + were written with their register blocksizes hard-coded as constant + loop bounds, which conflicted the the erroneous scomplex and dcomplex + values that bli_cntx_init_skx() was setting in the context. The + lesson here is that going forward, all subconfigurations must not set + any blocksizes for datatypes corresponding to default/reference + microkernels. (Note that a blocksize is left unchanged by the + bli_cntx_set_blkszs() function if it was set to -1.) + +commit 180f8e42e167b83a757340ad4bd4a5c7a1d6437b +Author: Field G. Van Zee +Date: Thu Jan 24 18:01:15 2019 -0600 + + Fixed undefined behavior trsm ukr bug in bdd46f9. + + Details: + - Fixed a bug that mainfested anytime a configuration was used in which + optimized microkernels were registered and the trsm operation (or + kernel) was invoked. The bug resulted from the optimized microkernels' + register blocksizes conflicting with the hard-coded values--expressed + in the form of constant loop bounds--used in the new reference trsm + ukernels that were introduced in bdd46f9. The fix was easy: reverting + back to the implementation that uses variable-bound loops, which + amounted to changing an #if 0 to #if 1 (since I preserved the older + implementation in the file alongside the new code based on constant- + bound loops). It should be noted that this fix must be permanent, + since the trsm kernel code with constant-bound loops can never work + with gemm ukernels that use different register blocksizes. + +commit bdd46f9ee88057d52610161966a11c224e5a026c +Author: Field G. Van Zee +Date: Thu Jan 24 17:23:18 2019 -0600 + + Rewrote reference kernels to use #pragma omp simd. + + Details: + - Rewrote level-1v, -1f, and -3 reference kernels in terms of simplified + indexing annotated by the #pragma omp simd directive, which a compiler + can use to vectorize certain constant-bounded loops. (The new kernels + actually use _Pragma("omp simd") since the kernels are defined via + templatizing macros.) Modest speedup was observed in most cases using + gcc 5.4.0, which may improve with newer versions. Thanks to Devin + Matthews for suggesting this via issue #286 and #259. + - Updated default blocksizes defined in ref_kernels/bli_cntx_ref.c to + be 4x16, 4x8, 4x8, and 4x4 for single, double, scomplex and dcomplex, + respectively, with a default row preference for the gemm ukernel. Also + updated axpyf, dotxf, and dotxaxpyf fusing factors to 8, 6, and 4, + respectively, for all datatypes. + - Modified configure to verify that -fopenmp-simd is a valid compiler + option (via a new detect/omp_simd/omp_simd_detect.c file). + - Added a new header in which prefetch macros are defined according to + which compiler is detected (via macros such as __GNUC__). These + prefetch macros are not yet employed anywhere, though. + - Updated the year in copyrights of template license headers in + build/templates and removed AMD as a default copyright holder. + +commit 63de2b0090829677755eb5cdb27e73bc738da32d +Author: Field G. Van Zee +Date: Wed Jan 23 12:16:27 2019 -0600 + + Prevent redef of ftnlen in blastest f2c_types.h. + + Details: + - Guard typedef of ftnlen in f2c_types.h with a #ifndef HAVE_BLIS_H + directive to prevent the redefinition of that type. Thanks to Jeff + Diamond for reporting this compiler warning (and apologies for the + delay in committing a fix). + +commit eec2e183a7b7d67702dbd1f39c153f38148b2446 +Author: Field G. Van Zee +Date: Mon Jan 21 12:12:18 2019 -0600 + + Added escaping to '/' in os_name in configure. + + Details: + - Add os_name to the list of variables into which the '/' character is + escaped. This is meant to address (or at least make progress toward + addressing) #293. Thanks to Isuru Fernando for spotting this as the + potential fix, and also thanks to M. Zhou for the original report. + +commit adf5c17f0839fdbc1f4a1780f637928b1e78e389 +Author: Field G. Van Zee +Date: Fri Jan 18 15:14:45 2019 -0600 + + Formally registered thunderx2 subconfiguration. + + Details: + - Added a separate subconfiguration for thunderx2, which now uses + different optimization flags than cortexa57/cortexa53. + +commit 094cfdf7df6c2764c25fcbfce686ba29b933942c +Author: M. Zhou <5723047+cdluminate@users.noreply.github.com> +Date: Fri Jan 18 18:46:13 2019 +0000 + + Port BLIS to GNU Hurd OS. (#294) + + Prevent blis.h from misidentifying Hurd as OSX. + +commit 5d7d616e8e591c2f3c7c2d73220eb27ea484f9c9 +Author: Field G. Van Zee +Date: Tue Jan 15 20:52:51 2019 -0600 + + README.md update re: mixeddt TOMS paper. + +commit 58c7fb4788177487f73a3964b7a910fe4dc75941 +Author: Field G. Van Zee +Date: Tue Jan 8 17:00:27 2019 -0600 + + Added more matlab scripts for mixeddt paper. + + Details: + - Added a variant set of matlab scripts geared to producing plots that + reflect performance data gathered with and without extra memory + optimizations enabled. These scripts reside (for now) in + test/mixeddt/matlab/wawoxmem. + +commit 34286eb914b48b56cdda4dfce192608b9f86d053 +Author: Field G. Van Zee +Date: Tue Jan 8 11:41:20 2019 -0600 + + Minor update to docs/HardwareSupport.md. + +commit 108b04dc5b1b1288db95f24088d1e40407d7bc88 +Author: Field G. Van Zee +Date: Mon Jan 7 20:16:31 2019 -0600 + + Regenerated symbols in build/libblis-symbols.def. + + Details: + - Reran ./build/regen-symbols.sh after running + 'configure --enable-cblas auto' to reflect removal of + bli_malloc_pool() and bli_free_pool(). + +commit 706cbd9d5622f4690e6332a89cf41ab5c8771899 +Author: Field G. Van Zee +Date: Mon Jan 7 18:28:19 2019 -0600 + + Minor tweaks/cleanups to bli_malloc.c, _apool.c. + + Details: + - Removed malloc_ft and free_ft function pointer arguments from the + interface to bli_apool_init() after deciding that there is no need to + specify the malloc()/free() for blocks within the apool. (The apool + blocks are actually just array_t structs.) Instead, we simply call + bli_malloc_intl()/_free_intl() directly. This has the added benefit + of allowing additional output when memory tracing is enabled via + --enable-mem-tracing. Also made corresponding changes elsewhere in + the apool API. + - Changed the inner pools (elements of the array_t within the apool_t) + to use BLIS_MALLOC_POOL and BLIS_FREE_POOL instead of BLIS_MALLOC_INTL + and BLIS_FREE_INTL. + - Disabled definitions of bli_malloc_pool() and bli_free_pool() since + there are no longer any consumers of these functions. + - Very minor comment / printf() updates. + +commit 579145039d945adbcad1177b1d53fb2d3f2e6573 +Author: Minh Quan Ho <1337056+hominhquan@users.noreply.github.com> +Date: Mon Jan 7 23:00:15 2019 +0100 + + Initialize error messages at compile time (#289) + + * Initialize error messages at compile time + + - Assigning strings directly to the bli_error_string array, instead of + snprintf() at execution-time. + + * Retired bli_error_init(), _finalize(). + + Details: + - Removed functions obviated by changes in 80e8dc6: bli_error_init(), + bli_error_finalize(), and bli_error_init_msgs(), as well as calls to + the former two in bli_init.c. + + * Regenerated symbols in build/libblis-symbols.def. + + Details: + - Reran ./build/regen-symbols.sh after running + 'configure --enable-cblas auto'. + +commit aafbca086e36b6727d7be67e21fef5bd9ff7bfd9 +Author: Field G. Van Zee +Date: Mon Jan 7 12:38:21 2019 -0600 + + Updated external package language in README.md. + + Details: + - Updated/added comments about Fedora, OpenSUSE, and GNU Guix under the + newly-renamed "External GNU/Linux packages" section. Thanks to Dave + Love for providing these revisions. + +commit daacfe68404c9cc8078e5e7ba49a8c7d93e8cda3 +Author: Field G. Van Zee +Date: Mon Jan 7 12:12:47 2019 -0600 + + Allow running configure with python 3.4. + + Details: + - Relax version blacklisting of python3 to allow 3.4 or later instead + of 3.5 or later. Thanks to Dave Love for pointing out that 3.4 was + sufficient for the purpose of BLIS's build system. (It should be + noted that we're not sure which, if any, python3 versions prior to + 3.4 are insufficient, and that the only thing stopping us from + determining this is the fact that these earlier versions of python3 + are not readily available for us to test with.) + - Updated docs/BuildSystem.md to be explicit about current python2 vs + python3 version requirements. + +commit ad8d9adb09a7dd267bbdeb2bd1fbbf9daf64ee76 +Author: Field G. Van Zee +Date: Thu Jan 3 16:08:24 2019 -0600 + + README.md, CREDITS update. + + Details: + - Added "What's New" and "What People Are Saying About BLIS" sections to + README.md. + - Added missing github handles to various individuals' entries in the + CREDITS file. + +commit 7052fca5aef430241278b67d24cef6fe33106904 +Author: Field G. Van Zee +Date: Wed Jan 2 13:48:40 2019 -0600 + + Apply f272c289 to bli_fmalloc_noalign(). + + Details: + - Perform the same check for NULL return values and error message output + in bli_fmalloc_noalign() as is performed by bli_fmalloc_align(). (This + change was intended for f272c289.) + +commit 528e3ad16a42311a852a8376101959b4ccd801a5 +Merge: 3126c52e f272c289 +Author: Field G. Van Zee +Date: Wed Jan 2 13:39:19 2019 -0600 + + Merge branch 'amd' + +commit 3126c52ea795ffb7d30b16b7f7ccc2a288a6158d +Merge: 61441b24 8091998b +Author: Field G. Van Zee +Date: Wed Jan 2 13:37:37 2019 -0600 + + Merge branch 'amd' + +commit f272c2899a6764eedbe05cea874ee3bd258dbff3 +Author: Field G. Van Zee +Date: Wed Jan 2 12:34:15 2019 -0600 + + Add error message to malloc() check for NULL. + + Details: + - Output an error message if and when the malloc()-equivalent called by + bli_fmalloc_align() ever returns NULL. Everything was already in place + for this to happen, including the error return code, the error string + sprintf(), the error checking function bli_check_valid_malloc_buf() + definition, and its prototype. Thanks to Minh Quan Ho for pointing out + the missing error message. + - Increased the default block_ptrs_len for each inner pool stored in the + small block allocator from 10 to 25. Under normal execution, each + thread uses only 21 blocks, so this change will prevent the sba from + needing to resize the block_ptrs array of any given inner pool as + threads initially populate the pool with small blocks upon first + execution of a level-3 operation. + - Nix stray newline echo in configure. + +commit eb97f778a1e13ee8d3b3aade05e479c4dfcfa7c0 +Author: Field G. Van Zee +Date: Tue Dec 25 20:17:09 2018 -0600 + + Added missing AMD copyrights to previous commit. + + Details: + - Forgot to add AMD copyrights to several touched files that did not + already have them in 2f31743. + +commit 2f3174330fb29164097d664b7c84e05c7ced7d95 +Author: Field G. Van Zee +Date: Tue Dec 25 19:35:01 2018 -0600 + + Implemented a pool-based small block allocator. + + Details: + - Implemented a sophisticated data structure and set of APIs that track + the small blocks of memory (around 80-100 bytes each) used when + creating nodes for control and thread trees (cntl_t and thrinfo_t) as + well as thread communicators (thrcomm_t). The purpose of the small + block allocator, or sba, is to allow the library to transition into a + runtime state in which it does not perform any calls to malloc() or + free() during normal execution of level-3 operations, regardless of + the threading environment (potentially multiple application threads + as well as multiple BLIS threads). The functionality relies on a new + data structure, apool_t, which is (roughly speaking) a pool of + arrays, where each array element is a pool of small blocks. The outer + pool, which is protected by a mutex, provides separate arrays for each + application thread while the arrays each handle multiple BLIS threads + for any given application thread. The design minimizes the potential + for lock contention, as only concurrent application threads would + need to fight for the apool_t lock, and only if they happen to begin + their level-3 operations at precisely the same time. Thanks to Kiran + Varaganti and AMD for requesting this feature. + - Added a configure option to disable the sba pools, which are enabled + by default; renamed the --[dis|en]able-packbuf-pools option to + --[dis|en]able-pba-pools; and rewrote the --help text associated with + this new option and consolidated it with the --help text for the + option associated with the sba (--[dis|en]able-sba-pools). + - Moved the membrk field from the cntx_t to the rntm_t. We now pass in + a rntm_t* to the bli_membrk_acquire() and _release() APIs, just as we + do for bli_sba_acquire() and _release(). + - Replaced all calls to bli_malloc_intl() and bli_free_intl() that are + used for small blocks with calls to bli_sba_acquire(), which takes a + rntm (in addition to the bytes requested), and bli_sba_release(). + These latter two functions reduce to the former two when the sba pools + are disabled at configure-time. + - Added rntm_t* arguments to various cntl_t and thrinfo_t functions, as + required by the new usage of bli_sba_acquire() and _release(). + - Moved the freeing of "old" blocks (those allocated prior to a change + in the block_size) from bli_membrk_acquire_m() to the implementation + of the pool_t checkout function. + - Miscellaneous improvements to the pool_t API. + - Added a block_size field to the pblk_t. + - Harmonized the way that the trsm_ukr testsuite module performs packing + relative to that of gemmtrsm_ukr, in part to avoid the need to create + a packm control tree node, which now requires a rntm_t that has been + initialized with an sba and membrk. + - Re-enable explicit call bli_finalize() in testsuite so that users who + run the testsuite with memory tracing enabled can check for memory + leaks. + - Manually imported the compact/minor changes from 61441b24 that cause + the rntm to be copied locally when it is passed in via one of the + expert APIs. + - Reordered parameters to various bli_thrcomm_*() functions so that the + thrcomm_t* to the comm being modified is last, not first. + - Added more descriptive tracing for allocating/freeing small blocks and + formalized via a new configure option: --[dis|en]able-mem-tracing. + - Moved some unused scalm code and headers into frame/1m/other. + - Whitespace changes to bli_pthread.c. + - Regenerated build/libblis-symbols.def. + +commit 61441b24f3244a4b202c29611a4899dd5c51d3a1 +Author: Field G. Van Zee +Date: Thu Dec 20 19:38:11 2018 -0600 + + Make local copy of user's rntm_t in level-3 ops. + + Details: + - In the case that the caller passes in a non-NULL rntm_t pointer into + one of the expert APIs for a level-3 operation (e.g. bli_gemm_ex()), + make a local copy of the rntm_t and use the address of that local copy + in all subsequent execution (which may change the contents of the + rntm_t). This prevents a potentially confusing situation whereby a + user-initialized rntm_t is used once (in, say, gemm), and then found + by the user to be in a different state before it is used a second + time. + +commit e809b5d2f1023b4249969e2f516291c9a3a00b80 +Merge: 76016691 0476f706 +Author: Field G. Van Zee +Date: Thu Dec 20 16:27:26 2018 -0600 + + Merge branch 'master' into amd + +commit 0476f706b93e83f6b74a3d7b7e6e9cc9a1a52c3b +Author: Field G. Van Zee +Date: Tue Dec 18 14:56:20 2018 -0600 + + CHANGELOG update (0.5.1) + +commit e0408c3ca3d53bc8e6fedac46ea42c86e06c922d (tag: 0.5.1) Author: Field G. Van Zee Date: Tue Dec 18 14:56:16 2018 -0600 Version file update (0.5.1) -commit 3ab231afc9f69d14493908c53c85a84c5fba58aa (origin/master, origin/HEAD) +commit 3ab231afc9f69d14493908c53c85a84c5fba58aa Author: Field G. Van Zee Date: Tue Dec 18 14:53:37 2018 -0600 @@ -53,6 +1746,55 @@ Date: Mon Dec 17 19:17:30 2018 -0600 OpenMP. - CREDITS file update. +commit 76016691e2c514fcb59f940c092475eda968daa2 +Author: Field G. Van Zee +Date: Thu Dec 13 17:23:09 2018 -0600 + + Improvements to bli_pool; malloc()/free() tracing. + + Details: + - Added malloc_ft and free_ft fields to pool_t, which are provided when + the pool is initialized, to allow bli_pool_alloc_block() and + bli_pool_free_block() to call bli_fmalloc_align()/bli_ffree_align() + with arbitrary align_size values (according to how the pool_t was + initialized). + - Added a block_ptrs_len argument to bli_pool_init(), which allows the + caller to specify an initial length for the block_ptrs array, which + previously suffered the cost of being reallocated, copied, and freed + each time a new block was added to the pool. + - Consolidated the "buf_sys" and "buf_align" pointer fields in pblk_t + into a single "buf" field. Consolidated the bli_pblk API accordingly + and also updated the bli_mem API implementation. This was done + because I'd previously already implemented opaque alignment via + bli_malloc_align(), which allocates extra space and stores the + original pointer returned by malloc() one element before the element + whose address is aligned. + - Tweaked bli_membrk_acquire_m() and bli_membrk_release() to call + bli_fmalloc_align() and bli_ffree_align(), which required adding an + align_size field to the membrk_t struct. + - Pass the pack schemas directly into bli_l3_cntl_create_if() rather + than transmit them via objects for A and B. + - Simplified bli_l3_cntl_free_if() and renamed to bli_l3_cntl_free(). + The function had not been conditionally freeing control trees for + quite some time. Also, removed obj_t* parameters since they aren't + needed anymore (or never were). + - Spun-off OpenMP nesting code in bli_l3_thread_decorator() to a + separate function, bli_l3_thread_decorator_thread_check(). + - Renamed: + bli_malloc_align() -> bli_fmalloc_align() + bli_free_align() -> bli_ffree_align() + bli_malloc_noalign() -> bli_fmalloc_noalign() + bli_free_noalign() -> bli_ffree_noalign() + The 'f' is for "function" since they each take a malloc_ft or free_ft + function pointer argument. + - Inserted various printf() calls for the purposes of tracing memory + allocation and freeing, guarded by cpp macro ENABLE_MEM_DEBUG, which, + for now, is intended to be a "hidden" feature rather than one hooked + up to a configure-time option. + - Defined bli_rntm_equals(), which compares two rntm_t for equality. + (There are no use cases for this function yet, but there may be soon.) + - Whitespace changes to function parameter lists in bli_pool.c, .h. + commit f808d829c58dc4194cc3ebc3825fbdde12cd3f93 Author: Field G. Van Zee Date: Wed Dec 12 15:22:59 2018 -0600 @@ -105,6 +1847,13 @@ Date: Wed Dec 12 15:22:59 2018 -0600 - Fixed a minor bug in the testsuite that prevented non-1m-based induced method implementations of trsm from executing. +commit 02ec0be3ba0b0d6b4186386ae140906a96de919b +Merge: e275def3 c534da62 +Author: Field G. Van Zee +Date: Wed Dec 5 19:33:53 2018 -0600 + + Merge branch 'master' into amd + commit c534da62c0015f91391983da5376c9e091378010 Author: Field G. Van Zee Date: Wed Dec 5 15:51:05 2018 -0600 @@ -149,7 +1898,7 @@ Date: Wed Dec 5 20:06:32 2018 +0000 (That is, when native complex microkernels are missing, we usually want to test performance of 1m.) -commit 0645f239fbdf37ee9d2096ee3bb0e76b3302cfff (origin/dev, dev) +commit 0645f239fbdf37ee9d2096ee3bb0e76b3302cfff Author: Field G. Van Zee Date: Tue Dec 4 14:31:06 2018 -0600 @@ -238,6 +1987,13 @@ Date: Mon Dec 3 17:49:52 2018 -0600 frame/3/gemm/ind/bli_gemm_ind_opt.h. - Various whitespace/comment updates. +commit e275def30ac41cadce296560fa67282704f20a02 +Merge: 8091998b dc184095 +Author: Field G. Van Zee +Date: Fri Nov 30 15:39:50 2018 -0600 + + Merge branch 'master' into amd + commit dc18409551f341125169fe8d4d43ac45e81bdf28 Author: Field G. Van Zee Date: Wed Nov 28 11:58:40 2018 -0600 @@ -489,6 +2245,13 @@ Date: Wed Nov 14 13:47:45 2018 -0600 Isuru Fernando for suggesting this fix, and also to Costas Yamin for originally reporting the issue (#277). +commit 8091998b6500e343c2024561c2b1aa73c3bafb0b +Merge: 333d8562 7b5ba731 +Author: Field G. Van Zee +Date: Wed Nov 14 12:36:35 2018 -0600 + + Merge branch 'master' into amd + commit 7b5ba7319b3901ad0e6c6b4fa3c1d96b579efbe9 Merge: ce719f81 52392932 Author: Field G. Van Zee @@ -548,6 +2311,18 @@ Date: Tue Nov 13 13:03:15 2018 -0600 datatype contains a different value. Thanks to Devangi Parikh for helping in isolating this bug. +commit 333d8562f04eea0676139a10cb80a97f107b45b0 +Author: Field G. Van Zee +Date: Sun Nov 11 14:28:53 2018 -0600 + + Added debug output to bli_malloc.c. + + Details: + - Added debug output to bli_malloc.c in order to debug certain kinds of + memory behavior in BLIS. The printf() statements are disabled and must + be enabled manually. + - Whitespace/comment updates in bli_membrk.c. + commit ce719f816d1237f5277527d7f61123e77180be54 Author: Field G. Van Zee Date: Sat Nov 10 14:48:43 2018 -0600 @@ -5671,7 +7446,7 @@ Date: Fri Feb 23 17:42:48 2018 -0600 Version file update (0.3.0) -commit 3defc7265c12cf85e9de2d7a1f243c5e090a6f9d +commit 3defc7265c12cf85e9de2d7a1f243c5e090a6f9d (origin/master, origin/HEAD) Author: Field G. Van Zee Date: Fri Feb 23 17:38:19 2018 -0600 @@ -5707,7 +7482,7 @@ Date: Fri Feb 23 16:33:32 2018 -0600 contained. To remedy this situation, we now selectively use movss to load any element that could be the last element in the matrix. -commit 5112e1859e7f8888f5555eb7bc02bd9fab9b4442 (origin/rt) +commit 5112e1859e7f8888f5555eb7bc02bd9fab9b4442 (origin/rt, rt) Author: Field G. Van Zee Date: Fri Feb 23 14:31:26 2018 -0600 @@ -5939,7 +7714,7 @@ Date: Thu Jan 4 20:51:35 2018 -0600 time hardware detection (when clang is selected). - Added some missing (but mostly-optional) quotes to configure script. -commit 5a7005dd44ed3174abbe360981e367fd41c99b4b +commit 5a7005dd44ed3174abbe360981e367fd41c99b4b (origin/amd, amd) Merge: 7be88705 3bc99a96 Author: Nisanth M P Date: Wed Jan 3 12:05:12 2018 +0530 @@ -5988,7 +7763,7 @@ Date: Sat Dec 23 15:32:03 2017 -0600 is used by the auto-detection script to printf() the name of the sub-configuration corresponding to the detected hardware. -commit 9804adfd405056ec332bb8e13d68c7b52bd3a6c1 (origin/selfinit) +commit 9804adfd405056ec332bb8e13d68c7b52bd3a6c1 (origin/selfinit, selfinit) Author: Field G. Van Zee Date: Thu Dec 21 19:22:57 2017 -0600 diff --git a/CREDITS b/CREDITS index 6c2eaa990..0d33f534f 100644 --- a/CREDITS +++ b/CREDITS @@ -9,18 +9,22 @@ The BLIS framework was primarily authored by but many others have contributed code and feedback, including + Sameer Agarwal @sandwichmaker (Google) Murtaza Ali (Texas Instruments) Sajid Ali @s-sajid-ali (Northwestern University) Erling Andersen @erling-d-andersen Alex Arslan @ararslan Vernon Austel (IBM, T.J. Watson Research Center) + Matthew Brett @matthew-brett (University of Birmingham) Jed Brown @jedbrown (Argonne National Laboratory) Robin Christ @robinchrist Kay Dewhurst @jkd2016 (Max Planck Institute, Halle, Germany) Jeff Diamond (Oracle) Johannes Dieterich @iotamudelta Krzysztof Drewniak @krzysz00 + Marat Dukhan @Maratyszcza (Google) Victor Eijkhout @VictorEijkhout (Texas Advanced Computing Center) + Evgeny Epifanovsky @epifanovsky (Q-Chem) Isuru Fernando @isuruf Roman Gareev @gareevroman Richard Goldschmidt @SuperFluffy @@ -30,7 +34,7 @@ but many others have contributed code and feedback, including Jeff Hammond @jeffhammond (Intel) Jacob Gorm Hansen @jacobgorm Jean-Michel Hautbois @jhautbois - Ian Henriksen @insertinterestingnamehere + Ian Henriksen @insertinterestingnamehere (The University of Texas at Austin) Minh Quan Ho @hominhquan Matthew Honnibal @honnibal Stefan Husmann @stefanhusmann @@ -53,6 +57,7 @@ but many others have contributed code and feedback, including Ilya Polkovnichenko Jack Poulson @poulson (Stanford) Mathieu Poumeyrol @kali + Christos Psarras @ChrisPsa (RWTH-Aachen) @qnerd Michael Rader @mrader1248 Pradeep Rao @pradeeptrgit (AMD) @@ -63,11 +68,13 @@ but many others have contributed code and feedback, including Rene Sitt Tony Skjellum @tonyskjellum (The University of Tennessee at Chattanooga) Mikhail Smelyanskiy (Intel, Parallel Computing Lab) + Nathaniel Smith @njsmith Shaden Smith @ShadenSmith Tyler Smith @tlrmchlsmth (The University of Texas at Austin) Paul Springer @springer13 (RWTH-Aachen) Vladimir Sukarev Santanu Thangaraj (AMD) + Nicholai Tukanov @nicholaiTukanov (The University of Texas at Austin) Rhys Ulerich @RhysU (The University of Texas at Austin) Robert van de Geijn @rvdg (The University of Texas at Austin) Kiran Varaganti @kvaragan (AMD) @@ -83,8 +90,10 @@ partners, including AMD Hewlett Packard Enterprise + Huawei Intel Microsoft + Oracle Texas Instruments as well as the National Science Foundation (NSF Awards CCF-0917167, diff --git a/Makefile b/Makefile index 6483b0de4..5f0c86034 100644 --- a/Makefile +++ b/Makefile @@ -386,23 +386,22 @@ ifeq ($(IS_CONFIGURED),yes) # named with three .so version numbers. UNINSTALL_OLD_LIBS := -UNINSTALL_OLD_LIBS += $(shell $(FIND) $(INSTALL_LIBDIR)/ -name "$(LIBBLIS_SO).?.?.?" 2> /dev/null | $(GREP) -v "$(LIBBLIS).$(LIBBLIS_SO_MMB_EXT)") +UNINSTALL_OLD_LIBS += $(filter-out $(INSTALL_LIBDIR)/$(LIBBLIS).$(LIBBLIS_SO_MMB_EXT),$(wildcard $(INSTALL_LIBDIR)/$(LIBBLIS_SO).?.?.?)) # These shell commands gather the filepaths to any library symlink in the # current LIBDIR that might be left over from an old installation. We start # with symlinks named using the .so major version number. -UNINSTALL_OLD_SYML := $(shell $(FIND) $(INSTALL_LIBDIR)/ -name "$(LIBBLIS_SO).?" 2> /dev/null | $(GREP) -v "$(LIBBLIS_SO).$(SO_MAJOR)") +UNINSTALL_OLD_SYML := $(filter-out $(INSTALL_LIBDIR)/$(LIBBLIS_SO).$(SO_MAJOR),$(wildcard $(INSTALL_LIBDIR)/$(LIBBLIS_SO).?)) # We also prepare to uninstall older-style symlinks whose names contain the # BLIS version number and configuration family. -UNINSTALL_OLD_SYML += $(shell $(FIND) $(INSTALL_LIBDIR)/ -name "$(LIBBLIS)-*.a" 2> /dev/null | $(GREP) -v "$(LIBBLIS)-$(VERS_CONF).a") - -UNINSTALL_OLD_SYML += $(shell $(FIND) $(INSTALL_LIBDIR)/ -name "$(LIBBLIS)-*.$(SHLIB_EXT)" 2> /dev/null | $(GREP) -v "$(LIBBLIS)-$(VERS_CONF).$(SHLIB_EXT)") +UNINSTALL_OLD_SYML += $(wildcard $(INSTALL_LIBDIR)/$(LIBBLIS)-*.a) +UNINSTALL_OLD_SYML += $(wildcard $(INSTALL_LIBDIR)/$(LIBBLIS)-*.$(SHLIB_EXT)) # This shell command grabs all files named "*.h" that are not blis.h or cblas.h # in the installation directory. We consider this set of headers to be "old" and # eligible for removal upon running of the uninstall-old-headers target. -UNINSTALL_OLD_HEADERS := $(shell $(FIND) $(INSTALL_INCDIR)/blis/ -name "*.h" 2> /dev/null | $(GREP) -v "$(BLIS_H)" | $(GREP) -v "$(CBLAS_H)") +UNINSTALL_OLD_HEADERS := $(filter-out $(BLIS_H),$(filter-out $(CBLAS_H),$(wildcard $(INSTALL_INCDIR)/blis/*.h))) endif # IS_CONFIGURED @@ -1027,23 +1026,24 @@ endif # ifeq ($(IS_WIN),no) # --- Query current configuration --- showconfig: check-env - @echo "configuration family: $(CONFIG_NAME)" - @echo "sub-configurations: $(CONFIG_LIST)" - @echo "requisite kernels: $(KERNEL_LIST)" - @echo "kernel-to-config map: $(KCONFIG_MAP)" - @echo "-----------------------" - @echo "BLIS version string: $(VERSION)" - @echo ".so major version: $(SO_MAJOR)" - @echo ".so minor.build vers: $(SO_MINORB)" - @echo "install libdir: $(INSTALL_LIBDIR)" - @echo "install includedir: $(INSTALL_INCDIR)" - @echo "debugging status: $(DEBUG_TYPE)" - @echo "multithreading status: $(THREADING_MODEL)" - @echo "enable BLAS API? $(MK_ENABLE_BLAS)" - @echo "enable CBLAS API? $(MK_ENABLE_CBLAS)" - @echo "build static library? $(MK_ENABLE_STATIC)" - @echo "build shared library? $(MK_ENABLE_SHARED)" - @echo "ARG_MAX hack enabled? $(ARG_MAX_HACK)" + @echo "configuration family: $(CONFIG_NAME)" + @echo "sub-configurations: $(CONFIG_LIST)" + @echo "requisite kernels sets: $(KERNEL_LIST)" + @echo "kernel-to-config map: $(KCONFIG_MAP)" + @echo "-------------------------" + @echo "BLIS version string: $(VERSION)" + @echo ".so major version: $(SO_MAJOR)" + @echo ".so minor.build vers: $(SO_MINORB)" + @echo "install libdir: $(INSTALL_LIBDIR)" + @echo "install includedir: $(INSTALL_INCDIR)" + @echo "install sharedir: $(INSTALL_SHAREDIR)" + @echo "debugging status: $(DEBUG_TYPE)" + @echo "multithreading status: $(THREADING_MODEL)" + @echo "enable BLAS API? $(MK_ENABLE_BLAS)" + @echo "enable CBLAS API? $(MK_ENABLE_CBLAS)" + @echo "build static library? $(MK_ENABLE_STATIC)" + @echo "build shared library? $(MK_ENABLE_SHARED)" + @echo "ARG_MAX hack enabled? $(ARG_MAX_HACK)" # --- Clean rules --- @@ -1059,16 +1059,16 @@ ifneq ($(SANDBOX),) - $(FIND) $(SANDBOX_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) endif else - @echo "Removing makefile fragments from $(CONFIG_FRAG_PATH)." + @echo "Removing makefile fragments from $(CONFIG_FRAG_PATH)" @- $(FIND) $(CONFIG_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) - @echo "Removing makefile fragments from $(FRAME_FRAG_PATH)." + @echo "Removing makefile fragments from $(FRAME_FRAG_PATH)" @- $(FIND) $(FRAME_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) - @echo "Removing makefile fragments from $(REFKERN_FRAG_PATH)." + @echo "Removing makefile fragments from $(REFKERN_FRAG_PATH)" @- $(FIND) $(REFKERN_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) - @echo "Removing makefile fragments from $(KERNELS_FRAG_PATH)." + @echo "Removing makefile fragments from $(KERNELS_FRAG_PATH)" @- $(FIND) $(KERNELS_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) ifneq ($(SANDBOX),) - @echo "Removing makefile fragments from $(SANDBOX_FRAG_PATH)." + @echo "Removing makefile fragments from $(SANDBOX_FRAG_PATH)" @- $(FIND) $(SANDBOX_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) endif endif @@ -1080,7 +1080,7 @@ ifeq ($(ENABLE_VERBOSE),yes) $(RM_F) $(BLIS_H_FLAT) $(RM_F) $(CBLAS_H_FLAT) else - @echo "Removing flattened header files from $(BASE_INC_PATH)." + @echo "Removing flattened header files from $(BASE_INC_PATH)" @$(RM_F) $(BLIS_H_FLAT) @$(RM_F) $(CBLAS_H_FLAT) endif @@ -1093,9 +1093,9 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(LIBBLIS_A_PATH) - $(RM_F) $(LIBBLIS_SO_PATH) else - @echo "Removing object files from $(BASE_OBJ_PATH)." + @echo "Removing object files from $(BASE_OBJ_PATH)" @- $(FIND) $(BASE_OBJ_PATH) -name "*.o" | $(XARGS) $(RM_F) - @echo "Removing libraries from $(BASE_LIB_PATH)." + @echo "Removing libraries from $(BASE_LIB_PATH)" @- $(RM_F) $(LIBBLIS_A_PATH) @- $(RM_F) $(LIBBLIS_SO_PATH) endif @@ -1117,13 +1117,13 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(BLASTEST_DRV_BIN_PATHS) - $(RM_F) $(addprefix out.,$(BLASTEST_DRV_BASES)) else - @echo "Removing object files from $(BASE_OBJ_BLASTEST_PATH)." + @echo "Removing object files from $(BASE_OBJ_BLASTEST_PATH)" @- $(RM_F) $(BLASTEST_F2C_OBJS) $(BLASTEST_DRV_OBJS) - @echo "Removing libf2c.a from $(BASE_OBJ_BLASTEST_PATH)." + @echo "Removing libf2c.a from $(BASE_OBJ_BLASTEST_PATH)" @- $(RM_F) $(BLASTEST_F2C_LIB) - @echo "Removing binaries from $(BASE_OBJ_BLASTEST_PATH)." + @echo "Removing binaries from $(BASE_OBJ_BLASTEST_PATH)" @- $(RM_F) $(BLASTEST_DRV_BIN_PATHS) - @echo "Removing driver output files 'out.*'." + @echo "Removing driver output files 'out.*'" @- $(RM_F) $(addprefix out.,$(BLASTEST_DRV_BASES)) endif # ENABLE_VERBOSE endif # IS_CONFIGURED @@ -1136,13 +1136,13 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(BLASTEST_DIR)/$(BLASTEST_F2C_LIB_NAME) - $(RM_F) $(addprefix $(BLASTEST_DIR)/out.,$(BLASTEST_DRV_BASES)) else - @echo "Removing object files from ./$(BLASTEST_DIR)/$(OBJ_DIR)." + @echo "Removing object files from ./$(BLASTEST_DIR)/$(OBJ_DIR)" @- $(FIND) $(BLASTEST_DIR)/$(OBJ_DIR) -name "*.o" | $(XARGS) $(RM_F) - @echo "Removing libf2c.a from ./$(BLASTEST_DIR)." + @echo "Removing libf2c.a from ./$(BLASTEST_DIR)" @- $(RM_F) $(BLASTEST_DIR)/$(BLASTEST_F2C_LIB_NAME) - @echo "Removing binaries from ./$(BLASTEST_DIR)." + @echo "Removing binaries from ./$(BLASTEST_DIR)" @- $(FIND) $(BLASTEST_DIR) -name "*.x" | $(XARGS) $(RM_F) - @echo "Removing driver output files 'out.*' from ./$(BLASTEST_DIR)." + @echo "Removing driver output files 'out.*' from ./$(BLASTEST_DIR)" @- $(RM_F) $(addprefix $(BLASTEST_DIR)/out.,$(BLASTEST_DRV_BASES)) endif # ENABLE_VERBOSE endif # IS_CONFIGURED @@ -1160,11 +1160,11 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(TESTSUITE_BIN) - $(RM_F) $(TESTSUITE_OUT_FILE) else - @echo "Removing object files from $(BASE_OBJ_TESTSUITE_PATH)." + @echo "Removing object files from $(BASE_OBJ_TESTSUITE_PATH)" @- $(RM_F) $(MK_TESTSUITE_OBJS) - @echo "Removing binary $(TESTSUITE_BIN)." + @echo "Removing binary $(TESTSUITE_BIN)" @- $(RM_F) $(TESTSUITE_BIN) - @echo "Removing $(TESTSUITE_OUT_FILE)." + @echo "Removing $(TESTSUITE_OUT_FILE)" @- $(RM_F) $(TESTSUITE_OUT_FILE) endif # ENABLE_VERBOSE endif # IS_CONFIGURED @@ -1176,9 +1176,9 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(TESTSUITE_DIR)/$(TESTSUITE_BIN) - $(MAKE) -C $(CPP_TEST_DIR) clean else - @echo "Removing object files from $(TESTSUITE_DIR)/$(OBJ_DIR)." + @echo "Removing object files from $(TESTSUITE_DIR)/$(OBJ_DIR)" @- $(FIND) $(TESTSUITE_DIR)/$(OBJ_DIR) -name "*.o" | $(XARGS) $(RM_F) - @echo "Removing binary $(TESTSUITE_DIR)/$(TESTSUITE_BIN)." + @echo "Removing binary $(TESTSUITE_DIR)/$(TESTSUITE_BIN)" @- $(RM_F) $(TESTSUITE_DIR)/$(TESTSUITE_BIN) @$(MAKE) -C $(CPP_TEST_DIR) clean endif # ENABLE_VERBOSE @@ -1193,15 +1193,15 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(RM_RF) $(LIB_DIR) - $(RM_RF) $(INCLUDE_DIR) else - @echo "Removing $(BLIS_CONFIG_H)." + @echo "Removing $(BLIS_CONFIG_H)" @$(RM_F) $(BLIS_CONFIG_H) - @echo "Removing $(CONFIG_MK_FILE)." + @echo "Removing $(CONFIG_MK_FILE)" @- $(RM_F) $(CONFIG_MK_FILE) - @echo "Removing $(OBJ_DIR)." + @echo "Removing $(OBJ_DIR)" @- $(RM_RF) $(OBJ_DIR) - @echo "Removing $(LIB_DIR)." + @echo "Removing $(LIB_DIR)" @- $(RM_RF) $(LIB_DIR) - @echo "Removing $(INCLUDE_DIR)." + @echo "Removing $(INCLUDE_DIR)" @- $(RM_RF) $(INCLUDE_DIR) endif endif @@ -1210,7 +1210,7 @@ endif # --- CHANGELOG rules --- changelog: - @echo "Updating '$(DIST_PATH)/$(CHANGELOG)' via '$(GIT_LOG)'." + @echo "Updating '$(DIST_PATH)/$(CHANGELOG)' via '$(GIT_LOG)'" @$(GIT_LOG) > $(DIST_PATH)/$(CHANGELOG) @@ -1225,7 +1225,7 @@ uninstall-libs: check-env ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(MK_LIBS_INST) else - @echo "Uninstalling libraries $(notdir $(MK_LIBS_INST)) from $(dir $(firstword $(MK_LIBS_INST)))." + @echo "Uninstalling libraries $(notdir $(MK_LIBS_INST)) from $(dir $(firstword $(MK_LIBS_INST)))" @- $(RM_F) $(MK_LIBS_INST) endif @@ -1233,7 +1233,7 @@ uninstall-lib-symlinks: check-env ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(MK_LIBS_SYML) else - @echo "Uninstalling symlinks $(notdir $(MK_LIBS_SYML)) from $(dir $(firstword $(MK_LIBS_SYML)))." + @echo "Uninstalling symlinks $(notdir $(MK_LIBS_SYML)) from $(dir $(firstword $(MK_LIBS_SYML)))" @- $(RM_F) $(MK_LIBS_SYML) endif @@ -1241,7 +1241,7 @@ uninstall-headers: check-env ifeq ($(ENABLE_VERBOSE),yes) - $(RM_RF) $(MK_INCL_DIR_INST) else - @echo "Uninstalling directory '$(notdir $(MK_INCL_DIR_INST))' from $(dir $(MK_INCL_DIR_INST))." + @echo "Uninstalling directory '$(notdir $(MK_INCL_DIR_INST))' from $(dir $(MK_INCL_DIR_INST))" @- $(RM_RF) $(MK_INCL_DIR_INST) endif @@ -1249,7 +1249,7 @@ uninstall-share: check-env ifeq ($(ENABLE_VERBOSE),yes) - $(RM_RF) $(MK_SHARE_DIR_INST) else - @echo "Uninstalling directory '$(notdir $(MK_SHARE_DIR_INST))' from $(dir $(MK_SHARE_DIR_INST))." + @echo "Uninstalling directory '$(notdir $(MK_SHARE_DIR_INST))' from $(dir $(MK_SHARE_DIR_INST))" @- $(RM_RF) $(MK_SHARE_DIR_INST) endif @@ -1265,7 +1265,7 @@ $(UNINSTALL_OLD_LIBS) $(UNINSTALL_OLD_SYML) $(UNINSTALL_OLD_HEADERS): check-env ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $@ else - @echo "Uninstalling $(@F) from $(@D)/." + @echo "Uninstalling $(@F) from $(@D)/" @- $(RM_F) $@ endif diff --git a/README.md b/README.md index 13acd96ec..60ac20b2c 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ Contents -------- * **[Introduction](#introduction)** +* **[Education and Learning](#education-and-learning)** * **[What's New](#whats-new)** * **[What People Are Saying About BLIS](#what-people-are-saying-about-blis)** * **[Key Features](#key-features)** @@ -76,9 +77,38 @@ and [collaborators](http://shpc.ices.utexas.edu/collaborators.html), [publications](http://shpc.ices.utexas.edu/publications.html), and [other educational projects](http://www.ulaff.net/) (such as MOOCs). +Education and Learning +---------------------- + +Want to understand what's under the hood? +Many of the same concepts and principles employed when developing BLIS are +introduced and taught in a basic pedagogical setting as part of +[LAFF-On Programming for High Performance (LAFF-On-PfHP)](http://www.ulaff.net/), +one of several massive open online courses (MOOCs) in the +[Linear Algebra: Foundations to Frontiers](http://www.ulaff.net/) series, +all of which are available for free via the [edX platform](http://www.edx.org/). + What's New ---------- + * **Small/skinny matrix support for dgemm now available!** Thanks to +contributions made possible by our partnership with AMD, we have dramatically +accelerated `gemm` for double-precision real matrix problems where one or two +dimensions is exceedingly small. A natural byproduct of this optimization is +that the traditional case of small _m = n = k_ (i.e. square matrices) is also +accelerated, even though it was not targeted specifically. And though only +`dgemm` was optimized for now, support for other datatypes, other operations, +and/or multithreading may be implemented in the future. We've also added a new +[PerformanceSmall](docs/PerformanceSmall.md) document to showcase the +improvement in performance when some matrix dimensions are small. + + * **Performance comparisons now available!** We recently measured the +performance of various level-3 operations on a variety of hardware architectures, +as implemented within BLIS and other BLAS libraries for all four of the standard +floating-point datatypes. The results speak for themselves! Check out our +extensive performance graphs and background info in our new +[Performance](docs/Performance.md) document. + * **BLIS is now in Debian Unstable!** Thanks to Debian developer-maintainers [M. Zhou](https://github.com/cdluminate) and [Nico Schlömer](https://github.com/nschloe) for sponsoring our package in Debian. @@ -87,7 +117,7 @@ the second-most popular Linux distribution (behind Ubuntu, which Debian packages feed into). The Debian tracker page may be found [here](https://tracker.debian.org/pkg/blis). - * **BLIS now supports mixed-datatype gemm.** The `gemm` operation may now be + * **BLIS now supports mixed-datatype gemm!** The `gemm` operation may now be executed on operands of mixed domains and/or mixed precisions. Any combination of storage datatype for A, B, and C is now supported, along with a separate computation precision that can differ from the storage precision of A and B. @@ -313,10 +343,20 @@ table of supported microarchitectures. * **[Multithreading](docs/Multithreading.md).** This document describes how to use the multithreading features of BLIS. - * **[Mixed-Datatype](docs/MixedDatatype.md).** This document provides an + * **[Mixed-Datatypes](docs/MixedDatatypes.md).** This document provides an overview of BLIS's mixed-datatype functionality and provides a brief example of how to take advantage of this new code. + * **[Performance](docs/Performance.md).** This document reports empirically +measured performance of a representative set of level-3 operations on a variety +of hardware architectures, as implemented within BLIS and other BLAS libraries +for all four of the standard floating-point datatypes. + + * **[PerformanceSmall](docs/PerformanceSmall.md).** This document reports +empirically measured performance of `gemm` on select hardware architectures +within BLIS and other BLAS libraries when performing matrix problems where one +or two dimensions is exceedingly small. + * **[Release Notes](docs/ReleaseNotes.md).** This document tracks a summary of changes included with each new version of BLIS, along with contributor credits for key features. diff --git a/windows/Makefile b/attic/windows/Makefile similarity index 100% rename from windows/Makefile rename to attic/windows/Makefile diff --git a/windows/build/bli_kernel.h b/attic/windows/build/bli_kernel.h similarity index 100% rename from windows/build/bli_kernel.h rename to attic/windows/build/bli_kernel.h diff --git a/windows/build/config.mk.in b/attic/windows/build/config.mk.in similarity index 100% rename from windows/build/config.mk.in rename to attic/windows/build/config.mk.in diff --git a/windows/build/defs.mk b/attic/windows/build/defs.mk similarity index 100% rename from windows/build/defs.mk rename to attic/windows/build/defs.mk diff --git a/windows/build/gather-src-for-windows.py b/attic/windows/build/gather-src-for-windows.py similarity index 100% rename from windows/build/gather-src-for-windows.py rename to attic/windows/build/gather-src-for-windows.py diff --git a/windows/build/gen-check-rev-file.py b/attic/windows/build/gen-check-rev-file.py similarity index 100% rename from windows/build/gen-check-rev-file.py rename to attic/windows/build/gen-check-rev-file.py diff --git a/windows/build/gen-config-file.py b/attic/windows/build/gen-config-file.py similarity index 100% rename from windows/build/gen-config-file.py rename to attic/windows/build/gen-config-file.py diff --git a/windows/build/ignore_list b/attic/windows/build/ignore_list similarity index 100% rename from windows/build/ignore_list rename to attic/windows/build/ignore_list diff --git a/windows/build/ignore_list.windows b/attic/windows/build/ignore_list.windows similarity index 100% rename from windows/build/ignore_list.windows rename to attic/windows/build/ignore_list.windows diff --git a/windows/build/leaf_list b/attic/windows/build/leaf_list similarity index 100% rename from windows/build/leaf_list rename to attic/windows/build/leaf_list diff --git a/windows/build/nmake-help.cmd b/attic/windows/build/nmake-help.cmd similarity index 100% rename from windows/build/nmake-help.cmd rename to attic/windows/build/nmake-help.cmd diff --git a/windows/configure.cmd b/attic/windows/configure.cmd similarity index 100% rename from windows/configure.cmd rename to attic/windows/configure.cmd diff --git a/windows/gendll.cmd b/attic/windows/gendll.cmd similarity index 100% rename from windows/gendll.cmd rename to attic/windows/gendll.cmd diff --git a/windows/linkargs.txt b/attic/windows/linkargs.txt similarity index 100% rename from windows/linkargs.txt rename to attic/windows/linkargs.txt diff --git a/windows/linkargs64.txt b/attic/windows/linkargs64.txt similarity index 100% rename from windows/linkargs64.txt rename to attic/windows/linkargs64.txt diff --git a/windows/revision b/attic/windows/revision similarity index 100% rename from windows/revision rename to attic/windows/revision diff --git a/windows/vc110.pdb b/attic/windows/vc110.pdb similarity index 100% rename from windows/vc110.pdb rename to attic/windows/vc110.pdb diff --git a/blastest/Makefile b/blastest/Makefile index 4659fcfee..b4b40a714 100644 --- a/blastest/Makefile +++ b/blastest/Makefile @@ -136,7 +136,7 @@ CFLAGS += -Wno-maybe-uninitialized -Wno-parentheses -Wfatal-errors \ -I$(INC_PATH) -DHAVE_BLIS_H # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Override the location of the check-blastest.sh script. #BLASTEST_CHECK := ./check-blastest.sh diff --git a/build/bli_config.h.in b/build/bli_config.h.in index f940090ad..1bb2ef28b 100644 --- a/build/bli_config.h.in +++ b/build/bli_config.h.in @@ -135,6 +135,12 @@ #endif #endif +#if @enable_sup_handling@ +#define BLIS_ENABLE_SUP_HANDLING +#else +#define BLIS_DISABLE_SUP_HANDLING +#endif + #if @enable_memkind@ #define BLIS_ENABLE_MEMKIND #else @@ -159,4 +165,5 @@ #define BLIS_DISABLE_SHARED #endif + #endif diff --git a/build/config.mk.in b/build/config.mk.in index c68601c67..0516ec97b 100644 --- a/build/config.mk.in +++ b/build/config.mk.in @@ -115,13 +115,33 @@ THREADING_MODEL := @threading_model@ # Whether the compiler supports "#pragma omp simd" via the -fopenmp-simd option. PRAGMA_OMP_SIMD := @pragma_omp_simd@ -# The install libdir, includedir, and shareddir values from configure tell -# us where to install the libraries, header files, and public makefile -# fragments, respectively. Notice that we support the use of DESTDIR so that -# advanced users may install to a temporary location. -INSTALL_LIBDIR := $(DESTDIR)@install_libdir@ -INSTALL_INCDIR := $(DESTDIR)@install_incdir@ -INSTALL_SHAREDIR := $(DESTDIR)@install_sharedir@ +# The installation prefix, exec_prefix, libdir, includedir, and shareddir +# values from configure tell us where to install the libraries, header files, +# and public makefile fragments. We must first assign each substituted +# @anchor@ to its own variable. Why? Because the subsitutions may contain +# unevaluated variable expressions. For example, '@libdir@' may be replaced +# with '${exec_prefix}/lib'. By assigning the anchors to variables first, and +# then assigning them to their final INSTALL_* variables, we allow prefix and +# exec_prefix to be used in the definitions of exec_prefix, libdir, +# includedir, and sharedir. +prefix := @prefix@ +exec_prefix := @exec_prefix@ +libdir := @libdir@ +includedir := @includedir@ +sharedir := @sharedir@ + +# Notice that we support the use of DESTDIR so that advanced users may install +# to a temporary location. +INSTALL_LIBDIR := $(DESTDIR)$(libdir) +INSTALL_INCDIR := $(DESTDIR)$(includedir) +INSTALL_SHAREDIR := $(DESTDIR)$(sharedir) + +#$(info prefix = $(prefix) ) +#$(info exec_prefix = $(exec_prefix) ) +#$(info libdir = $(libdir) ) +#$(info includedir = $(includedir) ) +#$(info sharedir = $(sharedir) ) +#$(error .) # Whether to output verbose command-line feedback as the Makefile is # processed. @@ -135,11 +155,15 @@ BUILDING_OOT := @configured_oot@ ARG_MAX_HACK := @enable_arg_max_hack@ # Whether to build the static and shared libraries. -# Note the "MK_" prefix, which helps differentiate these variables from +# NOTE: The "MK_" prefix, which helps differentiate these variables from # their corresonding cpp macros that use the BLIS_ prefix. MK_ENABLE_STATIC := @enable_static@ MK_ENABLE_SHARED := @enable_shared@ +# Whether to export all symbols within the shared library, even those symbols +# that are considered to be for internal use only. +EXPORT_SHARED := @export_shared@ + # Whether to enable either the BLAS or CBLAS compatibility layers. MK_ENABLE_BLAS := @enable_blas@ MK_ENABLE_CBLAS := @enable_cblas@ diff --git a/build/detect/config/config_detect.c b/build/detect/config/config_detect.c index 7d0e95522..6d59d6625 100644 --- a/build/detect/config/config_detect.c +++ b/build/detect/config/config_detect.c @@ -33,6 +33,7 @@ */ +#define BLIS_EXPORT_BLIS #include "bli_system.h" #include "bli_type_defs.h" #include "bli_arch.h" diff --git a/build/flatten-headers.py b/build/flatten-headers.py index 9599278d2..563725a7e 100755 --- a/build/flatten-headers.py +++ b/build/flatten-headers.py @@ -244,10 +244,24 @@ def flatten_header( inputfile, header_dirpaths, cursp ): # directive. header_path = get_header_path( header, header_dirpaths ) - # If the header was found, we recurse. Otherwise, we output - # the #include directive with a comment indicating that it - # was skipped. - if header_path: + # First, check if the header is our root header (and if so, ignore it). + # Otherwise, if the header was found, we recurse. Otherwise, we output + # the #include directive with a comment indicating that it as skipped + if header == root_inputfile: + + markl = result.group(1) + markr = result.group(3) + + echov2( "%sthis is the root header '%s'; commenting out / skipping." \ + % ( cursp, header ) ) + + # If the header found is our root header, then we cannot + # recurse into it lest we enter an infinite loop. Output the + # line but make sure it's commented out entirely. + ostring += "%s #include %c%s%c %c" \ + % ( skipstr, markl, header, markr, '\n' ) + + elif header_path: echov2( "%slocated file '%s'; recursing." \ % ( cursp, header_path ) ) @@ -327,6 +341,7 @@ strip_comments = None recursive_flag = None verbose_flag = None regex = None +root_inputfile = None def main(): @@ -336,6 +351,7 @@ def main(): global recursive_flag global verbose_flag global regex + global root_inputfile # Obtain the script name. path, script_name = os.path.split(sys.argv[0]) @@ -397,6 +413,10 @@ def main(): temp_dir = args[2] dir_list = args[3] + # Save the filename (basename) part of the input file (or root file) into a + # global variable that we can access later from within flatten_header(). + root_inputfile = os.path.basename( inputfile ) + # Separate the directories into distinct strings. dir_list = dir_list.split() diff --git a/build/gen-make-frags/gen-make-frag.sh b/build/gen-make-frags/gen-make-frag.sh index 4d8cb408d..e411fa8d9 100755 --- a/build/gen-make-frags/gen-make-frag.sh +++ b/build/gen-make-frags/gen-make-frag.sh @@ -417,8 +417,9 @@ main() # The arguments to this function. They'll get assigned meaningful # values after getopts. - mkfile_frag_tmpl_path="" root_dir="" + frag_dir="" + mkfile_frag_tmpl_path="" suffix_file="" ignore_file="" diff --git a/build/libblis-symbols.def b/build/libblis-symbols.def index f4db5f98f..e1bfce807 100644 --- a/build/libblis-symbols.def +++ b/build/libblis-symbols.def @@ -183,13 +183,11 @@ bli_cgemm4mb bli_cgemm4mb_ker_var2 bli_cgemm4mh bli_cgemm_ex -bli_cgemm_haswell_asm_3x8 -bli_cgemm_haswell_asm_8x3 bli_cgemm_ker_var2 bli_cgemm_md_c2r_ref -bli_cgemm_ukernel bli_cgemmtrsm_l_ukernel bli_cgemmtrsm_u_ukernel +bli_cgemm_ukernel bli_cgemv bli_cgemv_ex bli_cgemv_unb_var1 @@ -285,12 +283,6 @@ bli_chemv_unf_var3a bli_cher bli_cher2 bli_cher2_ex -bli_cher2_unb_var1 -bli_cher2_unb_var2 -bli_cher2_unb_var3 -bli_cher2_unb_var4 -bli_cher2_unf_var1 -bli_cher2_unf_var4 bli_cher2k bli_cher2k1m bli_cher2k3m1 @@ -298,9 +290,13 @@ bli_cher2k3mh bli_cher2k4m1 bli_cher2k4mh bli_cher2k_ex +bli_cher2_unb_var1 +bli_cher2_unb_var2 +bli_cher2_unb_var3 +bli_cher2_unb_var4 +bli_cher2_unf_var1 +bli_cher2_unf_var4 bli_cher_ex -bli_cher_unb_var1 -bli_cher_unb_var2 bli_cherk bli_cherk1m bli_cherk3m1 @@ -310,6 +306,8 @@ bli_cherk4mh bli_cherk_ex bli_cherk_l_ker_var2 bli_cherk_u_ker_var2 +bli_cher_unb_var1 +bli_cher_unb_var2 bli_cinvertd bli_cinvertd_ex bli_cinvertsc @@ -354,8 +352,8 @@ bli_cntl_copy bli_cntl_create_node bli_cntl_free bli_cntl_free_node -bli_cntl_free_w_thrinfo bli_cntl_free_wo_thrinfo +bli_cntl_free_w_thrinfo bli_cntl_mark_family bli_cntx_1m_stage bli_cntx_3m1_stage @@ -544,8 +542,8 @@ bli_ctrsm1m bli_ctrsm3m1 bli_ctrsm4m1 bli_ctrsm_ex -bli_ctrsm_l_ukernel bli_ctrsm_ll_ker_var2 +bli_ctrsm_l_ukernel bli_ctrsm_lu_ker_var2 bli_ctrsm_rl_ker_var2 bli_ctrsm_ru_ker_var2 @@ -591,7 +589,6 @@ bli_daddv bli_daddv_ex bli_damaxv bli_damaxv_ex -bli_damaxv_zen_int bli_dasumv bli_dasumv_ex bli_dasumv_unb_var1 @@ -603,14 +600,11 @@ bli_daxpyd bli_daxpyd_ex bli_daxpyf bli_daxpyf_ex -bli_daxpyf_zen_int_8 bli_daxpym bli_daxpym_ex bli_daxpym_unb_var1 bli_daxpyv bli_daxpyv_ex -bli_daxpyv_zen_int -bli_daxpyv_zen_int10 bli_dccastm bli_dccastnzm bli_dccastv @@ -640,16 +634,12 @@ bli_ddotaxpyv bli_ddotaxpyv_ex bli_ddotv bli_ddotv_ex -bli_ddotv_zen_int -bli_ddotv_zen_int10 bli_ddotxaxpyf bli_ddotxaxpyf_ex bli_ddotxf bli_ddotxf_ex -bli_ddotxf_zen_int_8 bli_ddotxv bli_ddotxv_ex -bli_ddotxv_zen_int bli_ddpackm_blk_var1_md bli_ddpackm_cxk_1e_md bli_ddpackm_cxk_1r_md @@ -673,14 +663,10 @@ bli_dgemm4mb bli_dgemm4mb_ker_var2 bli_dgemm4mh bli_dgemm_ex -bli_dgemm_haswell_asm_6x8 -bli_dgemm_haswell_asm_8x6 bli_dgemm_ker_var2 -bli_dgemm_ukernel -bli_dgemmtrsm_l_haswell_asm_6x8 bli_dgemmtrsm_l_ukernel -bli_dgemmtrsm_u_haswell_asm_6x8 bli_dgemmtrsm_u_ukernel +bli_dgemm_ukernel bli_dgemv bli_dgemv_ex bli_dgemv_unb_var1 @@ -713,12 +699,6 @@ bli_dhemv_unf_var3a bli_dher bli_dher2 bli_dher2_ex -bli_dher2_unb_var1 -bli_dher2_unb_var2 -bli_dher2_unb_var3 -bli_dher2_unb_var4 -bli_dher2_unf_var1 -bli_dher2_unf_var4 bli_dher2k bli_dher2k1m bli_dher2k3m1 @@ -726,9 +706,13 @@ bli_dher2k3mh bli_dher2k4m1 bli_dher2k4mh bli_dher2k_ex +bli_dher2_unb_var1 +bli_dher2_unb_var2 +bli_dher2_unb_var3 +bli_dher2_unb_var4 +bli_dher2_unf_var1 +bli_dher2_unf_var4 bli_dher_ex -bli_dher_unb_var1 -bli_dher_unb_var2 bli_dherk bli_dherk1m bli_dherk3m1 @@ -738,6 +722,8 @@ bli_dherk4mh bli_dherk_ex bli_dherk_l_ker_var2 bli_dherk_u_ker_var2 +bli_dher_unb_var1 +bli_dher_unb_var2 bli_dinvertd bli_dinvertd_ex bli_dinvertsc @@ -746,11 +732,6 @@ bli_dinvertv_ex bli_divsc bli_divsc_check bli_divsc_qfp -bli_dlamc1 -bli_dlamc2 -bli_dlamc3 -bli_dlamc4 -bli_dlamc5 bli_dlamch bli_dmachval bli_dmkherm @@ -838,8 +819,6 @@ bli_dscalm_ex bli_dscalm_unb_var1 bli_dscalv bli_dscalv_ex -bli_dscalv_zen_int -bli_dscalv_zen_int10 bli_dscastm bli_dscastnzm bli_dscastv @@ -906,11 +885,6 @@ bli_dsyrk3mh bli_dsyrk4m1 bli_dsyrk4mh bli_dsyrk_ex -bli_dt_size -bli_dt_size_check -bli_dt_string -bli_dt_string_check -bli_dt_union_check bli_dtrmm bli_dtrmm1m bli_dtrmm3 @@ -938,8 +912,8 @@ bli_dtrsm1m bli_dtrsm3m1 bli_dtrsm4m1 bli_dtrsm_ex -bli_dtrsm_l_ukernel bli_dtrsm_ll_ker_var2 +bli_dtrsm_l_ukernel bli_dtrsm_lu_ker_var2 bli_dtrsm_rl_ker_var2 bli_dtrsm_ru_ker_var2 @@ -950,6 +924,11 @@ bli_dtrsv_unb_var1 bli_dtrsv_unb_var2 bli_dtrsv_unf_var1 bli_dtrsv_unf_var2 +bli_dt_size +bli_dt_size_check +bli_dt_string +bli_dt_string_check +bli_dt_union_check bli_dunpackm_blk_var1 bli_dunpackm_cxk bli_dunpackm_unb_var1 @@ -1018,6 +997,7 @@ bli_gemm_basic_check bli_gemm_blk_var1 bli_gemm_blk_var2 bli_gemm_blk_var3 +bli_gemmbp_cntl_create bli_gemm_check bli_gemm_cntl_create bli_gemm_cntl_create_node @@ -1028,6 +1008,8 @@ bli_gemm_determine_kc_f bli_gemm_direct bli_gemm_ex bli_gemm_front +bli_gemmind +bli_gemmind_get_avail bli_gemm_int bli_gemm_ker_var2 bli_gemm_ker_var2_md @@ -1040,20 +1022,17 @@ bli_gemm_md_rcc bli_gemm_md_rcr bli_gemm_md_rrc bli_gemm_md_rrr +bli_gemmnat bli_gemm_packa bli_gemm_packb bli_gemm_prune_unref_mparts_k bli_gemm_prune_unref_mparts_m bli_gemm_prune_unref_mparts_n +bli_gemmtrsm_l_ukernel_qfp +bli_gemmtrsm_ukernel +bli_gemmtrsm_u_ukernel_qfp bli_gemm_ukernel bli_gemm_ukernel_qfp -bli_gemmbp_cntl_create -bli_gemmind -bli_gemmind_get_avail -bli_gemmnat -bli_gemmtrsm_l_ukernel_qfp -bli_gemmtrsm_u_ukernel_qfp -bli_gemmtrsm_ukernel bli_gemv bli_gemv_check bli_gemv_ex @@ -1120,30 +1099,18 @@ bli_hemv_unb_var3_qfp bli_hemv_unb_var4 bli_hemv_unb_var4_qfp bli_hemv_unf_var1 -bli_hemv_unf_var1_qfp bli_hemv_unf_var1a bli_hemv_unf_var1a_qfp +bli_hemv_unf_var1_qfp bli_hemv_unf_var3 -bli_hemv_unf_var3_qfp bli_hemv_unf_var3a bli_hemv_unf_var3a_qfp +bli_hemv_unf_var3_qfp bli_her bli_her2 bli_her2_check bli_her2_ex bli_her2_ex_qfp -bli_her2_unb_var1 -bli_her2_unb_var1_qfp -bli_her2_unb_var2 -bli_her2_unb_var2_qfp -bli_her2_unb_var3 -bli_her2_unb_var3_qfp -bli_her2_unb_var4 -bli_her2_unb_var4_qfp -bli_her2_unf_var1 -bli_her2_unf_var1_qfp -bli_her2_unf_var4 -bli_her2_unf_var4_qfp bli_her2k bli_her2k1m bli_her2k3m1 @@ -1157,13 +1124,21 @@ bli_her2k_front bli_her2kind bli_her2kind_get_avail bli_her2knat +bli_her2_unb_var1 +bli_her2_unb_var1_qfp +bli_her2_unb_var2 +bli_her2_unb_var2_qfp +bli_her2_unb_var3 +bli_her2_unb_var3_qfp +bli_her2_unb_var4 +bli_her2_unb_var4_qfp +bli_her2_unf_var1 +bli_her2_unf_var1_qfp +bli_her2_unf_var4 +bli_her2_unf_var4_qfp bli_her_check bli_her_ex bli_her_ex_qfp -bli_her_unb_var1 -bli_her_unb_var1_qfp -bli_her_unb_var2 -bli_her_unb_var2_qfp bli_herk bli_herk1m bli_herk3m1 @@ -1178,15 +1153,19 @@ bli_herk_determine_kc_f bli_herk_direct bli_herk_ex bli_herk_front +bli_herkind +bli_herkind_get_avail bli_herk_l_ker_var2 +bli_herknat bli_herk_prune_unref_mparts_k bli_herk_prune_unref_mparts_m bli_herk_prune_unref_mparts_n bli_herk_u_ker_var2 bli_herk_x_ker_var2 -bli_herkind -bli_herkind_get_avail -bli_herknat +bli_her_unb_var1 +bli_her_unb_var1_qfp +bli_her_unb_var2 +bli_her_unb_var2_qfp bli_ifprintm bli_ifprintv bli_igetsc @@ -1217,9 +1196,9 @@ bli_info_get_enable_sba_pools bli_info_get_enable_stay_auto_init bli_info_get_enable_threading bli_info_get_gemm_impl_string -bli_info_get_gemm_ukr_impl_string bli_info_get_gemmtrsm_l_ukr_impl_string bli_info_get_gemmtrsm_u_ukr_impl_string +bli_info_get_gemm_ukr_impl_string bli_info_get_heap_addr_align_size bli_info_get_heap_stride_align_size bli_info_get_hemm_impl_string @@ -1278,12 +1257,12 @@ bli_l1d_xy_check bli_l1m_ax_check bli_l1m_axy_check bli_l1m_xy_check -bli_l1v_ax_check bli_l1v_axby_check +bli_l1v_ax_check bli_l1v_axy_check bli_l1v_dot_check -bli_l1v_x_check bli_l1v_xby_check +bli_l1v_x_check bli_l1v_xi_check bli_l1v_xy_check bli_l3_basic_check @@ -1452,12 +1431,10 @@ bli_pool_init bli_pool_print bli_pool_reinit bli_pool_shrink -bli_pow_di -bli_pow_ri bli_prime_factorization -bli_print_msg bli_printm bli_printm_ex +bli_print_msg bli_printv bli_printv_ex bli_projm @@ -1510,7 +1487,6 @@ bli_saddv bli_saddv_ex bli_samaxv bli_samaxv_ex -bli_samaxv_zen_int bli_sasumv bli_sasumv_ex bli_sasumv_unb_var1 @@ -1522,14 +1498,11 @@ bli_saxpyd bli_saxpyd_ex bli_saxpyf bli_saxpyf_ex -bli_saxpyf_zen_int_8 bli_saxpym bli_saxpym_ex bli_saxpym_unb_var1 bli_saxpyv bli_saxpyv_ex -bli_saxpyv_zen_int -bli_saxpyv_zen_int10 bli_sba_acquire bli_sba_checkin_array bli_sba_checkout_array @@ -1591,16 +1564,12 @@ bli_sdotaxpyv bli_sdotaxpyv_ex bli_sdotv bli_sdotv_ex -bli_sdotv_zen_int -bli_sdotv_zen_int10 bli_sdotxaxpyf bli_sdotxaxpyf_ex bli_sdotxf bli_sdotxf_ex -bli_sdotxf_zen_int_8 bli_sdotxv bli_sdotxv_ex -bli_sdotxv_zen_int bli_sdpackm_blk_var1_md bli_sdpackm_cxk_1e_md bli_sdpackm_cxk_1r_md @@ -1643,14 +1612,10 @@ bli_sgemm4mb bli_sgemm4mb_ker_var2 bli_sgemm4mh bli_sgemm_ex -bli_sgemm_haswell_asm_16x6 -bli_sgemm_haswell_asm_6x16 bli_sgemm_ker_var2 -bli_sgemm_ukernel -bli_sgemmtrsm_l_haswell_asm_6x16 bli_sgemmtrsm_l_ukernel -bli_sgemmtrsm_u_haswell_asm_6x16 bli_sgemmtrsm_u_ukernel +bli_sgemm_ukernel bli_sgemv bli_sgemv_ex bli_sgemv_unb_var1 @@ -1683,12 +1648,6 @@ bli_shemv_unf_var3a bli_sher bli_sher2 bli_sher2_ex -bli_sher2_unb_var1 -bli_sher2_unb_var2 -bli_sher2_unb_var3 -bli_sher2_unb_var4 -bli_sher2_unf_var1 -bli_sher2_unf_var4 bli_sher2k bli_sher2k1m bli_sher2k3m1 @@ -1696,9 +1655,13 @@ bli_sher2k3mh bli_sher2k4m1 bli_sher2k4mh bli_sher2k_ex +bli_sher2_unb_var1 +bli_sher2_unb_var2 +bli_sher2_unb_var3 +bli_sher2_unb_var4 +bli_sher2_unf_var1 +bli_sher2_unf_var4 bli_sher_ex -bli_sher_unb_var1 -bli_sher_unb_var2 bli_sherk bli_sherk1m bli_sherk3m1 @@ -1708,6 +1671,8 @@ bli_sherk4mh bli_sherk_ex bli_sherk_l_ker_var2 bli_sherk_u_ker_var2 +bli_sher_unb_var1 +bli_sher_unb_var2 bli_shiftd bli_shiftd_check bli_shiftd_ex @@ -1717,11 +1682,6 @@ bli_sinvertd_ex bli_sinvertsc bli_sinvertv bli_sinvertv_ex -bli_slamc1 -bli_slamc2 -bli_slamc3 -bli_slamc4 -bli_slamc5 bli_slamch bli_sleep bli_smachval @@ -1793,8 +1753,6 @@ bli_sscalm_ex bli_sscalm_unb_var1 bli_sscalv bli_sscalv_ex -bli_sscalv_zen_int -bli_sscalv_zen_int10 bli_sscastm bli_sscastnzm bli_sscastv @@ -1889,8 +1847,8 @@ bli_strsm1m bli_strsm3m1 bli_strsm4m1 bli_strsm_ex -bli_strsm_l_ukernel bli_strsm_ll_ker_var2 +bli_strsm_l_ukernel bli_strsm_lu_ker_var2 bli_strsm_rl_ker_var2 bli_strsm_ru_ker_var2 @@ -2062,17 +2020,17 @@ bli_trmm_determine_kc_f bli_trmm_direct bli_trmm_ex bli_trmm_front +bli_trmmind +bli_trmmind_get_avail bli_trmm_ll_ker_var2 bli_trmm_lu_ker_var2 +bli_trmmnat bli_trmm_prune_unref_mparts_k bli_trmm_prune_unref_mparts_m bli_trmm_prune_unref_mparts_n bli_trmm_rl_ker_var2 bli_trmm_ru_ker_var2 bli_trmm_xx_ker_var2 -bli_trmmind -bli_trmmind_get_avail -bli_trmmnat bli_trmv bli_trmv_check bli_trmv_ex @@ -2102,11 +2060,14 @@ bli_trsm_determine_kc_f bli_trsm_direct bli_trsm_ex bli_trsm_front +bli_trsmind +bli_trsmind_get_avail bli_trsm_int bli_trsm_l_cntl_create -bli_trsm_l_ukernel_qfp bli_trsm_ll_ker_var2 +bli_trsm_l_ukernel_qfp bli_trsm_lu_ker_var2 +bli_trsmnat bli_trsm_packa bli_trsm_packb bli_trsm_prune_unref_mparts_k @@ -2115,12 +2076,9 @@ bli_trsm_prune_unref_mparts_n bli_trsm_r_cntl_create bli_trsm_rl_ker_var2 bli_trsm_ru_ker_var2 -bli_trsm_u_ukernel_qfp bli_trsm_ukernel +bli_trsm_u_ukernel_qfp bli_trsm_xx_ker_var2 -bli_trsmind -bli_trsmind_get_avail -bli_trsmnat bli_trsv bli_trsv_check bli_trsv_ex @@ -2245,13 +2203,11 @@ bli_zgemm4mb bli_zgemm4mb_ker_var2 bli_zgemm4mh bli_zgemm_ex -bli_zgemm_haswell_asm_3x4 -bli_zgemm_haswell_asm_4x3 bli_zgemm_ker_var2 bli_zgemm_md_c2r_ref -bli_zgemm_ukernel bli_zgemmtrsm_l_ukernel bli_zgemmtrsm_u_ukernel +bli_zgemm_ukernel bli_zgemv bli_zgemv_ex bli_zgemv_unb_var1 @@ -2284,12 +2240,6 @@ bli_zhemv_unf_var3a bli_zher bli_zher2 bli_zher2_ex -bli_zher2_unb_var1 -bli_zher2_unb_var2 -bli_zher2_unb_var3 -bli_zher2_unb_var4 -bli_zher2_unf_var1 -bli_zher2_unf_var4 bli_zher2k bli_zher2k1m bli_zher2k3m1 @@ -2297,9 +2247,13 @@ bli_zher2k3mh bli_zher2k4m1 bli_zher2k4mh bli_zher2k_ex +bli_zher2_unb_var1 +bli_zher2_unb_var2 +bli_zher2_unb_var3 +bli_zher2_unb_var4 +bli_zher2_unf_var1 +bli_zher2_unf_var4 bli_zher_ex -bli_zher_unb_var1 -bli_zher_unb_var2 bli_zherk bli_zherk1m bli_zherk3m1 @@ -2309,6 +2263,8 @@ bli_zherk4mh bli_zherk_ex bli_zherk_l_ker_var2 bli_zherk_u_ker_var2 +bli_zher_unb_var1 +bli_zher_unb_var2 bli_zinvertd bli_zinvertd_ex bli_zinvertsc @@ -2492,8 +2448,8 @@ bli_ztrsm1m bli_ztrsm3m1 bli_ztrsm4m1 bli_ztrsm_ex -bli_ztrsm_l_ukernel bli_ztrsm_ll_ker_var2 +bli_ztrsm_l_ukernel bli_ztrsm_lu_ker_var2 bli_ztrsm_rl_ker_var2 bli_ztrsm_ru_ker_var2 @@ -2528,19 +2484,6 @@ bli_zzpackm_struc_cxk_md bli_zzxpbym_md bli_zzxpbym_md_ex bli_zzxpbym_md_unb_var1 -bla_c_abs -bla_c_div -bla_d_abs -bla_d_cnjg -bla_d_imag -bla_d_sign -bla_f__cabs -bla_r_abs -bla_r_cnjg -bla_r_imag -bla_r_sign -bla_z_abs -bla_z_div sasum_ sasumsub_ saxpy_ @@ -2567,14 +2510,14 @@ srotmg_ ssbmv_ sscal_ sspmv_ -sspr2_ sspr_ +sspr2_ sswap_ ssymm_ ssymv_ +ssyr_ ssyr2_ ssyr2k_ -ssyr_ ssyrk_ stbmv_ stbsv_ @@ -2606,14 +2549,14 @@ dscal_ dsdot_ dsdotsub_ dspmv_ -dspr2_ dspr_ +dspr2_ dswap_ dsymm_ dsymv_ +dsyr_ dsyr2_ dsyr2k_ -dsyr_ dsyrk_ dtbmv_ dtbsv_ @@ -2641,13 +2584,13 @@ cgeru_ chbmv_ chemm_ chemv_ +cher_ cher2_ cher2k_ -cher_ cherk_ chpmv_ -chpr2_ chpr_ +chpr2_ crotg_ cscal_ csrot_ @@ -2680,13 +2623,13 @@ zgeru_ zhbmv_ zhemm_ zhemv_ +zher_ zher2_ zher2k_ -zher_ zherk_ zhpmv_ -zhpr2_ zhpr_ +zhpr2_ zrotg_ zscal_ zswap_ diff --git a/build/templates/license.c b/build/templates/license.c index bc0abc656..6505a70ff 100644 --- a/build/templates/license.c +++ b/build/templates/license.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2019, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/build/templates/license.h b/build/templates/license.h index bc0abc656..6505a70ff 100644 --- a/build/templates/license.h +++ b/build/templates/license.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2019, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/build/templates/license.sh b/build/templates/license.sh index ad5965c79..b9c51e289 100644 --- a/build/templates/license.sh +++ b/build/templates/license.sh @@ -5,6 +5,7 @@ # libraries. # # Copyright (C) 2019, The University of Texas at Austin +# Copyright (C) 2018, Advanced Micro Devices, Inc. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/common.mk b/common.mk index 4c13fd918..c6e52bb8e 100644 --- a/common.mk +++ b/common.mk @@ -118,7 +118,8 @@ get-noopt-cxxflags-for = $(strip $(CFLAGS_PRESET) \ get-refinit-cflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ -DBLIS_CNAME=$(1) \ - $(BUILD_FLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) get-refkern-cflags-for = $(strip $(call load-var-for,CROPTFLAGS,$(1)) \ @@ -126,23 +127,27 @@ get-refkern-cflags-for = $(strip $(call load-var-for,CROPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ $(COMPSIMDFLAGS) \ -DBLIS_CNAME=$(1) \ - $(BUILD_FLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) get-config-cflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ - $(BUILD_FLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) get-frame-cflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ - $(BUILD_FLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) get-kernel-cflags-for = $(strip $(call load-var-for,CKOPTFLAGS,$(1)) \ $(call load-var-for,CKVECFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ - $(BUILD_FLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) # When compiling sandboxes, we use flags similar to those of general framework @@ -153,19 +158,24 @@ get-kernel-cflags-for = $(strip $(call load-var-for,CKOPTFLAGS,$(1)) \ get-sandbox-c99flags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ $(CSBOXINCFLAGS) \ - $(BUILD_FLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) get-sandbox-cxxflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cxxflags-for,$(1)) \ $(CSBOXINCFLAGS) \ - $(BUILD_FLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ ) # Define a separate function that will return appropriate flags for use by # applications that want to use the same basic flags as those used when BLIS -# was compiled. (This is the same as get-frame-cflags-for(), except that it -# omits the BUILD_FLAGS, which are exclusively for use when BLIS is being -# compiled.) +# was compiled. (NOTE: This is the same as the $(get-frame-cflags-for ...) +# function, except that it omits two variables that contain flags exclusively +# for use when BLIS is being compiled/built: BUILD_CPPFLAGS, which contains a +# cpp macro that confirms that BLIS is being built; and BUILD_SYMFLAGS, which +# contains symbol export flags that are only needed when a shared library is +# being compiled/linked.) get-user-cflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ ) @@ -508,9 +518,9 @@ SOFLAGS := -shared ifeq ($(IS_WIN),yes) # Windows shared library link flags. ifeq ($(CC_VENDOR),clang) -SOFLAGS += -Wl,-def:build/libblis-symbols.def -Wl,-implib:$(BASE_LIB_PATH)/$(LIBBLIS).lib +SOFLAGS += -Wl,-implib:$(BASE_LIB_PATH)/$(LIBBLIS).lib else -SOFLAGS += -Wl,--export-all-symbols -Wl,--out-implib,$(BASE_LIB_PATH)/$(LIBBLIS).dll.a +SOFLAGS += -Wl,--out-implib,$(BASE_LIB_PATH)/$(LIBBLIS).dll.a endif else # Linux shared library link flags. @@ -532,6 +542,11 @@ ifeq ($(IS_WIN),no) LDFLAGS += -Wl,-rpath,$(BASE_LIB_PATH) endif endif +# On windows, use the shared library even if static is created. +ifeq ($(IS_WIN),yes) +LIBBLIS_L := $(LIBBLIS_SO) +LIBBLIS_LINK := $(LIBBLIS_SO_PATH) +endif endif @@ -610,7 +625,7 @@ endif $(foreach c, $(CONFIG_LIST_FAM), $(eval $(call append-var-for,CWARNFLAGS,$(c)))) -# --- Shared library (position-independent code) flags --- +# --- Position-independent code flags (shared libraries only) --- # Emit position-independent code for dynamic linking. ifeq ($(IS_WIN),yes) @@ -622,6 +637,71 @@ CPICFLAGS := -fPIC endif $(foreach c, $(CONFIG_LIST_FAM), $(eval $(call append-var-for,CPICFLAGS,$(c)))) +# --- Symbol exporting flags (shared libraries only) --- + +# NOTE: These flags are only applied when building BLIS and not used by +# applications that import BLIS compilation flags via the +# $(get-user-cflags-for ...) function. + +# Determine default export behavior / visibility of symbols for gcc. +ifeq ($(CC_VENDOR),gcc) +ifeq ($(IS_WIN),yes) +ifeq ($(EXPORT_SHARED),all) +BUILD_SYMFLAGS := -Wl,--export-all-symbols, -Wl,--enable-auto-import +else # ifeq ($(EXPORT_SHARED),public) +BUILD_SYMFLAGS := -Wl,--exclude-all-symbols +endif +else # ifeq ($(IS_WIN),no) +ifeq ($(EXPORT_SHARED),all) +# Export all symbols by default. +BUILD_SYMFLAGS := -fvisibility=default +else # ifeq ($(EXPORT_SHARED),public) +# Hide all symbols by default and export only those that have been annotated +# as needing to be exported. +BUILD_SYMFLAGS := -fvisibility=hidden +endif +endif +endif + +# Determine default export behavior / visibility of symbols for icc. +# NOTE: The Windows branches have been omitted since we currently make no +# effort to support Windows builds via icc (only gcc/clang via AppVeyor). +ifeq ($(CC_VENDOR),icc) +ifeq ($(EXPORT_SHARED),all) +# Export all symbols by default. +BUILD_SYMFLAGS := -fvisibility=default +else # ifeq ($(EXPORT_SHARED),public) +# Hide all symbols by default and export only those that have been annotated +# as needing to be exported. +BUILD_SYMFLAGS := -fvisibility=hidden +endif +endif + +# Determine default export behavior / visibility of symbols for clang. +ifeq ($(CC_VENDOR),clang) +ifeq ($(IS_WIN),yes) +ifeq ($(EXPORT_SHARED),all) +# NOTE: clang on Windows does not appear to support exporting all symbols +# by default, and therefore we ignore the value of EXPORT_SHARED. +BUILD_SYMFLAGS := +else # ifeq ($(EXPORT_SHARED),public) +# NOTE: The default behavior of clang on Windows is to hide all symbols +# and only export functions and other declarations that have beenannotated +# as needing to be exported. +BUILD_SYMFLAGS := +endif +else # ifeq ($(IS_WIN),no) +ifeq ($(EXPORT_SHARED),all) +# Export all symbols by default. +BUILD_SYMFLAGS := -fvisibility=default +else # ifeq ($(EXPORT_SHARED),public) +# Hide all symbols by default and export only those that have been annotated +# as needing to be exported. +BUILD_SYMFLAGS := -fvisibility=hidden +endif +endif +endif + # --- Language flags --- # Enable C99. @@ -685,8 +765,18 @@ endif # --- #pragma omp simd flags (used for reference kernels only) --- ifeq ($(PRAGMA_OMP_SIMD),yes) +ifeq ($(CC_VENDOR),gcc) COMPSIMDFLAGS := -fopenmp-simd else +ifeq ($(CC_VENDOR),clang) +COMPSIMDFLAGS := -fopenmp-simd +else +ifeq ($(CC_VENDOR),icc) +COMPSIMDFLAGS := -qopenmp-simd +endif +endif +endif +else # ifeq ($(PRAGMA_OMP_SIMD),no) COMPSIMDFLAGS := endif @@ -960,7 +1050,7 @@ VERS_DEF := -DBLIS_VERSION_STRING=\"$(VERSION)\" # Define a C preprocessor flag that is *only* defined when BLIS is being # compiled. (In other words, an application that #includes blis.h will not # get this cpp macro.) -BUILD_FLAGS := -DBLIS_IS_BUILDING_LIBRARY +BUILD_CPPFLAGS := -DBLIS_IS_BUILDING_LIBRARY diff --git a/config/amd64/make_defs.mk b/config/amd64/make_defs.mk index afea69558..70c0b692b 100644 --- a/config/amd64/make_defs.mk +++ b/config/amd64/make_defs.mk @@ -57,7 +57,7 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O2 -fomit-frame-pointer +COPTFLAGS := -O3 endif # Flags specific to optimized kernels. @@ -74,7 +74,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) CRVECFLAGS := $(CKVECFLAGS) +else +CRVECFLAGS := $(CKVECFLAGS) +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/bulldozer/make_defs.mk b/config/bulldozer/make_defs.mk index 15870c4cb..dec89a4c3 100644 --- a/config/bulldozer/make_defs.mk +++ b/config/bulldozer/make_defs.mk @@ -57,16 +57,16 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O2 -funroll-all-loops +COPTFLAGS := -O3 endif # Flags specific to optimized kernels. CKOPTFLAGS := $(COPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := -mfpmath=sse -mavx -mfma4 -march=bdver1 +CKVECFLAGS := -mfpmath=sse -mavx -mfma4 -march=bdver1 -mno-tbm -mno-xop -mno-lwp else ifeq ($(CC_VENDOR),clang) -CKVECFLAGS := -mfpmath=sse -mavx -mfma4 -march=bdver1 +CKVECFLAGS := -mfpmath=sse -mavx -mfma4 -march=bdver1 -mno-tbm -mno-xop -mno-lwp else $(error gcc or clang are required for this configuration.) endif @@ -74,7 +74,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +else CRVECFLAGS := $(CKVECFLAGS) +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/excavator/make_defs.mk b/config/excavator/make_defs.mk index 45fff9690..deb85c79b 100644 --- a/config/excavator/make_defs.mk +++ b/config/excavator/make_defs.mk @@ -57,16 +57,16 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O2 -fomit-frame-pointer +COPTFLAGS := -O3 endif # Flags specific to optimized kernels. CKOPTFLAGS := $(COPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := -mfpmath=sse -mavx -mfma -mno-fma4 -march=bdver4 +CKVECFLAGS := -mfpmath=sse -mavx -mfma -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp else ifeq ($(CC_VENDOR),clang) -CKVECFLAGS := -mfpmath=sse -mavx -mfma -mno-fma4 -march=bdver4 +CKVECFLAGS := -mfpmath=sse -mavx -mfma -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp else $(error gcc or clang are required for this configuration.) endif @@ -74,7 +74,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +else CRVECFLAGS := $(CKVECFLAGS) +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/generic/make_defs.mk b/config/generic/make_defs.mk index d491d072e..3388291da 100644 --- a/config/generic/make_defs.mk +++ b/config/generic/make_defs.mk @@ -78,7 +78,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) CRVECFLAGS := $(CKVECFLAGS) +else +CRVECFLAGS := $(CKVECFLAGS) +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/haswell/bli_cntx_init_haswell.c b/config/haswell/bli_cntx_init_haswell.c index 0682e6933..7f222415a 100644 --- a/config/haswell/bli_cntx_init_haswell.c +++ b/config/haswell/bli_cntx_init_haswell.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,9 +35,12 @@ #include "blis.h" +//GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + void bli_cntx_init_haswell( cntx_t* cntx ) { blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; // Set default kernel blocksizes and functions. bli_cntx_init_haswell_ref( cntx ); @@ -69,6 +73,7 @@ void bli_cntx_init_haswell( cntx_t* cntx ) cntx ); + // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( 4, @@ -118,12 +123,18 @@ void bli_cntx_init_haswell( cntx_t* cntx ) #if 1 bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + //bli_blksz_init_easy( &blkszs[ BLIS_MC ], 1008, 1008, 1008, 1008 ); + //bli_blksz_init_easy( &blkszs[ BLIS_MC ], 168, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 168, 72, 75, 192 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); #else bli_blksz_init_easy( &blkszs[ BLIS_MR ], 16, 8, 8, 4 ); bli_blksz_init_easy( &blkszs[ BLIS_NR ], 6, 6, 3, 3 ); -#endif - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); + //bli_blksz_init_easy( &blkszs[ BLIS_MC ], 1024, 1024, 1024, 1024 ); + //bli_blksz_init_easy( &blkszs[ BLIS_MC ], 112, 64, 56, 32 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 112, 72, 56, 44 ); bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); +#endif bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, 8, 8 ); bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, 8, 8 ); @@ -144,5 +155,62 @@ void bli_cntx_init_haswell( cntx_t* cntx ) BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, cntx ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], -1, 201, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], -1, 100, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], -1, 120, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 8, + //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], -1, 6, -1, -1, + -1, 9, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], -1, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], -1, 72, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], -1, 256, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], -1, 4080, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); } diff --git a/config/haswell/bli_family_haswell.h b/config/haswell/bli_family_haswell.h index dc75f01b2..58154692a 100644 --- a/config/haswell/bli_family_haswell.h +++ b/config/haswell/bli_family_haswell.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,7 +37,6 @@ //#define BLIS_FAMILY_H - #if 0 // -- LEVEL-3 MICRO-KERNEL CONSTANTS AND DEFINITIONS --------------------------- diff --git a/config/haswell/make_defs.mk b/config/haswell/make_defs.mk index 5d2f0a73b..102b34237 100644 --- a/config/haswell/make_defs.mk +++ b/config/haswell/make_defs.mk @@ -63,13 +63,13 @@ endif # Flags specific to optimized kernels. CKOPTFLAGS := $(COPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := -mavx2 -mfma -mfpmath=sse -march=core-avx2 +CKVECFLAGS := -mavx2 -mfma -mfpmath=sse -march=haswell else ifeq ($(CC_VENDOR),icc) CKVECFLAGS := -xCORE-AVX2 else ifeq ($(CC_VENDOR),clang) -CKVECFLAGS := -mavx2 -mfma -mfpmath=sse -march=core-avx2 +CKVECFLAGS := -mavx2 -mfma -mfpmath=sse -march=haswell else $(error gcc, icc, or clang is required for this configuration.) endif @@ -78,7 +78,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) #-funsafe-math-optimizations +else CRVECFLAGS := $(CKVECFLAGS) +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/intel64/make_defs.mk b/config/intel64/make_defs.mk index 442b81e3a..af462fdc3 100644 --- a/config/intel64/make_defs.mk +++ b/config/intel64/make_defs.mk @@ -78,7 +78,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) CRVECFLAGS := $(CKVECFLAGS) +else +CRVECFLAGS := $(CKVECFLAGS) +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/knc/make_defs.mk b/config/knc/make_defs.mk index 367b64b27..be3c9019d 100644 --- a/config/knc/make_defs.mk +++ b/config/knc/make_defs.mk @@ -70,7 +70,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +else CRVECFLAGS := $(CKVECFLAGS) +endif # Override the default value for LDFLAGS. LDFLAGS := -mmic diff --git a/config/knl/make_defs.mk b/config/knl/make_defs.mk index f4165f788..b08cf1e4d 100644 --- a/config/knl/make_defs.mk +++ b/config/knl/make_defs.mk @@ -99,7 +99,7 @@ endif # Note: We use AVX2 for reference kernels instead of AVX-512. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := -march=knl -mno-avx512f -mno-avx512pf -mno-avx512er -mno-avx512cd +CRVECFLAGS := -march=knl -mno-avx512f -mno-avx512pf -mno-avx512er -mno-avx512cd -funsafe-math-optimizations else ifeq ($(CC_VENDOR),icc) CRVECFLAGS := -xMIC-AVX512 diff --git a/config/penryn/make_defs.mk b/config/penryn/make_defs.mk index 294dd616a..41d2d939f 100644 --- a/config/penryn/make_defs.mk +++ b/config/penryn/make_defs.mk @@ -57,7 +57,7 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O2 -fomit-frame-pointer +COPTFLAGS := -O3 endif # Flags specific to optimized kernels. @@ -78,7 +78,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +else CRVECFLAGS := $(CKVECFLAGS) +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/piledriver/make_defs.mk b/config/piledriver/make_defs.mk index 155b0c002..bb23fbece 100644 --- a/config/piledriver/make_defs.mk +++ b/config/piledriver/make_defs.mk @@ -57,16 +57,16 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O2 -fomit-frame-pointer +COPTFLAGS := -O3 endif # Flags specific to optimized kernels. CKOPTFLAGS := $(COPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := -mfpmath=sse -mavx -mfma -mno-fma4 -march=bdver2 +CKVECFLAGS := -mfpmath=sse -mavx -mfma -march=bdver2 -mno-fma4 -mno-tbm -mno-xop -mno-lwp else ifeq ($(CC_VENDOR),clang) -CKVECFLAGS := -mfpmath=sse -mavx -mfma -mno-fma4 -march=bdver2 +CKVECFLAGS := -mfpmath=sse -mavx -mfma -march=bdver2 -mno-fma4 -mno-tbm -mno-xop -mno-lwp else $(error gcc or clang are required for this configuration.) endif @@ -74,7 +74,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +else CRVECFLAGS := $(CKVECFLAGS) +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/sandybridge/make_defs.mk b/config/sandybridge/make_defs.mk index f0d694f8c..23acc1708 100644 --- a/config/sandybridge/make_defs.mk +++ b/config/sandybridge/make_defs.mk @@ -63,13 +63,13 @@ endif # Flags specific to optimized kernels. CKOPTFLAGS := $(COPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := -mavx -mfpmath=sse -march=corei7-avx +CKVECFLAGS := -mavx -mfpmath=sse -march=sandybridge else ifeq ($(CC_VENDOR),icc) CKVECFLAGS := -xAVX else ifeq ($(CC_VENDOR),clang) -CKVECFLAGS := -mavx -mfpmath=sse -march=corei7-avx +CKVECFLAGS := -mavx -mfpmath=sse -march=sandybridge else $(error gcc, icc, or clang is required for this configuration.) endif @@ -78,7 +78,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +else CRVECFLAGS := $(CKVECFLAGS) +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/skx/make_defs.mk b/config/skx/make_defs.mk index e9319e476..27bea5ef5 100644 --- a/config/skx/make_defs.mk +++ b/config/skx/make_defs.mk @@ -89,7 +89,7 @@ endif # to overcome the AVX-512 frequency drop". (Issue #187) CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := -march=skylake-avx512 -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd +CRVECFLAGS := -march=skylake-avx512 -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd -funsafe-math-optimizations else ifeq ($(CC_VENDOR),icc) CRVECFLAGS := -xCORE-AVX2 diff --git a/config/steamroller/make_defs.mk b/config/steamroller/make_defs.mk index 6c093d244..a5b670704 100644 --- a/config/steamroller/make_defs.mk +++ b/config/steamroller/make_defs.mk @@ -57,16 +57,16 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O2 -fomit-frame-pointer +COPTFLAGS := -O3 endif # Flags specific to optimized kernels. CKOPTFLAGS := $(COPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS := -mfpmath=sse -mavx -mfma -mno-fma4 -march=bdver3 +CKVECFLAGS := -mfpmath=sse -mavx -mfma -march=bdver3 -mno-fma4 -mno-tbm -mno-xop -mno-lwp else ifeq ($(CC_VENDOR),clang) -CKVECFLAGS := -mfpmath=sse -mavx -mfma -mno-fma4 -march=bdver3 +CKVECFLAGS := -mfpmath=sse -mavx -mfma -march=bdver3 -mno-fma4 -mno-tbm -mno-xop -mno-lwp else $(error gcc or clang are required for this configuration.) endif @@ -74,7 +74,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +else CRVECFLAGS := $(CKVECFLAGS) +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/template/make_defs.mk b/config/template/make_defs.mk index ff89757c7..35edf71a1 100644 --- a/config/template/make_defs.mk +++ b/config/template/make_defs.mk @@ -57,7 +57,7 @@ endif ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else -COPTFLAGS := -O2 +COPTFLAGS := -O3 endif # Flags specific to optimized kernels. diff --git a/config/x86_64/make_defs.mk b/config/x86_64/make_defs.mk index 375ea7dec..4d038ff04 100644 --- a/config/x86_64/make_defs.mk +++ b/config/x86_64/make_defs.mk @@ -78,7 +78,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) CRVECFLAGS := $(CKVECFLAGS) +else +CRVECFLAGS := $(CKVECFLAGS) +endif # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 258d4e92d..09ca2dee0 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -35,9 +35,12 @@ #include "blis.h" +//GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + void bli_cntx_init_zen( cntx_t* cntx ) { blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; // Set default kernel blocksizes and functions. bli_cntx_init_zen_ref( cntx ); @@ -114,23 +117,27 @@ void bli_cntx_init_zen( cntx_t* cntx ) bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); /* - Multi Instance performance improvement of DGEMM when binded to a CCX - In Multi instance each thread runs a sequential DGEMM. - - a) If BLIS is run in a multi instance mode with - CPU freq 2.6/2.2 Ghz - DDR4 clock frequency 2400Mhz + Multi Instance performance degradation on different cores + a) CPU freq 2.6 Ghz + DDR4 2400 + Multi instance mode mc = 240, kc = 512, and nc = 2040 - has better performance on EPYC server, over the default block sizes. + + b) CPU freq 2.4Ghz + DDR4 2400 + Multi Instance mode + either + mc = 240, kc = 512 and nc = 2040 + (or) + mc = 390, kc = 512 and nc = 4080 - b) If BLIS is run in Single Instance mode + c) Higher frequency(3.1Ghz), single instance mode choose default value mc = 510, kc = 1024 and nc = 4080 */ // Zen optmized level 3 cache block sizes #ifdef BLIS_ENABLE_ZEN_BLOCK_SIZES - #if BLIS_ENABLE_SINGLE_INSTANCE_BLOCK_SIZES bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 510, 144, 72 ); @@ -138,7 +145,6 @@ void bli_cntx_init_zen( cntx_t* cntx ) bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); #else - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 72 ); bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 2040, 4080, 4080 ); @@ -150,9 +156,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); #endif - - - + //bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 2040, 4080, 4080 ); bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); @@ -172,5 +176,62 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, cntx ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], -1, 256, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], -1, 100, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], -1, 120, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 8, + //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], -1, 6, -1, -1, + -1, 9, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], -1, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], -1, 72, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], -1, 256, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], -1, 4080, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); } diff --git a/config/zen/bli_family_zen.h b/config/zen/bli_family_zen.h index 3f41a53bb..526e3a8b0 100644 --- a/config/zen/bli_family_zen.h +++ b/config/zen/bli_family_zen.h @@ -39,14 +39,13 @@ // By default, it is effective to parallelize the outer loops. // Setting these macros to 1 will force JR and IR inner loops // to be not paralleized. -#define BLIS_THREAD_MAX_IR 1 -#define BLIS_THREAD_MAX_JR 1 +#define BLIS_DEFAULT_MR_THREAD_MAX 1 +#define BLIS_DEFAULT_NR_THREAD_MAX 1 #define BLIS_ENABLE_ZEN_BLOCK_SIZES #define BLIS_ENABLE_SMALL_MATRIX #define BLIS_ENABLE_SMALL_MATRIX_TRSM - // This will select the threshold below which small matrix code will be called. #define BLIS_SMALL_MATRIX_THRES 700 #define BLIS_SMALL_M_RECT_MATRIX_THRES 160 @@ -64,6 +63,15 @@ #define D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES 90 #define D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO 22 +// Allow the sup implementation to combine some small edge case iterations in +// the 2nd loop of the panel-block algorithm (MR) and/or the 2nd loop of the +// block-panel algorithm (NR) with the last full iteration that precedes it. +// NOTE: These cpp macros need to be explicitly set to an integer since they +// are used at compile-time to create unconditional branches or dead code +// regions. +#define BLIS_ENABLE_SUP_MR_EXT 1 +#define BLIS_ENABLE_SUP_NR_EXT 0 + //#endif diff --git a/config/zen/make_defs.mk b/config/zen/make_defs.mk index ea5f0802c..f7f84aee2 100644 --- a/config/zen/make_defs.mk +++ b/config/zen/make_defs.mk @@ -46,10 +46,27 @@ AMD_CONFIG_FILE := amd_config.mk AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen -include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 +endif + +# Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) ifeq ($(CC_VENDOR),gcc) CKVECFLAGS += -march=znver1 endif + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +else +CRVECFLAGS := $(CKVECFLAGS) +endif + # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 25eae4866..452e2426f 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -38,7 +38,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) { blksz_t blkszs[ BLIS_NUM_BLKSZS ]; - + blksz_t thresh[ BLIS_NUM_THRESH ]; // Set default kernel blocksizes and functions. bli_cntx_init_zen2_ref( cntx ); @@ -135,5 +135,61 @@ void bli_cntx_init_zen2( cntx_t* cntx ) BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, cntx ); +// ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], -1, 256, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], -1, 100, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], -1, 120, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 8, + //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], -1, 6, -1, -1, + -1, 9, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], -1, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], -1, 72, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], -1, 256, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], -1, 4080, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); } diff --git a/config/zen2/make_defs.mk b/config/zen2/make_defs.mk index ac1a0545b..832681ca6 100644 --- a/config/zen2/make_defs.mk +++ b/config/zen2/make_defs.mk @@ -33,36 +33,56 @@ # # -# FLAGS that are specific to 'zen2' architecture are added here. -# FLAGS that are common for all the AMD architectures are present in config/zen/amd_config.mk -# # Declare the name of the current configuration and add it to the # running list of configurations included by common.mk. THIS_CONFIG := zen2 #CONFIGS_INCL += $(THIS_CONFIG) -# Include file containing common flags for all AMD architectures -AMD_CONFIG_FILE := amd_config.mk -AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen --include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) # # --- Determine the C compiler and related flags --- # + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 -fomit-frame-pointer +endif + # Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) ifeq ($(CC_VENDOR),gcc) # gcc 9.0 (clang ?) or later: -GCC_VERSION := $(strip $(shell gcc -dumpversion)) -ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) -CKVECFLAGS += -march=znver2 +#CKVECFLAGS := -mavx2 -mfpmath=sse -mfma -march=znver2 # gcc 6.0 (clang 4.0) or later: -else -CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store -endif +CKVECFLAGS := -mavx2 -mfpmath=sse -mfma -march=znver1 -mno-avx256-split-unaligned-store # gcc 4.9 (clang 3.5) or later: # possibly add zen-specific instructions: -mclzero -madx -mrdseed -mmwaitx -msha -mxsavec -mxsaves -mclflushopt -mpopcnt #CKVECFLAGS := -mavx2 -mfpmath=sse -mfma -march=bdver4 -mno-fma4 -mno-tbm -mno-xop -mno-lwp +else +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -mavx2 -mfpmath=sse -mfma -march=znver1 -mno-fma4 -mno-tbm -mno-xop -mno-lwp +else +$(error gcc or clang are required for this configuration.) endif +endif + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +CRVECFLAGS := $(CKVECFLAGS) + # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/configure b/configure index e72831d8f..1db3dfa0c 100755 --- a/configure +++ b/configure @@ -51,8 +51,6 @@ print_usage() #echo " " #echo " BLIS ${version}" echo " " - echo " Field G. Van Zee" - echo " " echo " Configure BLIS's build system for compilation using a specified" echo " configuration directory." echo " " @@ -72,30 +70,37 @@ print_usage() echo " " echo " -p PREFIX, --prefix=PREFIX" echo " " - echo " The path to which make will install all build products." - echo " If given, this option implies the following options:" - echo " --libdir=PREFIX/lib" - echo " --incdir=PREFIX/include" + echo " The common installation prefix for all files. If given," + echo " this option effectively implies:" + echo " --libdir=EXECPREFIX/lib" + echo " --includedir=PREFIX/include" echo " --sharedir=PREFIX/share" - echo " If not given, PREFIX defaults to \$(HOME)/blis. If PREFIX" + echo " where EXECPREFIX defaults to PREFIX. If this option is" + echo " not given, PREFIX defaults to '${prefix_def}'. If PREFIX" echo " refers to a directory that does not exist, it will be" echo " created." echo " " + echo " --exec-prefix=EXECPREFIX" + echo " " + echo " The installation prefix for libraries. Specifically, if" + echo " given, this option effectively implies:" + echo " --libdir=EXECPREFIX/lib" + echo " If not given, EXECPREFIX defaults to PREFIX, which may be" + echo " modified by the --prefix option. If EXECPREFIX refers to" + echo " a directory that does not exist, it will be created." + echo " " echo " --libdir=LIBDIR" echo " " - echo " The path to which make will install libraries. If given," - echo " LIBDIR will override the corresponding directory implied" - echo " by --prefix; if not not given, LIBDIR defaults to" - echo " PREFIX/lib. If LIBDIR refers to a directory that does" - echo " not exist, it will be created." + echo " The path to which make will install libraries. If not" + echo " given, LIBDIR defaults to PREFIX/lib. If LIBDIR refers to" + echo " a directory that does not exist, it will be created." echo " " echo " --includedir=INCDIR" echo " " echo " The path to which make will install development header" - echo " files. If given, INCDIR will override the corresponding" - echo " directory implied by --prefix; if not given, INCDIR" - echo " defaults to PREFIX/include. If INCDIR refers to a" - echo " directory that does not exist, it will be created." + echo " files. If not given, INCDIR defaults to PREFIX/include." + echo " If INCDIR refers to a directory that does not exist, it" + echo " will be created." echo " " echo " --sharedir=SHAREDIR" echo " " @@ -104,18 +109,9 @@ print_usage() echo " and LDFLAGS). These files allow certain BLIS makefiles," echo " such as those in the examples or testsuite directories, to" echo " operate on an installed copy of BLIS rather than a local" - echo " (and possibly uninstalled) copy. If given, SHAREDIR will" - echo " override the corresponding directory implied by --prefix;" - echo " if not given, SHAREDIR defaults to PREFIX/share. If" - echo " SHAREDIR refers to a directory that does not exist, it" - echo " will be created." - echo " " - echo " -d DEBUG, --enable-debug[=DEBUG]" - echo " " - echo " Enable debugging symbols in the library. If argument" - echo " DEBUG is given as 'opt', then optimization flags are" - echo " kept in the framework, otherwise optimization is" - echo " turned off." + echo " (and possibly uninstalled) copy. If not given, SHAREDIR" + echo " defaults to PREFIX/share. If SHAREDIR refers to a" + echo " directory that does not exist, it will be created." echo " " echo " --enable-verbose-make, --disable-verbose-make" echo " " @@ -129,6 +125,13 @@ print_usage() echo " even if the command plus command line arguments exceeds" echo " the operating system limit (ARG_MAX)." echo " " + echo " -d DEBUG, --enable-debug[=DEBUG]" + echo " " + echo " Enable debugging symbols in the library. If argument" + echo " DEBUG is given as 'opt', then optimization flags are" + echo " kept in the framework, otherwise optimization is" + echo " turned off." + echo " " echo " --disable-static, --enable-static" echo " " echo " Disable (enabled by default) building BLIS as a static" @@ -141,6 +144,23 @@ print_usage() echo " library. If the shared library build is disabled, the" echo " static library build must remain enabled." echo " " + echo " -e SYMBOLS, --export-shared[=SYMBOLS]" + echo " " + echo " Specify the subset of library symbols that are exported" + echo " within a shared library. Valid values for SYMBOLS are:" + echo " 'public' (the default) and 'all'. By default, only" + echo " functions and variables that belong to public APIs are" + echo " exported in shared libraries. However, the user may" + echo " instead export all symbols in BLIS, even those that were" + echo " intended for internal use only. Note that the public APIs" + echo " encompass all functions that almost any user would ever" + echo " want to call, including the BLAS/CBLAS compatibility APIs" + echo " as well as the basic and expert interfaces to the typed" + echo " and object APIs that are unique to BLIS. Also note that" + echo " changing this option to 'all' will have no effect in some" + echo " environments, such as when compiling with clang on" + echo " Windows." + echo " " echo " -t MODEL, --enable-threading[=MODEL], --disable-threading" echo " " echo " Enable threading in the library, using threading model" @@ -222,6 +242,16 @@ print_usage() echo " only be enabled when mixed domain/precision support is" echo " enabled." echo " " + echo " --disable-sup-handling, --enable-sup-handling" + echo " " + echo " Disable (enabled by default) handling of small/skinny" + echo " matrix problems via separate code branches. When disabled," + echo " these small/skinny level-3 operations will be performed by" + echo " the conventional implementation, which is optimized for" + echo " medium and large problems. Note that what qualifies as" + echo " \"small\" depends on thresholds that may vary by sub-" + echo " configuration." + echo " " echo " -s NAME --enable-sandbox=NAME" echo " " echo " Enable a separate sandbox implementation of gemm. This" @@ -278,6 +308,7 @@ print_usage() echo " Environment Variables:" echo " " echo " CC Specifies the C compiler to use." + echo " CXX Specifies the C++ compiler to use (sandbox only)." echo " RANLIB Specifies the ranlib executable to use." echo " AR Specifies the archiver to use." echo " CFLAGS Specifies additional compiler flags to use (prepended)." @@ -1016,7 +1047,7 @@ auto_detect() # Set the linker flags. We need pthreads because it is needed for # parts of bli_arch.c unrelated to bli_arch_string(), which is called # by the main() function in ${main_c}. - if [ $is_win = no ]; then + if [[ $is_win == no || "$cc_vendor" != "clang" ]]; then ldflags="${LIBPTHREAD--lpthread}" fi @@ -1294,8 +1325,7 @@ get_compiler_version() # to OS X's egrep only returning the first match. cc_vendor=$(echo "${vendor_string}" | egrep -o 'icc|gcc|clang|emcc|pnacl|IBM' | { read first rest ; echo $first ; }) if [ "${cc_vendor}" = "icc" -o \ - "${cc_vendor}" = "gcc" -o \ - "${cc_vendor}" = "clang" ]; then + "${cc_vendor}" = "gcc" ]; then cc_version=$(${cc} -dumpversion) else cc_version=$(echo "${vendor_string}" | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*' | { read first rest ; echo ${first} ; }) @@ -1343,7 +1373,7 @@ check_compiler() # Specific: # # skx: icc 15.0.1+, gcc 6.0+, clang 3.9+ - # knl: icc 14.0.1+, gcc 5.0+, clang 3.5+ + # knl: icc 14.0.1+, gcc 5.0+, clang 3.9+ # haswell: any # sandybridge: any # penryn: any @@ -1418,27 +1448,42 @@ check_compiler() # clang if [ "x${cc_vendor}" = "xclang" ]; then - - if [ ${cc_major} -lt 3 ]; then - echoerr_unsupportedcc - fi - if [ ${cc_major} -eq 3 ]; then - if [ ${cc_minor} -lt 3 ]; then + if [ "$(echo ${vendor_string} | grep -o Apple)" = "Apple" ]; then + if [ ${cc_major} -lt 5 ]; then echoerr_unsupportedcc fi - if [ ${cc_minor} -lt 5 ]; then + # See https://en.wikipedia.org/wiki/Xcode#Toolchain_versions + if [ ${cc_major} -eq 5 ]; then + # Apple clang 5.0 is clang 3.4svn blacklistcc_add "excavator" blacklistcc_add "zen" - blacklistcc_add "knl" fi - if [ ${cc_minor} -lt 9 ]; then + if [ ${cc_major} -lt 7 ]; then + blacklistcc_add "knl" blacklistcc_add "skx" fi - fi - if [ ${cc_major} -lt 4 ]; then - # See comment above regarding zen support. - #blacklistcc_add "zen" - : # explicit no-op since bash can't handle empty loop bodies. + else + if [ ${cc_major} -lt 3 ]; then + echoerr_unsupportedcc + fi + if [ ${cc_major} -eq 3 ]; then + if [ ${cc_minor} -lt 3 ]; then + echoerr_unsupportedcc + fi + if [ ${cc_minor} -lt 5 ]; then + blacklistcc_add "excavator" + blacklistcc_add "zen" + fi + if [ ${cc_minor} -lt 9 ]; then + blacklistcc_add "knl" + blacklistcc_add "skx" + fi + fi + if [ ${cc_major} -lt 4 ]; then + # See comment above regarding zen support. + #blacklistcc_add "zen" + : # explicit no-op since bash can't handle empty loop bodies. + fi fi fi } @@ -1496,8 +1541,8 @@ check_assembler() # # The assembler on OS X won't recognize AVX-512 without help. - if [ "$(uname -s)" == "Darwin" ]; then - cflags="-Wa,-march=knl" + if [ "${cc_vendor}" == "clang" ]; then + cflags="-march=knl" fi asm_fp=$(find ${asm_dir} -name "avx512f.s") @@ -1513,8 +1558,8 @@ check_assembler() # # The assembler on OS X won't recognize AVX-512 without help. - if [ "$(uname -s)" == "Darwin" ]; then - cflags="-Wa,-march=skylake-avx512" + if [ "${cc_vendor}" == "clang" ]; then + cflags="-march=skylake-avx512" fi asm_fp=$(find ${asm_dir} -name "avx512dq.s") @@ -1731,21 +1776,33 @@ main() # -- configure options -- - # The user-given install prefix and a flag indicating it was given. - #install_prefix_def="${HOME}/blis" - install_prefix_user=${HOME}/blis # default to this directory. + # Define the default prefix so that the print_usage() function can + # output it in the --help text. + prefix_def='/usr/local' + + # The installation prefix, assigned its default value, and a flag to + # track whether or not it was given by the user. + prefix=${prefix_def} prefix_flag='' - # The user-given install libdir and a flag indicating it was given. - install_libdir_user='' + # The installation exec_prefix, assigned its default value, and a flag to + # track whether or not it was given by the user. + exec_prefix='${prefix}' + exec_prefix_flag='' + + # The installation libdir, assigned its default value, and a flag to + # track whether or not it was given by the user. + libdir='${exec_prefix}/lib' libdir_flag='' - # The user-given install includedir and a flag indicating it was given. - install_incdir_user='' - incdir_flag='' + # The installation includedir, assigned its default value, and a flag to + # track whether or not it was given by the user. + includedir='${prefix}/include' + includedir_flag='' - # The user-given install sharedir and a flag indicating it was given. - install_sharedir_user='' + # The installation sharedir, assigned its default value, and a flag to + # track whether or not it was given by the user. + sharedir='${prefix}/share' sharedir_flag='' # The preset value of CFLAGS and LDFLAGS (ie: compiler and linker flags @@ -1758,7 +1815,7 @@ main() debug_flag='' # The threading flag. - threading_model='no' + threading_model='off' # The method of assigning micropanels to threads in the JR and JR loops. thread_part_jrir='slab' @@ -1772,6 +1829,7 @@ main() enable_arg_max_hack='no' enable_static='yes' enable_shared='yes' + export_shared='public' enable_pba_pools='yes' enable_sba_pools='yes' enable_mem_tracing='no' @@ -1781,6 +1839,7 @@ main() enable_cblas='no' enable_mixed_dt='yes' enable_mixed_dt_extra_mem='yes' + enable_sup_handling='yes' enable_memkind='' # The default memkind value is determined later on. force_version='no' @@ -1821,7 +1880,7 @@ main() # Process our command line options. unset OPTIND - while getopts ":hp:d:s:t:r:qci:b:-:" opt; do + while getopts ":hp:d:e:s:t:r:qci:b:-:" opt; do case $opt in -) case "$OPTARG" in @@ -1833,19 +1892,23 @@ main() ;; prefix=*) prefix_flag=1 - install_prefix_user=${OPTARG#*=} + prefix=${OPTARG#*=} + ;; + exec-prefix=*) + exec_prefix_flag=1 + exec_prefix=${OPTARG#*=} ;; libdir=*) libdir_flag=1 - install_libdir_user=${OPTARG#*=} + libdir=${OPTARG#*=} ;; includedir=*) - incdir_flag=1 - install_incdir_user=${OPTARG#*=} + includedir_flag=1 + includedir=${OPTARG#*=} ;; sharedir=*) sharedir_flag=1 - install_sharedir_user=${OPTARG#*=} + sharedir=${OPTARG#*=} ;; enable-debug) debug_flag=1 @@ -1882,15 +1945,18 @@ main() disable-shared) enable_shared='no' ;; + export-shared=*) + export_shared=${OPTARG#*=} + ;; enable-threading=*) threading_model=${OPTARG#*=} ;; + disable-threading) + threading_model='off' + ;; thread-part-jrir=*) thread_part_jrir=${OPTARG#*=} ;; - disable-threading) - threading_model='no' - ;; enable-pba-pools) enable_pba_pools='yes' ;; @@ -1946,6 +2012,12 @@ main() disable-mixed-dt-extra-mem) enable_mixed_dt_extra_mem='no' ;; + enable-sup-handling) + enable_sup_handling='yes' + ;; + disable-sup-handling) + enable_sup_handling='no' + ;; with-memkind) enable_memkind='yes' ;; @@ -1967,12 +2039,15 @@ main() ;; p) prefix_flag=1 - install_prefix_user=$OPTARG + prefix=$OPTARG ;; d) debug_flag=1 debug_type=$OPTARG ;; + e) + export_shared=$OPTARG + ;; s) sandbox_flag=1 sandbox=$OPTARG @@ -2459,54 +2534,49 @@ main() # -- Prepare variables for subsitution into template files ----------------- - # Parse the status of the install prefix and echo feedback. + # Parse the status of the prefix option and echo feedback. if [ -n "${prefix_flag}" ]; then - echo "${script_name}: detected --prefix='${install_prefix_user}'." + echo "${script_name}: detected --prefix='${prefix}'." else - echo "${script_name}: no install prefix option given; defaulting to '${install_prefix_user}'." + echo "${script_name}: no install prefix option given; defaulting to '${prefix}'." fi - # Set initial (candidate) values for the libdir and includedir using the - # install prefix that was determined above. - install_libdir=${install_prefix_user}/lib - install_incdir=${install_prefix_user}/include - install_sharedir=${install_prefix_user}/share + # Parse the status of the exec_prefix option and echo feedback. + if [ -n "${exec_prefix_flag}" ]; then + echo "${script_name}: detected --exec-prefix='${exec_prefix}'." + else + echo "${script_name}: no install exec_prefix option given; defaulting to PREFIX." + fi - # Set the install libdir, if it was specified. Note that this will override - # the default libdir implied by the install prefix, even if both options - # were given. + # Parse the status of the libdir option and echo feedback. if [ -n "${libdir_flag}" ]; then - echo "${script_name}: detected --libdir='${install_libdir_user}'." - install_libdir=${install_libdir_user} + echo "${script_name}: detected --libdir='${libdir}'." else - echo "${script_name}: no install libdir option given; defaulting to PREFIX/lib." + echo "${script_name}: no install libdir option given; defaulting to EXECPREFIX/lib." fi - # Set the install includedir, if it was specified. Note that this will - # override the default includedir implied by the install prefix, even if - # both options were given. - if [ -n "${incdir_flag}" ]; then - echo "${script_name}: detected --includedir='${install_incdir_user}'." - install_incdir=${install_incdir_user} + # Parse the status of the includedir option and echo feedback. + if [ -n "${includedir_flag}" ]; then + echo "${script_name}: detected --includedir='${includedir}'." else echo "${script_name}: no install includedir option given; defaulting to PREFIX/include." fi - # Set the install sharedir, if it was specified. Note that this will - # override the default sharedir implied by the install prefix, even if - # both options were given. + # Parse the status of the sharedir option and echo feedback. if [ -n "${sharedir_flag}" ]; then - echo "${script_name}: detected --sharedir='${install_sharedir_user}'." - install_sharedir=${install_sharedir_user} + echo "${script_name}: detected --sharedir='${sharedir}'." else echo "${script_name}: no install sharedir option given; defaulting to PREFIX/share." fi # Echo the installation directories that we settled on. echo "${script_name}: final installation directories:" - echo "${script_name}: libdir: ${install_libdir}" - echo "${script_name}: includedir: ${install_incdir}" - echo "${script_name}: sharedir: ${install_sharedir}" + echo "${script_name}: prefix: "${prefix} + echo "${script_name}: exec_prefix: "${exec_prefix} + echo "${script_name}: libdir: "${libdir} + echo "${script_name}: includedir: "${includedir} + echo "${script_name}: sharedir: "${sharedir} + echo "${script_name}: NOTE: the variables above can be overridden when running make." # Check if CFLAGS is non-empty. if [ -n "${CFLAGS}" ]; then @@ -2573,6 +2643,23 @@ main() exit 1 fi + # Check if the "export shared" flag was specified. + if [ "x${export_shared}" = "xall" ]; then + if [ "x${enable_shared}" = "xyes" ]; then + echo "${script_name}: exporting all symbols within shared library." + else + echo "${script_name}: ignoring request to export all symbols within shared library." + fi + elif [ "x${export_shared}" = "xpublic" ]; then + if [ "x${enable_shared}" = "xyes" ]; then + echo "${script_name}: exporting only public symbols within shared library." + fi + else + echo "${script_name}: *** Invalid argument '${export_shared}' to --export-shared option given." + echo "${script_name}: *** Please use 'public' or 'all'." + exit 1 + fi + # Check the threading model flag and standardize its value, if needed. # NOTE: 'omp' is deprecated but still supported; 'openmp' is preferred. enable_openmp='no' @@ -2594,9 +2681,11 @@ main() enable_pthreads='yes' enable_pthreads_01=1 threading_model="pthreads" # Standardize the value. - elif [ "x${threading_model}" = "xno" ] || + elif [ "x${threading_model}" = "xoff" ] || + [ "x${threading_model}" = "xno" ] || [ "x${threading_model}" = "xnone" ]; then echo "${script_name}: threading is disabled." + threading_model="off" else echo "${script_name}: *** Unsupported threading model: ${threading_model}." exit 1 @@ -2707,6 +2796,13 @@ main() enable_mixed_dt_extra_mem_01=0 enable_mixed_dt_01=0 fi + if [ "x${enable_sup_handling}" = "xyes" ]; then + echo "${script_name}: small matrix handling is enabled." + enable_sup_handling_01=1 + else + echo "${script_name}: small matrix handling is disabled." + enable_sup_handling_01=0 + fi # Report integer sizes. if [ "x${int_type_size}" = "x32" ]; then @@ -2758,13 +2854,15 @@ main() # Variables that may contain forward slashes, such as paths, need extra # escaping when used in sed commands. We insert those extra escape # characters here so that the sed commands below do the right thing. - os_name_esc=$(echo "${os_name}" | sed 's/\//\\\//g') - install_libdir_esc=$(echo "${install_libdir}" | sed 's/\//\\\//g') - install_incdir_esc=$(echo "${install_incdir}" | sed 's/\//\\\//g') - install_sharedir_esc=$(echo "${install_sharedir}" | sed 's/\//\\\//g') - dist_path_esc=$(echo "${dist_path}" | sed 's/\//\\\//g') - cc_esc=$(echo "${found_cc}" | sed 's/\//\\\//g') - cxx_esc=$(echo "${found_cxx}" | sed 's/\//\\\//g') + os_name_esc=$(echo "${os_name}" | sed 's/\//\\\//g') + prefix_esc=$(echo "${prefix}" | sed 's/\//\\\//g') + exec_prefix_esc=$(echo "${exec_prefix}" | sed 's/\//\\\//g') + libdir_esc=$(echo "${libdir}" | sed 's/\//\\\//g') + includedir_esc=$(echo "${includedir}" | sed 's/\//\\\//g') + sharedir_esc=$(echo "${sharedir}" | sed 's/\//\\\//g') + dist_path_esc=$(echo "${dist_path}" | sed 's/\//\\\//g') + cc_esc=$(echo "${found_cc}" | sed 's/\//\\\//g') + cxx_esc=$(echo "${found_cxx}" | sed 's/\//\\\//g') #sandbox_relpath_esc=$(echo "${sandbox_relpath}" | sed 's/\//\\\//g') # For RANLIB, if the variable is not set, we use a default value of @@ -2779,7 +2877,7 @@ main() # For Windows builds, clear the libpthread_esc variable so that # no pthreads library is substituted into config.mk. (Windows builds # employ an implementation of pthreads that is internal to BLIS.) - if [ $is_win = yes ]; then + if [[ $is_win == yes && "$cc_vendor" == "clang" ]]; then libpthread_esc= fi @@ -2821,13 +2919,13 @@ main() # -- Determine whether we are performing an out-of-tree build -------------- - if [ ${dist_path} != "./" ]; then + if [ "${dist_path}" != "./" ]; then # At this point, we know the user did not run "./configure". But we # have not yet ruled out "/configure" or some # equivalent # that uses relative paths. To further rule out these possibilities, # we create a dummy file in the current build directory. - touch ./${dummy_file} + touch "./${dummy_file}" # If the dummy file we just created in the current directory does not # appear in the source distribution path, then we are in a different @@ -2871,14 +2969,17 @@ main() | sed -e "s/@ldflags_preset@/${ldflags_preset_esc}/g" \ | sed -e "s/@debug_type@/${debug_type}/g" \ | sed -e "s/@threading_model@/${threading_model}/g" \ - | sed -e "s/@install_libdir@/${install_libdir_esc}/g" \ - | sed -e "s/@install_incdir@/${install_incdir_esc}/g" \ - | sed -e "s/@install_sharedir@/${install_sharedir_esc}/g" \ + | sed -e "s/@prefix@/${prefix_esc}/g" \ + | sed -e "s/@exec_prefix@/${exec_prefix_esc}/g" \ + | sed -e "s/@libdir@/${libdir_esc}/g" \ + | sed -e "s/@includedir@/${includedir_esc}/g" \ + | sed -e "s/@sharedir@/${sharedir_esc}/g" \ | sed -e "s/@enable_verbose@/${enable_verbose}/g" \ | sed -e "s/@configured_oot@/${configured_oot}/g" \ | sed -e "s/@enable_arg_max_hack@/${enable_arg_max_hack}/g" \ | sed -e "s/@enable_static@/${enable_static}/g" \ | sed -e "s/@enable_shared@/${enable_shared}/g" \ + | sed -e "s/@export_shared@/${export_shared}/g" \ | sed -e "s/@enable_blas@/${enable_blas}/g" \ | sed -e "s/@enable_cblas@/${enable_cblas}/g" \ | sed -e "s/@enable_memkind@/${enable_memkind}/g" \ @@ -2910,6 +3011,7 @@ main() | sed -e "s/@enable_cblas@/${enable_cblas_01}/g" \ | sed -e "s/@enable_mixed_dt@/${enable_mixed_dt_01}/g" \ | sed -e "s/@enable_mixed_dt_extra_mem@/${enable_mixed_dt_extra_mem_01}/g" \ + | sed -e "s/@enable_sup_handling@/${enable_sup_handling_01}/g" \ | sed -e "s/@enable_memkind@/${enable_memkind_01}/g" \ | sed -e "s/@enable_pragma_omp_simd@/${enable_pragma_omp_simd_01}/g" \ | sed -e "s/@enable_sandbox@/${enable_sandbox_01}/g" \ diff --git a/docs/BuildSystem.md b/docs/BuildSystem.md index 84004f886..28137d8b3 100644 --- a/docs/BuildSystem.md +++ b/docs/BuildSystem.md @@ -9,6 +9,9 @@ * **[Step 3b: Testing (optional)](BuildSystem.md#step-3b-testing-optional)** * **[Step 4: Installation](BuildSystem.md#step-4-installation)** * **[Cleaning out build products](BuildSystem.md#cleaning-out-build-products)** +* **[Compiling with BLIS](BuildSystem.md#compiling-with-blis)** + * [Disabling BLAS prototypes](BuildSystem.md#disabling-blas-prototypes) + * [CBLAS](BuildSystem.md#cblas) * **[Linking against BLIS](BuildSystem.md#linking-against-blis)** * **[Uninstalling](BuildSystem.md#uninstalling)** * **[make targets](BuildSystem.md#make-targets)** @@ -83,11 +86,11 @@ Alternatively, `configure` can automatically select a configuration based on you ``` $ ./configure auto ``` -However, as of this writing, only a limited number of architectures are detected. If the `configure` script is not able to detect your architecture, the `generic` configuration will be used. +However, as of this writing, only a limited number of architectures are detected. If the `configure` script is not able to detect your architecture, the `generic` configuration will be used. Upon running configure, you will get output similar to the following. The exact output will depend on whether you cloned BLIS from a `git` repository or whether you obtained BLIS via a downloadable tarball from the [releases](https://github.com/flame/blis/releases) page. ``` -$ ./configure haswell +$ ./configure --prefix=$HOME/blis haswell configure: using 'gcc' compiler. configure: found gcc version 5.4.0 (maj: 5, min: 4, rev: 0). configure: checking for blacklisted configurations due to gcc 5.4.0. @@ -166,17 +169,11 @@ The installation prefix can be specified via the `--prefix=PREFIX` option: ``` $ ./configure --prefix=/usr ``` -This will cause libraries to eventually be installed (via `make install`) to `PREFIX/lib` and development headers to be installed to `PREFIX/include`. (The default value of `PREFIX` is `$(HOME)/blis`.) You can also specify the library install directory separately from the development header install directory with the `--libdir=LIBDIR` and `--includedir=INCDIR` options, respectively: +This will cause libraries to eventually be installed (via `make install`) to `PREFIX/lib` and development headers to be installed to `PREFIX/include`. (The default value of `PREFIX` is `/usr/local`.) You can also specify the library install directory separately from the development header install directory with the `--libdir=LIBDIR` and `--includedir=INCDIR` options, respectively: ``` $ ./configure --libdir=/usr/lib --includedir=/usr/include ``` -The `--libdir=LIBDIR` and `--includedir=INCDIR` options will override any `PREFIX` path, whether it was specified explicitly via `--prefix` or implicitly (via the default). That is, `LIBDIR` defaults to `PREFIX/lib` and `INCDIR` defaults to `PREFIX/include`, but each will be overriden by their respective `--libdir`/`--includedir` options. So, -``` -$ ./configure --libdir=/usr/lib - -``` -will configure BLIS to install libraries to `/usr/lib` and header files to the default location (`$HOME/blis/include`). -Also, note that `configure` will create any installation directories that do not already exist. +The `--libdir=LIBDIR` and `--includedir=INCDIR` options will override any path implied by `PREFIX`, whether it was specified explicitly via `--prefix` or implicitly (via the default). That is, `LIBDIR` defaults to `EXECPREFIX/lib` (where `EXECPREFIX`, set via `--exec-prefix=EXECPREFIX`, defaults to `PREFIX`) and `INCDIR` defaults to `PREFIX/include`, but `LIBDIR` and `INCDIR` will each be overriden by their respective `--libdir`/`--includedir` options. There is a third related option, `--sharedir=SHAREDIR`, where `SHAREDIR` defaults to `PREFIX/share`. This option specifies the installation directory for certain makefile fragments that contain variables determined by `configure` (e.g. `CC`, `CFLAGS`, `LDFLAGS`, etc.). These files allow certain BLIS makefiles, such as those in the `examples` or `testsuite` directories, to operate on an installed copy of BLIS rather than a local (and possibly uninstalled) copy. For a complete list of supported `configure` options and arguments, run `configure` with the `-h` option: ``` @@ -338,6 +335,47 @@ Removing include. Running the `distclean` target is like saying, "Remove anything ever created by the build system." +## Compiling with BLIS + +All BLIS definitions and prototypes may be included in your C source file by including a single header file, `blis.h`: +```c +#include "stdio.h" +#include "stdlib.h" +#include "otherstuff.h" +#include "blis.h" +``` +If the BLAS compatibility layer was enabled at configure-time (as it is by default), then `blis.h` will also provide BLAS prototypes to your source code. + + +### Disabling BLAS prototypes + +Some applications already `#include` a header that contains BLAS prototypes. This can cause problems if those applications also try to `#include` the BLIS header file, as shown above. Suppose for a moment that `otherstuff.h` in the example above already provides BLAS prototypes. +``` +$ gcc -I/path/to/blis -I/path/to/otherstuff -c main.c -o main.o +In file included from main.c:41:0: +/path/to/blis/blis.h:36900:111: error: conflicting declaration of C function ‘int xerbla_(const bla_character*, const bla_integer*, ftnlen)’ + TEF770(xerbla)(const bla_character *srname, const bla_integer *info, ftnlen srname_len); +``` +If your application is already declaring (prototyping) BLAS functions, then you may disable those prototypes from being defined included within `blis.h`. This prevents `blis.h` from re-declaring those prototypes, or, allows your other header to declare those functions for the first time, depending on the order that you `#include` the headers. +```c +#include "stdio.h" +#include "stdlib.h" +#include "otherstuff.h" +#define BLIS_DISABLE_BLAS_DEFS // disable BLAS prototypes within BLIS. +#include "blis.h" +``` +By `#defining` the `BLIS_DISABLE_BLAS_DEFS` macro, we signal to `blis.h` that it should skip over the BLAS prototypes, but otherwise `#include` everything else as it normally would. Note that `BLIS_DISABLE_BLAS_DEFS` must be `#defined` *prior* to the `#include "blis.h"` directive in order for it to have any effect. + + +### CBLAS + +If you build BLIS with CBLAS enabled and you wish to access CBLAS function prototypes from within your application, you will have to `#include` the `cblas.h` header separately from `blis.h`. +``` +#include "blis.h" +#include "cblas.h" +``` + + ## Linking against BLIS Once you have instantiated (configured and compiled, and perhaps installed) a BLIS library, you can link to it in your application's makefile as you would any other library. The following is an abbreviated makefile for a small hypothetical application that has just two external dependencies: BLIS and the standard C math library. We also link against libpthread since that library has been a runtime dependency of BLIS since 70640a3 (December 2017). @@ -357,7 +395,7 @@ OBJS = main.o util.o other.o %.o: %.c $(CC) $(CFLAGS) -c $< -o $@ -all: $(OBJS) +all: $(OBJS) $(LINKER) $(OBJS) $(BLIS_LIB) $(OTHER_LIBS) -o my_program.x ``` The above example assumes you will want to include BLIS definitions and function prototypes into your application via `#include blis.h`. (If you are only using the BLIS via the BLAS compatibility layer, including `blis.h` is not necessary.) Since BLIS headers are installed into a `blis` subdirectory of `PREFIX/include`, you must make sure that the compiler knows where to find the `blis.h` header file. This is typically accomplished by inserting `#include "blis.h"` into your application's source code files and compiling the code with `-I PREFIX/include/blis`. diff --git a/docs/HardwareSupport.md b/docs/HardwareSupport.md index 41036d51c..adba02f19 100644 --- a/docs/HardwareSupport.md +++ b/docs/HardwareSupport.md @@ -12,8 +12,8 @@ The following table lists architectures for which there exist optimized level-3 A few remarks / reminders: * Optimizing only the [gemm microkernel](KernelsHowTo.md#gemm-microkernel) will result in optimal performance for all [level-3 operations](BLISTypedAPI#level-3-operations) except `trsm` (which will typically achieve 60 - 80% of attainable peak performance). * The [trsm](BLISTypedAPI#trsm) operation needs the [gemmtrsm microkernel(s)](KernelsHowTo.md#gemmtrsm-microkernels), in addition to the aforementioned [gemm microkernel](KernelsHowTo.md#gemm-microkernel), in order reach optimal performance. - * Induced complex (1m) implementations are employed in all situations where the real domain [gemm microkernel](KernelsHowTo.md#gemm-microkernel) of the corresponding precision is available. Please see our [ACM TOMS article on the 1m method](https://github.com/flame/blis#citations) for more info on this topic. - * Some microarchitectures use the same sub-configuration. This is not a typo. For example, Haswell and Broadwell systems as well as "desktop" (non-server) versions of Skylake, Kabylake, and Coffeelake all use the `haswell` sub-configuration and the kernels registered therein. + * Induced complex (1m) implementations are employed in all situations where the real domain [gemm microkernel](KernelsHowTo.md#gemm-microkernel) of the corresponding precision is available, but the "native" complex domain gemm microkernel is unavailable. Note that the table below lists native kernels, so if a microarchitecture lists only `sd`, support for both `c` and `z` datatypes will be provided via the 1m method. (Note: most people cannot tell the difference between native and 1m-based performance.) Please see our [ACM TOMS article on the 1m method](https://github.com/flame/blis#citations) for more info on this topic. + * Some microarchitectures use the same sub-configuration. *This is not a typo.* For example, Haswell and Broadwell systems as well as "desktop" (non-server) versions of Skylake, Kaby Lake, and Coffee Lake all use the `haswell` sub-configuration and the kernels registered therein. Microkernels can be recycled in this manner because the key detail that determines level-3 performance outcomes is actually the vector ISA, not the microarchitecture. In the previous example, all of the microarchitectures listed support AVX2 (but not AVX-512), and therefore they can reuse the same microkernels. * Remember that you (usually) don't have to choose your sub-configuration manually! Instead, you can always request configure-time hardware detection via `./configure auto`. This will defer to internal logic (based on CPUID for x86_64 systems) that will attempt to choose the appropriate sub-configuration automatically. | Vendor/Microarchitecture | BLIS sub-configuration | `gemm` | `gemmtrsm` | @@ -26,7 +26,7 @@ A few remarks / reminders: | Intel Core2 (SSE3) | `penryn` | `sd` | `d` | | Intel Sandy/Ivy Bridge (AVX/FMA3) | `sandybridge` | `sdcz` | | | Intel Haswell, Broadwell (AVX/FMA3) | `haswell` | `sdcz` | `sd` | -| Intel Sky/Kaby/Coffeelake (AVX/FMA3) | `haswell` | `sdcz` | `sd` | +| Intel Sky/Kaby/CoffeeLake (AVX/FMA3) | `haswell` | `sdcz` | `sd` | | Intel Knights Landing (AVX-512/FMA3) | `knl` | `sd` | | | Intel SkylakeX (AVX-512/FMA3) | `skx` | `sd` | | | ARMv7 Cortex-A9 (NEON) | `cortex-a9` | `sd` | | diff --git a/docs/Multithreading.md b/docs/Multithreading.md index 7fff7357f..a2630b18d 100644 --- a/docs/Multithreading.md +++ b/docs/Multithreading.md @@ -23,11 +23,17 @@ # Introduction -Our paper [Anatomy of High-Performance Many-Threaded Matrix Multiplication](https://github.com/flame/blis#citations), presented at IPDPS'14, identified 5 loops around the microkernel as opportunities for parallelization within level-3 operations such as `gemm`. Within BLIS, we have enabled parallelism for 4 of those loops and have extended it to the rest of the level-3 operations except for `trsm`. +Our paper [Anatomy of High-Performance Many-Threaded Matrix Multiplication](https://github.com/flame/blis#citations), presented at IPDPS'14, identified five loops around the microkernel as opportunities for parallelization within level-3 operations such as `gemm`. Within BLIS, we have enabled parallelism for four of those loops, with the fifth planned for future work. This software architecture extends naturally to all level-3 operations except for `trsm`, where its application is necessarily limited to three of the five loops due to inter-iteration dependencies. + +**IMPORTANT**: Multithreading in BLIS is disabled by default. Furthermore, even when multithreading is enabled, BLIS will default to single-threaded execution at runtime. In order to both *allow* and *invoke* parallelism from within BLIS operations, you must both *enable* multithreading at configure-time and *specify* multithreading at runtime. + +To summarize: In order to observe multithreaded parallelism within a BLIS operation, you must do *both* of the following: +1. Enable multithreading at configure-time. This is discussed in the [next section](docs/Multithreading.md#enabling-multithreading). +2. Specify multithreading at runtime. This is also dicussed [later on](docs/Multithreading.md#specifying-multithreading). # Enabling multithreading -Note that BLIS disables multithreading by default. In order to extract multithreaded parallelism from BLIS, you must first enable multithreading explicitly at configure-time. +BLIS disables multithreading by default. In order to allow multithreaded parallelism from BLIS, you must first enable multithreading explicitly at configure-time. As of this writing, BLIS optionally supports multithreading via either OpenMP or POSIX threads. @@ -101,7 +107,7 @@ This pattern--automatic or manual--holds regardless of which of the three method Regardless of which method is employed, and which specific way within each method, after setting the number of threads, the application may call the desired level-3 operation (via either the [typed API](docs/BLISTypedAPI.md) or the [object API](docs/BLISObjectAPI.md)) and the operation will execute in a multithreaded manner. (When calling BLIS via the BLAS API, only the first two (global) methods are available.) -NOTE: Please be aware of what happens if you try to specify both the automatic and manual ways, as it could otherwise confuse new users. Regardless of which broad method is used, **if multithreading is specified via both the automatic and manual ways, the manual way will always take precedence.** Also, specifying parallelism for even *one* loop counts as specifying the manual way (in which case the ways of parallelism for the remaining loops will be assumed to be 1). +**Note**: Please be aware of what happens if you try to specify both the automatic and manual ways, as it could otherwise confuse new users. Regardless of which broad method is used, **if multithreading is specified via both the automatic and manual ways, the manual way will always take precedence.** Also, specifying parallelism for even *one* loop counts as specifying the manual way (in which case the ways of parallelism for the remaining loops will be assumed to be 1). ## Globally via environment variables @@ -109,6 +115,8 @@ The most common method of specifying multithreading in BLIS is globally via envi Regardless of whether you end up using the automatic or manual way of expressing a request for multithreading, note that the environment variables are read (via `getenv()`) by BLIS **only once**, when the library is initialized. Subsequent to library initialization, the global settings for parallelization may only be changed via the [global runtime API](Multithreading.md#globally-at-runtime). If this constraint is not a problem, then environment variables may work fine for you. Otherwise, please consider [local settings](Multithreading.md#locally-at-runtime). (Local settings may used at any time, regardless of whether global settings were explicitly specified, and local settings always override global settings.) +**Note**: Regardless of which way ([automatic](Multithreading.md#environment-variables-the-automatic-way) or [manual](Multithreading.md#environment-variables-the-manual-way)) environment variables are used to specify multithreading, that specification will affect operation of BLIS through **both** the BLAS compatibility layer as well as the native [typed](docs/BLISTypedAPI.md) and [object](docs/BLISObjectAPI.md) APIs that are unique to BLIS. + ### Environment variables: the automatic way The automatic way of specifying parallelism entails simply setting the total number of threads you wish BLIS to employ in its parallelization. This total number of threads is captured by the `BLIS_NUM_THREADS` environment variable. You can set this variable prior to executing your BLIS-linked executable: @@ -119,7 +127,7 @@ $ ./my_blis_program ``` This causes BLIS to automatically determine a reasonable threading strategy based on what is known about the operation and problem size. If `BLIS_NUM_THREADS` is not set, BLIS will attempt to query the value of `OMP_NUM_THREADS`. If neither variable is set, the default number of threads is 1. -**Note:** We *highly* discourage use of the `OMP_NUM_THREADS` environment variable and may remove support for it in the future. If you wish to set parallelism globally via environment variables, please use `BLIS_NUM_THREADS`. +**Note**: We *highly* discourage use of the `OMP_NUM_THREADS` environment variable and may remove support for it in the future. If you wish to set parallelism globally via environment variables, please use `BLIS_NUM_THREADS`. ### Environment variables: the manual way @@ -127,7 +135,7 @@ The manual way of specifying parallelism involves communicating which loops with The below chart describes the five loops used in BLIS's matrix multiplication operations. -| Loop around microkernel | Environment variable | Direction | Notes | +| Loop around microkernel | Environment variable | Direction | Notes | |:-------------------------|:---------------------|:----------|:------------| | 5th loop | `BLIS_JC_NT` | `n` | | | 4th loop | _N/A_ | `k` | Not enabled | @@ -154,6 +162,8 @@ Next, which combinations of loops to parallelize depends on which caches are sha If you still wish to set the parallelization scheme globally, but you want to do so at runtime, BLIS provides a thread-safe API for specifying multithreading. Think of these functions as a way to modify the same internal data structure into which the environment variables are read. (Recall that the environment variables are only read once, when BLIS is initialized). +**Note**: Regardless of which way ([automatic](Multithreading.md#globally-at-runtime-the-automatic-way) or [manual](Multithreading.md#globally-at-runtime-the-manual-way)) the global runtime API is used to specify multithreading, that specification will affect operation of BLIS through **both** the BLAS compatibility layer as well as the native [typed](docs/BLISTypedAPI.md) and [object](docs/BLISObjectAPI.md) APIs that are unique to BLIS. + ### Globally at runtime: the automatic way If you simply want to specify an overall number of threads and let BLIS choose a thread factorization automatically, use the following function: @@ -193,6 +203,8 @@ In addition to the global methods based on environment variables and runtime fun As with environment variables and the global runtime API, there are two ways to specify parallelism: the automatic way and the manual way. Both ways involve allocating a BLIS-specific object, initializing the object and encoding the desired parallelization, and then passing a pointer to the object into one of the expert interfaces of either the [typed](docs/BLISTypedAPI.md) or [object](docs/BLISObjectAPI) APIs. We provide examples of utilizing this threading object below. +**Note**: Neither way ([automatic](Multithreading.md#locally-at-runtime-the-automatic-way) nor [manual](Multithreading.md#locally-at-runtime-the-manual-way)) of specifying multithreading via the local runtime API can be used via the BLAS interfaces. The local runtime API may *only* be used via the native [typed](docs/BLISTypedAPI.md) and [object](docs/BLISObjectAPI.md) APIs, which are unique to BLIS. (Furthermore, the expert interfaces of each API must be used. This is demonstrated later on in this section.) + ### Initializing a rntm_t Before specifying the parallelism (automatically or manually), you must first allocate a special BLIS object called a `rntm_t` (runtime). The object is quite small (about 64 bytes), and so we recommend allocating it statically on the function stack: diff --git a/docs/Performance.md b/docs/Performance.md new file mode 100644 index 000000000..e51028c49 --- /dev/null +++ b/docs/Performance.md @@ -0,0 +1,394 @@ +# Contents + +* **[Contents](Performance.md#contents)** +* **[Introduction](Performance.md#introduction)** +* **[General information](Performance.md#general-information)** +* **[Level-3 performance](Performance.md#level-3-performance)** + * **[ThunderX2](Performance.md#thunderx2)** + * **[Experiment details](Performance.md#thunderx2-experiment-details)** + * **[Results](Performance.md#thunderx2-results)** + * **[SkylakeX](Performance.md#skylakex)** + * **[Experiment details](Performance.md#skylakex-experiment-details)** + * **[Results](Performance.md#skylakex-results)** + * **[Haswell](Performance.md#haswell)** + * **[Experiment details](Performance.md#haswell-experiment-details)** + * **[Results](Performance.md#haswell-results)** + * **[Epyc](Performance.md#epyc)** + * **[Experiment details](Performance.md#epyc-experiment-details)** + * **[Results](Performance.md#epyc-results)** +* **[Feedback](Performance.md#feedback)** + +# Introduction + +This document showcases performance results for a representative sample of +level-3 operations on large matrices with BLIS and BLAS for several hardware +architectures. + +# General information + +Generally speaking, for level-3 operations on large matrices, we publish three +"panels" for each type of hardware, +each of which reports one of: single-threaded performance, multithreaded +performance on a single socket, or multithreaded performance on two sockets. +Each panel will consist of a 4x5 grid of graphs, with each row representing +a different datatype (single real, double real, single complex, and double +complex) and each column representing a different operation (`gemm`, +`hemm`/`symm`, `herk`/`syrk`, `trmm`, and `trsm`). +Each of the 20 graphs within a panel will contain an x-axis that reports +problem size, with all matrix dimensions equal to the problem size (e.g. +_m_ = _n_ = _k_), resulting in square matrices. +The y-axis will report in units GFLOPS (billions of floating-point operations +per second) in the case of single-threaded performance, or GFLOPS/core in the +case of single- or dual-socket multithreaded performance, where GFLOPS/core +is simply the total GFLOPS observed divided by the number of threads utilized. +This normalization is done intentionally in order to facilitate a visual +assessment of the drop in efficiency of multithreaded performance relative +to their single-threaded baselines. + +It's also worth pointing out that the top of each graph (e.g. the maximum +y-axis value depicted) _always_ corresponds to the theoretical peak performance +under the conditions associated with that graph. +Theoretical peak performance, in units of GFLOPS/core, is calculated as the +product of: +1. the maximum sustainable clock rate in GHz; and +2. the maximum number of floating-point operations (flops) that can be +executed per cycle (per core). + +Note that the maximum sustainable clock rate may change depending on the +conditions. +For example, on some systems the maximum clock rate is higher when only one +core is active (e.g. single-threaded performance) versus when all cores are +active (e.g. multithreaded performance). +The maximum number of flops executable per cycle (per core) is generally +computed as the product of: +1. the maximum number of fused multiply-add (FMA) vector instructions that +can be issued per cycle (per core); +2. the maximum number of elements that can be stored within a single vector +register (for the datatype in question); and +3. 2.0, since an FMA instruction fuses two operations (a multiply and an add). + +The problem size range, represented on the x-axis, is usually sampled with 50 +equally-spaced problem size. +For example, for single-threaded execution, we might choose to execute with +problem sizes of 48 to 2400 in increments of 48, or 56 to 2800 in increments +of 56. +These values are almost never chosen for any particular (read: sneaky) reason; +rather, we start with a "good" maximum problem size, such as 2400 or 2800, and +then divide it by 50 to obtain the appropriate starting point and increment. + +Finally, each point along each curve represents the best of three trials. + +# Interpretation + +In general, the the curves associated with higher-performing implementations +will appear higher in the graphs than lower-performing implementations. +Ideally, an implementation will climb in performance (as a function of problem +size) as quickly as possible and asymptotically approach some high fraction of +peak performance. + +Occasionally, we may publish graphs with incomplete curves--for example, +only the first 25 data points in a typical 50-point series--usually because +the implementation being tested was slow enough that it was not practical to +allow it to finish. + +Where along the x-axis you focus your attention will depend on the segment of +the problem size range that you care about most. Some people's applications +depend heavily on smaller problems, where "small" can mean anything from 10 +to 1000 or even higher. Some people consider 1000 to be quite large, while +others insist that 5000 is merely "medium." What each of us considers to be +small, medium, or large (naturally) depends heavily on the kinds of dense +linear algebra problems we tend to encounter. No one is "right" or "wrong" +about their characterization of matrix smallness or bigness since each person's +relative frame of reference can vary greatly. That said, the +[Science of High-Performance Computing](http://shpc.ices.utexas.edu/) group at +[The University of Texas at Austin](https://www.utexas.edu/) tends to target +matrices that it classifies as "medium-to-large", and so most of the graphs +presented in this document will reflect that targeting in their x-axis range. + +When corresponding with us, via email or when opening an +[issue](https://github.com/flame/blis/issues) on github, we kindly ask that +you specify as closely as possible (though a range is fine) your problem +size of interest so that we can better assist you. + +# Level-3 performance + +## ThunderX2 + +### ThunderX2 experiment details + +* Location: Unknown +* Processor model: Marvell ThunderX2 CN9975 +* Core topology: two sockets, 28 cores per socket, 56 cores total +* SMT status: disabled at boot-time +* Max clock rate: 2.2GHz (single-core and multicore) +* Max vector register length: 128 bits (NEON) +* Max FMA vector IPC: 2 +* Peak performance: + * single-core: 17.6 GFLOPS (double-precision), 35.2 GFLOPS (single-precision) + * multicore: 17.6 GFLOPS/core (double-precision), 35.2 GFLOPS/core (single-precision) +* Operating system: Ubuntu 16.04 (Linux kernel 4.15.0) +* Compiler: gcc 7.3.0 +* Results gathered: 14 February 2019 +* Implementations tested: + * BLIS 075143df (0.5.1-39) + * configured with `./configure -t openmp thunderx2` (single- and multithreaded) + * sub-configuration exercised: `thunderx2` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (28 core) execution requested via `export BLIS_JC_NT=4 BLIS_IC_NT=7` + * Multithreaded (56 core) execution requested via `export BLIS_JC_NT=8 BLIS_IC_NT=7` + * OpenBLAS 52d3f7a + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=56` (multithreaded, 56 cores) + * Single-threaded (1 core) execution requested via `export OPENBLAS_NUM_THREADS=1` + * Multithreaded (28 core) execution requested via `export OPENBLAS_NUM_THREADS=28` + * Multithreaded (56 core) execution requested via `export OPENBLAS_NUM_THREADS=56` + * ARMPL 18.4 + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (28 core) execution requested via `export OMP_NUM_THREADS=28` + * Multithreaded (56 core) execution requested via `export OMP_NUM_THREADS=56` +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0 1 2 3 ... 55"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * No changes made. +* Comments: + * ARMPL performance is remarkably uneven across datatypes and operations, though it would appear their "base" consists of OpenBLAS, which they then optimize for select, targeted routines. Unfortunately, we were unable to test the absolute latest versions of OpenBLAS and ARMPL on this hardware before we lost access. We will rerun these experiments once we gain access to a similar system. + +### ThunderX2 results + +#### pdf + +* [ThunderX2 single-threaded](graphs/large/l3_perf_tx2_nt1.pdf) +* [ThunderX2 multithreaded (28 cores)](graphs/large/l3_perf_tx2_jc4ic7_nt28.pdf) +* [ThunderX2 multithreaded (56 cores)](graphs/large/l3_perf_tx2_jc8ic7_nt56.pdf) + +#### png (inline) + +* **ThunderX2 single-threaded** +![single-threaded](graphs/large/l3_perf_tx2_nt1.png) +* **ThunderX2 multithreaded (28 cores)** +![multithreaded (28 cores)](graphs/large/l3_perf_tx2_jc4ic7_nt28.png) +* **ThunderX2 multithreaded (56 cores)** +![multithreaded (56 cores)](graphs/large/l3_perf_tx2_jc8ic7_nt56.png) + +--- + +## SkylakeX + +### SkylakeX experiment details + +* Location: Oracle cloud +* Processor model: Intel Xeon Platinum 8167M (SkylakeX/AVX-512) +* Core topology: two sockets, 26 cores per socket, 52 cores total +* SMT status: enabled, but not utilized +* Max clock rate: 2.0GHz (single-core and multicore) +* Max vector register length: 512 bits (AVX-512) +* Max FMA vector IPC: 2 +* Peak performance: + * single-core: 64 GFLOPS (double-precision), 128 GFLOPS (single-precision) + * multicore: 64 GFLOPS/core (double-precision), 128 GFLOPS/core (single-precision) +* Operating system: Ubuntu 18.04 (Linux kernel 4.15.0) +* Compiler: gcc 7.3.0 +* Results gathered: 6 March 2019, 27 March 2019 +* Implementations tested: + * BLIS 9f1dbe5 (0.5.1-54) + * configured with `./configure -t openmp auto` (single- and multithreaded) + * sub-configuration exercised: `skx` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (26 core) execution requested via `export BLIS_JC_NT=2 BLIS_IC_NT=13` + * Multithreaded (52 core) execution requested via `export BLIS_JC_NT=4 BLIS_IC_NT=13` + * OpenBLAS 0.3.5 + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=52` (multithreaded, 52 cores) + * Single-threaded (1 core) execution requested via `export OPENBLAS_NUM_THREADS=1` + * Multithreaded (26 core) execution requested via `export OPENBLAS_NUM_THREADS=26` + * Multithreaded (52 core) execution requested via `export OPENBLAS_NUM_THREADS=52` + * Eigen 3.3.90 + * Obtained via the [Eigen git mirror](https://github.com/eigenteam/eigen-git-mirror) (March 27, 2019) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal). + * configured and built BLAS library via `mkdir build; cd build; cmake ..; make blas` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (26 core) execution requested via `export OMP_NUM_THREADS=26` + * Multithreaded (52 core) execution requested via `export OMP_NUM_THREADS=52` + * **NOTE**: This version of Eigen does not provide multithreaded implementations of `symm`/`hemm`, `syrk`/`herk`, `trmm`, or `trsm`, and therefore those curves are omitted from the multithreaded graphs. + * MKL 2019 update 1 + * Single-threaded (1 core) execution requested via `export MKL_NUM_THREADS=1` + * Multithreaded (26 core) execution requested via `export MKL_NUM_THREADS=26` + * Multithreaded (52 core) execution requested via `export MKL_NUM_THREADS=52` +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0 1 2 3 ... 51"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * Driver: acpi-cpufreq + * Governor: performance + * Hardware limits: 1.0GHz - 2.0GHz + * Adjusted minimum: 2.0GHz +* Comments: + * MKL yields superb performance for most operations, though BLIS is not far behind except for `trsm`. (We understand the `trsm` underperformance and hope to address it in the future.) OpenBLAS lags far behind MKL and BLIS due to lack of full support for AVX-512, and possibly other reasons related to software architecture and register/cache blocksizes. + +### SkylakeX results + +#### pdf + +* [SkylakeX single-threaded](graphs/large/l3_perf_skx_nt1.pdf) +* [SkylakeX multithreaded (26 cores)](graphs/large/l3_perf_skx_jc2ic13_nt26.pdf) +* [SkylakeX multithreaded (52 cores)](graphs/large/l3_perf_skx_jc4ic13_nt52.pdf) + +#### png (inline) + +* **SkylakeX single-threaded** +![single-threaded](graphs/large/l3_perf_skx_nt1.png) +* **SkylakeX multithreaded (26 cores)** +![multithreaded (26 cores)](graphs/large/l3_perf_skx_jc2ic13_nt26.png) +* **SkylakeX multithreaded (52 cores)** +![multithreaded (52 cores)](graphs/large/l3_perf_skx_jc4ic13_nt52.png) + +--- + +## Haswell + +### Haswell experiment details + +* Location: TACC (Lonestar5) +* Processor model: Intel Xeon E5-2690 v3 (Haswell) +* Core topology: two sockets, 12 cores per socket, 24 cores total +* SMT status: enabled, but not utilized +* Max clock rate: 3.5GHz (single-core), 3.1GHz (multicore) +* Max vector register length: 256 bits (AVX2) +* Max FMA vector IPC: 2 +* Peak performance: + * single-core: 56 GFLOPS (double-precision), 112 GFLOPS (single-precision) + * multicore: 49.6 GFLOPS/core (double-precision), 99.2 GFLOPS/core (single-precision) +* Operating system: Cray Linux Environment 6 (Linux kernel 4.4.103) +* Compiler: gcc 6.3.0 +* Results gathered: 25-26 February 2019, 27 March 2019 +* Implementations tested: + * BLIS 075143df (0.5.1-39) + * configured with `./configure -t openmp auto` (single- and multithreaded) + * sub-configuration exercised: `haswell` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (12 core) execution requested via `export BLIS_JC_NT=2 BLIS_IC_NT=3 BLIS_JR_NT=2` + * Multithreaded (24 core) execution requested via `export BLIS_JC_NT=4 BLIS_IC_NT=3 BLIS_JR_NT=2` + * OpenBLAS 0.3.5 + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=24` (multithreaded, 24 cores) + * Single-threaded (1 core) execution requested via `export OPENBLAS_NUM_THREADS=1` + * Multithreaded (12 core) execution requested via `export OPENBLAS_NUM_THREADS=12` + * Multithreaded (24 core) execution requested via `export OPENBLAS_NUM_THREADS=24` + * Eigen 3.3.90 + * Obtained via the [Eigen git mirror](https://github.com/eigenteam/eigen-git-mirror) (March 27, 2019) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal). + * configured and built BLAS library via `mkdir build; cd build; cmake ..; make blas` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (12 core) execution requested via `export OMP_NUM_THREADS=12` + * Multithreaded (24 core) execution requested via `export OMP_NUM_THREADS=24` + * **NOTE**: This version of Eigen does not provide multithreaded implementations of `symm`/`hemm`, `syrk`/`herk`, `trmm`, or `trsm`, and therefore those curves are omitted from the multithreaded graphs. + * MKL 2018 update 2 + * Single-threaded (1 core) execution requested via `export MKL_NUM_THREADS=1` + * Multithreaded (12 core) execution requested via `export MKL_NUM_THREADS=12` + * Multithreaded (24 core) execution requested via `export MKL_NUM_THREADS=24` +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0 1 2 3 ... 23"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * No changes made. +* Comments: + * We were pleasantly surprised by how competitive BLIS performs relative to MKL on this multicore Haswell system, which is a _very_ common microarchitecture, and _very_ similar to the more recent Broadwells, Skylakes (desktop), Kaby Lakes, and Coffee Lakes that succeeded it. + +### Haswell results + +#### pdf + +* [Haswell single-threaded](graphs/large/l3_perf_has_nt1.pdf) +* [Haswell multithreaded (12 cores)](graphs/large/l3_perf_has_jc2ic3jr2_nt12.pdf) +* [Haswell multithreaded (24 cores)](graphs/large/l3_perf_has_jc4ic3jr2_nt24.pdf) + +#### png (inline) + +* **Haswell single-threaded** +![single-threaded](graphs/large/l3_perf_has_nt1.png) +* **Haswell multithreaded (12 cores)** +![multithreaded (12 cores)](graphs/large/l3_perf_has_jc2ic3jr2_nt12.png) +* **Haswell multithreaded (24 cores)** +![multithreaded (24 cores)](graphs/large/l3_perf_has_jc4ic3jr2_nt24.png) + +--- + +## Epyc + +### Epyc experiment details + +* Location: Oracle cloud +* Processor model: AMD Epyc 7551 (Zen1) +* Core topology: two sockets, 4 dies per socket, 2 core complexes (CCX) per die, 4 cores per CCX, 64 cores total +* SMT status: enabled, but not utilized +* Max clock rate: 3.0GHz (single-core), 2.55GHz (multicore) +* Max vector register length: 256 bits (AVX2) +* Max FMA vector IPC: 1 + * Alternatively, FMA vector IPC is 2 when vectors are limited to 128 bits each. +* Peak performance: + * single-core: 24 GFLOPS (double-precision), 48 GFLOPS (single-precision) + * multicore: 20.4 GFLOPS/core (double-precision), 40.8 GFLOPS/core (single-precision) +* Operating system: Ubuntu 18.04 (Linux kernel 4.15.0) +* Compiler: gcc 7.3.0 +* Results gathered: 6 March 2019, 19 March 2019, 27 March 2019 +* Implementations tested: + * BLIS 9f1dbe5 (0.5.1-54) + * configured with `./configure -t openmp auto` (single- and multithreaded) + * sub-configuration exercised: `zen` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (32 core) execution requested via `export BLIS_JC_NT=1 BLIS_IC_NT=8 BLIS_JR_NT=4` + * Multithreaded (64 core) execution requested via `export BLIS_JC_NT=2 BLIS_IC_NT=8 BLIS_JR_NT=4` + * OpenBLAS 0.3.5 + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=64` (multithreaded, 64 cores) + * Single-threaded (1 core) execution requested via `export OPENBLAS_NUM_THREADS=1` + * Multithreaded (32 core) execution requested via `export OPENBLAS_NUM_THREADS=32` + * Multithreaded (64 core) execution requested via `export OPENBLAS_NUM_THREADS=64` + * Eigen 3.3.90 + * Obtained via the [Eigen git mirror](https://github.com/eigenteam/eigen-git-mirror) (March 27, 2019) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal). + * configured and built BLAS library via `mkdir build; cd build; cmake ..; make blas` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Single-threaded (1 core) execution requested via `export OMP_NUM_THREADS=1` + * Multithreaded (32 core) execution requested via `export OMP_NUM_THREADS=32` + * Multithreaded (64 core) execution requested via `export OMP_NUM_THREADS=64` + * **NOTE**: This version of Eigen does not provide multithreaded implementations of `symm`/`hemm`, `syrk`/`herk`, `trmm`, or `trsm`, and therefore those curves are omitted from the multithreaded graphs. + * MKL 2019 update 1 + * Single-threaded (1 core) execution requested via `export MKL_NUM_THREADS=1` + * Multithreaded (32 core) execution requested via `export MKL_NUM_THREADS=32` + * Multithreaded (64 core) execution requested via `export MKL_NUM_THREADS=64` +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0 1 2 3 ... 63"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * Driver: acpi-cpufreq + * Governor: performance + * Hardware limits: 1.2GHz - 2.0GHz + * Adjusted minimum: 2.0GHz +* Comments: + * MKL performance is dismal, despite being linked in the same manner as on the Xeon Platinum. It's not clear what is causing the slowdown. It could be that MKL's runtime kernel/blocksize selection logic is falling back to some older, more basic implementation because CPUID is not returning Intel as the hardware vendor. Alternatively, it's possible that MKL is trying to use kernels for the closest Intel architectures--say, Haswell/Broadwell--but its implementations use Haswell-specific optimizations that, due to microarchitectural differences, degrade performance on Zen. + +### Epyc results + +#### pdf + +* [Epyc single-threaded](graphs/large/l3_perf_epyc_nt1.pdf) +* [Epyc multithreaded (32 cores)](graphs/large/l3_perf_epyc_jc1ic8jr4_nt32.pdf) +* [Epyc multithreaded (64 cores)](graphs/large/l3_perf_epyc_jc2ic8jr4_nt64.pdf) + +#### png (inline) + +* **Epyc single-threaded** +![single-threaded](graphs/large/l3_perf_epyc_nt1.png) +* **Epyc multithreaded (32 cores)** +![multithreaded (32 cores)](graphs/large/l3_perf_epyc_jc1ic8jr4_nt32.png) +* **Epyc multithreaded (64 cores)** +![multithreaded (64 cores)](graphs/large/l3_perf_epyc_jc2ic8jr4_nt64.png) + +--- + +# Feedback + +Please let us know what you think of these performance results! Similarly, if you have any questions or concerns, or are interested in reproducing these performance experiments on your own hardware, we invite you to [open an issue](https://github.com/flame/blis/issues) and start a conversation with BLIS developers. + +Thanks for your interest in BLIS! + diff --git a/docs/PerformanceSmall.md b/docs/PerformanceSmall.md new file mode 100644 index 000000000..51f0498b2 --- /dev/null +++ b/docs/PerformanceSmall.md @@ -0,0 +1,224 @@ +# Contents + +* **[Contents](Performance.md#contents)** +* **[Introduction](Performance.md#introduction)** +* **[General information](Performance.md#general-information)** +* **[Level-3 performance](Performance.md#level-3-performance)** + * **[Kaby Lake](Performance.md#kaby-lake)** + * **[Experiment details](Performance.md#kaby-lake-experiment-details)** + * **[Results](Performance.md#kaby-lake-results)** + * **[Epyc](Performance.md#epyc)** + * **[Experiment details](Performance.md#epyc-experiment-details)** + * **[Results](Performance.md#epyc-results)** +* **[Feedback](Performance.md#feedback)** + +# Introduction + +This document showcases performance results for the level-3 `gemm` operation +on small matrices with BLIS and BLAS for select hardware architectures. + +# General information + +Generally speaking, for level-3 operations on small matrices, we publish +two "panels" for each type of hardware, one that reflects performance on +row-stored matrices and another for column-stored matrices. +Each panel will consist of a 4x7 grid of graphs, with each row representing +a different transposition case (`nn`, `nt`, `tn`, `tt`) +complex) and each column representing a different shape scenario, usually +with one or two matrix dimensions bound to a fixed size for all problem +sizes tested. +Each of the 28 graphs within a panel will contain an x-axis that reports +problem size, with one, two, or all three matrix dimensions equal to the +problem size (e.g. _m_ = 6; _n_ = _k_, also encoded as `m6npkp`). +The y-axis will report in units GFLOPS (billions of floating-point operations +per second) on a single core. + +It's also worth pointing out that the top of each graph (e.g. the maximum +y-axis value depicted) _always_ corresponds to the theoretical peak performance +under the conditions associated with that graph. +Theoretical peak performance, in units of GFLOPS, is calculated as the +product of: +1. the maximum sustainable clock rate in GHz; and +2. the maximum number of floating-point operations (flops) that can be +executed per cycle. + +Note that the maximum sustainable clock rate may change depending on the +conditions. +For example, on some systems the maximum clock rate is higher when only one +core is active (e.g. single-threaded performance) versus when all cores are +active (e.g. multithreaded performance). +The maximum number of flops executable per cycle (per core) is generally +computed as the product of: +1. the maximum number of fused multiply-add (FMA) vector instructions that +can be issued per cycle (per core); +2. the maximum number of elements that can be stored within a single vector +register (for the datatype in question); and +3. 2.0, since an FMA instruction fuses two operations (a multiply and an add). + +The problem size range, represented on the x-axis, is sampled in +increments of 4 up to 800 for the cases where one or two dimensions is small +(and constant) +and up to 400 in the case where all dimensions (e.g. _m_, _n_, and _k_) are +bound to the problem size (i.e., square matrices). + +Note that the constant small matrix dimensions were chosen to be _very_ +small--in the neighborhood of 8--intentionally to showcase what happens when +at least one of the matrices is abnormally "skinny." Typically, organizations +and individuals only publish performance with square matrices, which can miss +the problem sizes of interest to many applications. Here, in addition to square +matrices (shown in the seventh column), we also show six other scenarios where +one or two `gemm` dimensions (of _m,_ _n_, and _k_) is small. + +The legend in each graph contains two entries for BLIS, corresponding to the +two black lines, one solid and one dotted. The dotted line, **"BLIS conv"**, +represents the conventional implementation that targets large matrices. This +was the only implementation available in BLIS prior to the addition to the +small/skinny matrix support. The solid line, **"BLIS sup"**, makes use of the +new small/skinny matrix implementation for certain small problems. Whenever +these results differ by any significant amount (beyond noise), it denotes a +problem size for which BLIS employed the new small/skinny implementation. +Put another way, **the delta between these two lines represents the performance +improvement between BLIS's previous status quo and the new regime.** + +Finally, each point along each curve represents the best of three trials. + +# Interpretation + +In general, the the curves associated with higher-performing implementations +will appear higher in the graphs than lower-performing implementations. +Ideally, an implementation will climb in performance (as a function of problem +size) as quickly as possible and asymptotically approach some high fraction of +peak performance. + +When corresponding with us, via email or when opening an +[issue](https://github.com/flame/blis/issues) on github, we kindly ask that +you specify as closely as possible (though a range is fine) your problem +size of interest so that we can better assist you. + +# Level-3 performance + +## Kaby Lake + +### Kaby Lake experiment details + +* Location: undisclosed +* Processor model: Intel Core i5-7500 (Kaby Lake) +* Core topology: one socket, 4 cores total +* SMT status: unavailable +* Max clock rate: 3.8GHz (single-core) +* Max vector register length: 256 bits (AVX2) +* Max FMA vector IPC: 2 +* Peak performance: + * single-core: 57.6 GFLOPS (double-precision), 115.2 GFLOPS (single-precision) +* Operating system: Gentoo Linux (Linux kernel 5.0.7) +* Compiler: gcc 7.3.0 +* Results gathered: 31 May 2019, 3 June 2019, 19 June 2019 +* Implementations tested: + * BLIS 6bf449c (0.5.2-42) + * configured with `./configure --enable-cblas auto` + * sub-configuration exercised: `haswell` + * OpenBLAS 0.3.6 + * configured `Makefile.rule` with `BINARY=64 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * BLASFEO 2c9f312 + * configured `Makefile.rule` with: `BLAS_API=1 FORTRAN_BLAS_API=1 CBLAS_API=1`. + * Eigen 3.3.90 + * Obtained via the [Eigen git mirror](https://github.com/eigenteam/eigen-git-mirror) (30 May 2019) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal). + * configured and built BLAS library via `mkdir build; cd build; cmake ..; make blas` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Requested threading via `export OMP_NUM_THREADS=1` (single-threaded) + * MKL 2018 update 4 + * Requested threading via `export MKL_NUM_THREADS=1` (single-threaded) +* Affinity: + * N/A. +* Frequency throttling (via `cpupower`): + * Driver: intel_pstate + * Governor: performance + * Hardware limits: 800MHz - 3.8GHz + * Adjusted minimum: 3.7GHz +* Comments: + * For both row- and column-stored matrices, BLIS's new small/skinny matrix implementation is competitive with (or exceeds the performance of) the next highest-performing solution (typically MKL), except for a few cases of where the _k_ dimension is very small. It is likely the case that this shape scenario begs a different kernel approach, since the BLIS microkernel is inherently designed to iterate over many _k_ dimension iterations (which leads them to incur considerable overhead for small values of _k_). + * For the classic case of `dgemm_nn` on square matrices, BLIS is the fastest implementation for the problem size range of approximately 80 to 180. BLIS is also competitive in this general range for other transpose parameter combinations (`nt`, `tn`, and `tt`). + +### Kaby Lake results + +#### pdf + +* [Kaby Lake row-stored](graphs/sup/dgemm_rrr_kbl_nt1.pdf) +* [Kaby Lake column-stored](graphs/sup/dgemm_ccc_kbl_nt1.pdf) + +#### png (inline) + +* **Kaby Lake row-stored** +![row-stored](graphs/sup/dgemm_rrr_kbl_nt1.png) +* **Kaby Lake column-stored** +![column-stored](graphs/sup/dgemm_ccc_kbl_nt1.png) + +--- + +## Epyc + +### Epyc experiment details + +* Location: Oracle cloud +* Processor model: AMD Epyc 7551 (Zen1) +* Core topology: two sockets, 4 dies per socket, 2 core complexes (CCX) per die, 4 cores per CCX, 64 cores total +* SMT status: enabled, but not utilized +* Max clock rate: 3.0GHz (single-core), 2.55GHz (multicore) +* Max vector register length: 256 bits (AVX2) +* Max FMA vector IPC: 1 + * Alternatively, FMA vector IPC is 2 when vectors are limited to 128 bits each. +* Peak performance: + * single-core: 24 GFLOPS (double-precision), 48 GFLOPS (single-precision) +* Operating system: Ubuntu 18.04 (Linux kernel 4.15.0) +* Compiler: gcc 7.3.0 +* Results gathered: 31 May 2019, 3 June 2019, 19 June 2019 +* Implementations tested: + * BLIS 6bf449c (0.5.2-42) + * configured with `./configure --enable-cblas auto` + * sub-configuration exercised: `zen` + * OpenBLAS 0.3.6 + * configured `Makefile.rule` with `BINARY=64 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * BLASFEO 2c9f312 + * configured `Makefile.rule` with: `BLAS_API=1 FORTRAN_BLAS_API=1 CBLAS_API=1`. + * Eigen 3.3.90 + * Obtained via the [Eigen git mirror](https://github.com/eigenteam/eigen-git-mirror) (30 May 2019) + * Prior to compilation, modified top-level `CMakeLists.txt` to ensure that `-march=native` was added to `CXX_FLAGS` variable (h/t Sameer Agarwal). + * configured and built BLAS library via `mkdir build; cd build; cmake ..; make blas` + * The `gemm` implementation was pulled in at compile-time via Eigen headers; other operations were linked to Eigen's BLAS library. + * Requested threading via `export OMP_NUM_THREADS=1` (single-threaded) + * MKL 2019 update 4 + * Requested threading via `export MKL_NUM_THREADS=1` (single-threaded) +* Affinity: + * N/A. +* Frequency throttling (via `cpupower`): + * Driver: acpi-cpufreq + * Governor: performance + * Hardware limits: 1.2GHz - 2.0GHz + * Adjusted minimum: 2.0GHz +* Comments: + * As with Kaby Lake, BLIS's new small/skinny matrix implementation is competitive with (or exceeds the performance of) the next highest-performing solution, except for a few cases of where the _k_ dimension is very small. + * For the classic case of `dgemm_nn` on square matrices, BLIS is the fastest implementation for the problem size range of approximately 12 to 256. BLIS is also competitive in this general range for other transpose parameter combinations (`nt`, `tn`, and `tt`). + +### Epyc results + +#### pdf + +* [Epyc row-stored](graphs/sup/dgemm_rrr_epyc_nt1.pdf) +* [Epyc column-stored](graphs/sup/dgemm_ccc_epyc_nt1.pdf) + +#### png (inline) + +* **Epyc row-stored** +![row-stored](graphs/sup/dgemm_rrr_epyc_nt1.png) +* **Epyc column-stored** +![column-stored](graphs/sup/dgemm_ccc_epyc_nt1.png) + +--- + +# Feedback + +Please let us know what you think of these performance results! Similarly, if you have any questions or concerns, or are interested in reproducing these performance experiments on your own hardware, we invite you to [open an issue](https://github.com/flame/blis/issues) and start a conversation with BLIS developers. + +Thanks for your interest in BLIS! + diff --git a/docs/ReleaseNotes.md b/docs/ReleaseNotes.md index 193de9342..d1a6baece 100644 --- a/docs/ReleaseNotes.md +++ b/docs/ReleaseNotes.md @@ -4,6 +4,8 @@ ## Contents +* [Changes in 0.6.0](ReleaseNotes.md#changes-in-060) +* [Changes in 0.5.2](ReleaseNotes.md#changes-in-052) * [Changes in 0.5.1](ReleaseNotes.md#changes-in-051) * [Changes in 0.5.0](ReleaseNotes.md#changes-in-050) * [Changes in 0.4.1](ReleaseNotes.md#changes-in-041) @@ -33,6 +35,70 @@ * [Changes in 0.0.2](ReleaseNotes.md#changes-in-002) * [Changes in 0.0.1](ReleaseNotes.md#changes-in-001) +## Changes in 0.6.0 +June 3, 2019 + +Improvements present in 0.6.0: + +Framework: +- Implemented small/skinny/unpacked (sup) framework for accelerated level-3 performance when at least one matrix dimension is small (or very small). For now, only `dgemm` is optimized, and this new implementation currently only targets Intel Haswell through Coffee Lake, and AMD Zen-based Ryzen/Epyc. (The existing kernels should extend without significant modification to Zen2-based Ryzen/Epyc once they are available.) Also, multithreaded parallelism is not yet implemented, though application-level threading should be fine. (AMD) +- Changed function pointer usages of `void*` to new, typedef'ed type `void_fp`. +- Allow compile-time disabling of BLAS prototypes in BLIS, in case the application already has access to prototypes. +- In `bli_system.h`, define `_POSIX_C_SOURCE` to `200809L` if the macro is not already defined. This ensures that things such as pthreads are properly defined by an application that has `#include "blis.h"` but omits the definition of `_POSIX_C_SOURCE` from the command-line compiler options. (Christos Psarras) + +Kernels: +- None. + +Build system: +- Updated the way configure and the top-level Makefile handle installation prefixes (`prefix`, `exec_prefix`, `libdir`, `includedir`, `sharedir`) to better conform with GNU conventions. +- Improved clang version detection. (Isuru Fernando) +- Use pthreads on MinGW and Cygwin. (Isuru Fernando) + +Testing: +- Added Eigen support to test drivers in `test/3`. +- Fix inadvertently hidden `xerbla_()` in blastest drivers when building only shared libraries. (Isuru Fernando, M. Zhou) + +Documentation: +- Added `docs/PerformanceSmall.md` to showcase new BLIS small/skinny `dgemm` performance on Kaby Lake and Epyc. +- Added Eigen results (3.3.90) to performance graphs showcased in `docs/Performance.md`. +- Added BLIS thread factorization info to `docs/Performance.md`. + +## Changes in 0.5.2 +March 19, 2019 + +Improvements present in 0.5.2: + +Framework: +- Added support for IC loop parallelism to the `trsm` operation. +- Implemented a pool-based small block allocator and a corresponding `configure` option (enabled by default), which minimizes the number of calls to `malloc()` and `free()` for the purposes of allocating small blocks (on the order of 100 bytes). These small blocks are used by internal data structures, and the repeated allocation and freeing of these structures could, perhaps, cause memory fragmentation issues in certain application circumstances. This was never reproduced and observed, however, and remains entirely theoretical. Still, the sba should be no slower, and perhaps a little faster, than repeatedly calling `malloc()` and `free()` for these internal data structures. Also, the sba was designed to be thread-safe. (AMD) +- Refined and extended the output enabled by `--enable-mem-tracing`, which allows a developer to follow memory allocation and release performed by BLIS. +- Initialize error messages at compile-time rather than at runtime. (Minh Quan Ho) +- Fixed a potential situation whereby the multithreading parameters in a `rntm_t` object that is passed into an expert interface is ignored. +- Prevent a redefinition of `ftnlen` in the `f2c_types.h` in blastest. (Jeff Diamond) + +Kernels: +- Adjusted the cache blocksizes in the `zen` sub-configuration for `float`, `scomplex`, and `dcomplex` datatypes. The previous values, taken directly from the `haswell` subconfig, were merely meant to be reasonable placeholders until more suitable values were determined, as had already taken place for the `double` datatype. (AMD) +- Rewrote reference kernels in terms of simplified indexing annotated by the `#pragma omp simd` directive, which a compiler can use to vectorize certain constant-bounded loops. The `#pragma` is disabled via a preprocessor macro layer if the compiler is found by `configure` to not support `-fopenmp-simd`. (Devin Matthews, Jeff Hammond) + +Build system: +- Added symbol-export annotation macros to all of the function prototypes and global variable declarations for public symbols, and created a new `configure` option, `--export-shared=[public|all]`, that controls which symbols--only those that are meant to be public, or all symbols--are exported to the shared library. (Isuru Fernando) +- Standardized to using `-O3` in various subconfigs, and also `-funsafe-math-optimizations` for reference kernels. (Dave Love, Jeff Hammond) +- Disabled TBM, XOP, LWP instructions in all AMD subconfigs. (Devin Matthews) +- Fixed issues that prevented using BLIS on GNU Hurd. (M. Zhou) +- Relaxed python3 requirements to allow python 3.4 or later. Previously, python 3.5 or later was required if python3 was being used. (Dave Love) +- Added `thunderx2` sub-configuration. (Devangi Parikh) +- Added `power9` sub-configuration. For now, this subconfig only uses reference kernels. (Nicholai Tukanov) +- Fixed an issue with `configure` failing on OSes--including certain flavors of BSD--that contain a slash '/' character in the output of `uname -s`. (Isuru Fernando, M. Zhou) + +Testing: +- Renamed `test/3m4m` directory to `test/3`. +- Lots of updates and improvements to Makefiles, shell scripts, and matlab scripts in `test/3`. + +Documentation: +- Added a new `docs/Performance.md` document that showcases single-threaded, single-socket, and dual-socket performance results of `single`, `double`, `scomplex`, and `dcomplex` level-3 operations in BLIS, OpenBLAS, and MKL/ARMPL for Haswell, SkylakeX, ThunderX2, and Epyc hardware architectures. (Note: Other implementations such as Eigen and ATLAS may be added to these graphs in the future.) +- Updated `README.md` to include new language on external packages. (Dave Love) +- Updated `docs/Multithreading.md` to be more explicit about the fact that multithreading is disabled by default at configure-time, and the fact that BLIS will run executed single-threaded at runtime by default if no multithreaded specification is given. (M. Zhou) + ## Changes in 0.5.1 December 18, 2018 @@ -88,7 +154,7 @@ Kernels: Build system: - Added support for building Windows DLLs via AppVeyor [2], complete with a built-in implementation of pthreads for Windows, as well as an implementation of the `pthread_barrier_*()` APIs for use on OS X. (Isuru Fernando, Devin Matthews, Mathieu Poumeyrol, Matthew Honnibal) - Defined a `cortexa53` sub-configuration, which is similar to `cortexa57` except that it uses slightly different compiler flags. (Mathieu Poumeyrol) -- Added python version checking to configure script. +- Added python version checking to `configure` script. - Added a script to automate the regeneration of the symbols list file (now located in `build/libblis-symbols.def`). - Various tweaks in preparation for BLIS's inclusion within Debian. (M. Zhou) - Various fixes and cleanups. @@ -246,16 +312,16 @@ May 2, 2017 - Implemented the 1m method for inducing complex matrix multiplication. (Please see ACM TOMS publication ["Implementing high-performance complex matrix multiplication via the 1m method"](https://github.com/flame/blis#citations) for more details.) - Switched to simpler `trsm_r` implementation. - Relaxed constraints that `MC % NR = 0` and `NC % MR = 0`, as this was only needed for the more sophisticated `trsm_r` implementation. -- Automatic loop thread assignment. (Devin Matthews) -- Updates to `.travis.yml` configuration file. (Devin Matthews) +- Automatic loop thread assignment. (Devin Matthews) +- Updates to `.travis.yml` configuration file. (Devin Matthews) - Updates to non-default haswell microkernels. - Match storage format of the temporary micro-tiles in macrokernels to that of the microkernel storage preference for edge cases. -- Added support for Intel's Knight's Landing. (Devin Matthews) -- Added more flexible options to specify multithreading via the configure script. (Devin Matthews) -- OS X compatibility fixes. (Devin Matthews) -- Other small changes and fixes. +- Added support for Intel's Knight's Landing. (Devin Matthews) +- Added more flexible options to specify multithreading via the configure script. (Devin Matthews) +- OS X compatibility fixes. (Devin Matthews) +- Other small changes and fixes. -Also, thanks to Elmar Peise, Krzysztof Drewniak, and Francisco Igual for their contributions in reporting/fixing certain bugs that were addressed in this version. +Also, thanks to Elmar Peise, Krzysztof Drewniak, and Francisco Igual for their contributions in reporting/fixing certain bugs that were addressed in this version. ## Changes in 0.2.1 October 5, 2016 @@ -439,7 +505,7 @@ While neither `bli_config.h` nor `bli_kernel.h` has changed formats since 0.0.7, ## Changes in 0.0.7 April 30, 2013 -This version incorporates many small fixes and feature enhancements made during our SC13 collaboration. +This version incorporates many small fixes and feature enhancements made during our SC13 collaboration. ## Changes in 0.0.6 April 13, 2013 @@ -478,7 +544,7 @@ The compatibility layer is enabled via a configuration option in `bl2_config.h`. ## Changes in 0.0.2 February 11, 2013 -Most notably, this version contains the new test suite I've been working on for the last month. +Most notably, this version contains the new test suite I've been working on for the last month. What is the test suite? It is a highly configurable test driver that allows one to test an arbitrary set of BLIS operations, with an arbitrary set of parameter combinations, and matrix/vector storage formats, as well as whichever datatypes you are interested in. (For now, only homogeneous datatyping is supported, which is what most people want.) You can also specify an arbitrary problem size range with arbitrary increments, and arbitrary ratios between dimensions (or anchor a dimension to a single value), and you can output directly to files which store the output in matlab syntax, which makes it easy to generate performance graphs. diff --git a/docs/graphs/large/l3_perf_epyc_jc1ic8jr4_nt32.pdf b/docs/graphs/large/l3_perf_epyc_jc1ic8jr4_nt32.pdf new file mode 100644 index 000000000..7fbf4abda Binary files /dev/null and b/docs/graphs/large/l3_perf_epyc_jc1ic8jr4_nt32.pdf differ diff --git a/docs/graphs/large/l3_perf_epyc_jc1ic8jr4_nt32.png b/docs/graphs/large/l3_perf_epyc_jc1ic8jr4_nt32.png new file mode 100644 index 000000000..aa12be210 Binary files /dev/null and b/docs/graphs/large/l3_perf_epyc_jc1ic8jr4_nt32.png differ diff --git a/docs/graphs/large/l3_perf_epyc_jc2ic8jr4_nt64.pdf b/docs/graphs/large/l3_perf_epyc_jc2ic8jr4_nt64.pdf new file mode 100644 index 000000000..d7250eff6 Binary files /dev/null and b/docs/graphs/large/l3_perf_epyc_jc2ic8jr4_nt64.pdf differ diff --git a/docs/graphs/large/l3_perf_epyc_jc2ic8jr4_nt64.png b/docs/graphs/large/l3_perf_epyc_jc2ic8jr4_nt64.png new file mode 100644 index 000000000..168de3538 Binary files /dev/null and b/docs/graphs/large/l3_perf_epyc_jc2ic8jr4_nt64.png differ diff --git a/docs/graphs/large/l3_perf_epyc_nt1.pdf b/docs/graphs/large/l3_perf_epyc_nt1.pdf new file mode 100644 index 000000000..4b34f4d27 Binary files /dev/null and b/docs/graphs/large/l3_perf_epyc_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_epyc_nt1.png b/docs/graphs/large/l3_perf_epyc_nt1.png new file mode 100644 index 000000000..f1a2ef5a6 Binary files /dev/null and b/docs/graphs/large/l3_perf_epyc_nt1.png differ diff --git a/docs/graphs/large/l3_perf_has_jc2ic3jr2_nt12.pdf b/docs/graphs/large/l3_perf_has_jc2ic3jr2_nt12.pdf new file mode 100644 index 000000000..3b80889a9 Binary files /dev/null and b/docs/graphs/large/l3_perf_has_jc2ic3jr2_nt12.pdf differ diff --git a/docs/graphs/large/l3_perf_has_jc2ic3jr2_nt12.png b/docs/graphs/large/l3_perf_has_jc2ic3jr2_nt12.png new file mode 100644 index 000000000..08e28fd0d Binary files /dev/null and b/docs/graphs/large/l3_perf_has_jc2ic3jr2_nt12.png differ diff --git a/docs/graphs/large/l3_perf_has_jc4ic3jr2_nt24.pdf b/docs/graphs/large/l3_perf_has_jc4ic3jr2_nt24.pdf new file mode 100644 index 000000000..d55f37bdc Binary files /dev/null and b/docs/graphs/large/l3_perf_has_jc4ic3jr2_nt24.pdf differ diff --git a/docs/graphs/large/l3_perf_has_jc4ic3jr2_nt24.png b/docs/graphs/large/l3_perf_has_jc4ic3jr2_nt24.png new file mode 100644 index 000000000..e3fb023af Binary files /dev/null and b/docs/graphs/large/l3_perf_has_jc4ic3jr2_nt24.png differ diff --git a/docs/graphs/large/l3_perf_has_nt1.pdf b/docs/graphs/large/l3_perf_has_nt1.pdf new file mode 100644 index 000000000..64da9d160 Binary files /dev/null and b/docs/graphs/large/l3_perf_has_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_has_nt1.png b/docs/graphs/large/l3_perf_has_nt1.png new file mode 100644 index 000000000..12651513f Binary files /dev/null and b/docs/graphs/large/l3_perf_has_nt1.png differ diff --git a/docs/graphs/large/l3_perf_skx_jc2ic13_nt26.pdf b/docs/graphs/large/l3_perf_skx_jc2ic13_nt26.pdf new file mode 100644 index 000000000..089625157 Binary files /dev/null and b/docs/graphs/large/l3_perf_skx_jc2ic13_nt26.pdf differ diff --git a/docs/graphs/large/l3_perf_skx_jc2ic13_nt26.png b/docs/graphs/large/l3_perf_skx_jc2ic13_nt26.png new file mode 100644 index 000000000..cf970de36 Binary files /dev/null and b/docs/graphs/large/l3_perf_skx_jc2ic13_nt26.png differ diff --git a/docs/graphs/large/l3_perf_skx_jc4ic13_nt52.pdf b/docs/graphs/large/l3_perf_skx_jc4ic13_nt52.pdf new file mode 100644 index 000000000..eca573ccf Binary files /dev/null and b/docs/graphs/large/l3_perf_skx_jc4ic13_nt52.pdf differ diff --git a/docs/graphs/large/l3_perf_skx_jc4ic13_nt52.png b/docs/graphs/large/l3_perf_skx_jc4ic13_nt52.png new file mode 100644 index 000000000..561357a71 Binary files /dev/null and b/docs/graphs/large/l3_perf_skx_jc4ic13_nt52.png differ diff --git a/docs/graphs/large/l3_perf_skx_nt1.pdf b/docs/graphs/large/l3_perf_skx_nt1.pdf new file mode 100644 index 000000000..e0e4c74b6 Binary files /dev/null and b/docs/graphs/large/l3_perf_skx_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_skx_nt1.png b/docs/graphs/large/l3_perf_skx_nt1.png new file mode 100644 index 000000000..02ad841c0 Binary files /dev/null and b/docs/graphs/large/l3_perf_skx_nt1.png differ diff --git a/docs/graphs/large/l3_perf_tx2_jc4ic7_nt28.pdf b/docs/graphs/large/l3_perf_tx2_jc4ic7_nt28.pdf new file mode 100644 index 000000000..352d0556c Binary files /dev/null and b/docs/graphs/large/l3_perf_tx2_jc4ic7_nt28.pdf differ diff --git a/docs/graphs/large/l3_perf_tx2_jc4ic7_nt28.png b/docs/graphs/large/l3_perf_tx2_jc4ic7_nt28.png new file mode 100644 index 000000000..1b8f23192 Binary files /dev/null and b/docs/graphs/large/l3_perf_tx2_jc4ic7_nt28.png differ diff --git a/docs/graphs/large/l3_perf_tx2_jc8ic7_nt56.pdf b/docs/graphs/large/l3_perf_tx2_jc8ic7_nt56.pdf new file mode 100644 index 000000000..c25ea9eee Binary files /dev/null and b/docs/graphs/large/l3_perf_tx2_jc8ic7_nt56.pdf differ diff --git a/docs/graphs/large/l3_perf_tx2_jc8ic7_nt56.png b/docs/graphs/large/l3_perf_tx2_jc8ic7_nt56.png new file mode 100644 index 000000000..87b039886 Binary files /dev/null and b/docs/graphs/large/l3_perf_tx2_jc8ic7_nt56.png differ diff --git a/docs/graphs/large/l3_perf_tx2_nt1.pdf b/docs/graphs/large/l3_perf_tx2_nt1.pdf new file mode 100644 index 000000000..66c808c9c Binary files /dev/null and b/docs/graphs/large/l3_perf_tx2_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_tx2_nt1.png b/docs/graphs/large/l3_perf_tx2_nt1.png new file mode 100644 index 000000000..058bef36b Binary files /dev/null and b/docs/graphs/large/l3_perf_tx2_nt1.png differ diff --git a/docs/graphs/sup/dgemm_ccc_epyc_nt1.pdf b/docs/graphs/sup/dgemm_ccc_epyc_nt1.pdf new file mode 100644 index 000000000..1d272c7b4 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_epyc_nt1.pdf differ diff --git a/docs/graphs/sup/dgemm_ccc_epyc_nt1.png b/docs/graphs/sup/dgemm_ccc_epyc_nt1.png new file mode 100644 index 000000000..200a29426 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_epyc_nt1.png differ diff --git a/docs/graphs/sup/dgemm_ccc_kbl_nt1.pdf b/docs/graphs/sup/dgemm_ccc_kbl_nt1.pdf new file mode 100644 index 000000000..ea39829f9 Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_kbl_nt1.pdf differ diff --git a/docs/graphs/sup/dgemm_ccc_kbl_nt1.png b/docs/graphs/sup/dgemm_ccc_kbl_nt1.png new file mode 100644 index 000000000..6cf1c58de Binary files /dev/null and b/docs/graphs/sup/dgemm_ccc_kbl_nt1.png differ diff --git a/docs/graphs/sup/dgemm_rrr_epyc_nt1.pdf b/docs/graphs/sup/dgemm_rrr_epyc_nt1.pdf new file mode 100644 index 000000000..ff7ea7055 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_epyc_nt1.pdf differ diff --git a/docs/graphs/sup/dgemm_rrr_epyc_nt1.png b/docs/graphs/sup/dgemm_rrr_epyc_nt1.png new file mode 100644 index 000000000..2a7b7a397 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_epyc_nt1.png differ diff --git a/docs/graphs/sup/dgemm_rrr_kbl_nt1.pdf b/docs/graphs/sup/dgemm_rrr_kbl_nt1.pdf new file mode 100644 index 000000000..5715c130a Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_kbl_nt1.pdf differ diff --git a/docs/graphs/sup/dgemm_rrr_kbl_nt1.png b/docs/graphs/sup/dgemm_rrr_kbl_nt1.png new file mode 100644 index 000000000..cd781c407 Binary files /dev/null and b/docs/graphs/sup/dgemm_rrr_kbl_nt1.png differ diff --git a/examples/oapi/Makefile b/examples/oapi/Makefile index 64dbf20dd..f12ca227b 100644 --- a/examples/oapi/Makefile +++ b/examples/oapi/Makefile @@ -114,7 +114,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Binary executable name. TEST_BINS := 00obj_basic.x \ diff --git a/examples/tapi/Makefile b/examples/tapi/Makefile index 1de4acc13..83330d38b 100644 --- a/examples/tapi/Makefile +++ b/examples/tapi/Makefile @@ -102,7 +102,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Binary executable name. TEST_BINS := 00level1v.x \ diff --git a/frame/0/bli_l0_oapi.c b/frame/0/bli_l0_oapi.c index 9a5492971..2dc37efd1 100644 --- a/frame/0/bli_l0_oapi.c +++ b/frame/0/bli_l0_oapi.c @@ -64,7 +64,7 @@ void PASTEMAC0(opname) \ bli_obj_scalar_set_dt_buffer( chi, dt_absq_c, &dt_chi, &buf_chi ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt_chi ); \ \ f \ @@ -100,7 +100,7 @@ void PASTEMAC0(opname) \ PASTEMAC(opname,_check)( chi, psi ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt ); \ \ f \ @@ -137,7 +137,7 @@ void PASTEMAC0(opname) \ PASTEMAC(opname,_check)( chi ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt ); \ \ f \ @@ -170,7 +170,7 @@ void PASTEMAC0(opname) \ PASTEMAC(opname,_check)( chi, psi ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt ); \ \ f \ @@ -213,7 +213,7 @@ void PASTEMAC0(opname) \ else dt_use = dt_chi; \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt_use ); \ \ f \ @@ -247,7 +247,7 @@ void PASTEMAC0(opname) \ PASTEMAC(opname,_check)( zeta_r, zeta_i, chi ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt_chi ); \ \ f \ @@ -290,7 +290,7 @@ void PASTEMAC0(opname) \ bli_obj_scalar_set_dt_buffer( chi, dt_zeta_c, &dt_chi, &buf_chi ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt_chi ); \ \ f \ @@ -327,7 +327,7 @@ void PASTEMAC0(opname) \ PASTEMAC(opname,_check)( chi, zeta_r, zeta_i ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = PASTEMAC(opname,_qfp)( dt_chi ); \ \ f \ diff --git a/frame/0/bli_l0_oapi.h b/frame/0/bli_l0_oapi.h index f73aa08d2..d0b05606f 100644 --- a/frame/0/bli_l0_oapi.h +++ b/frame/0/bli_l0_oapi.h @@ -40,7 +40,7 @@ #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0(opname) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* chi, \ obj_t* absq \ @@ -53,7 +53,7 @@ GENPROT( normfsc ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0(opname) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* chi, \ obj_t* psi \ @@ -69,7 +69,7 @@ GENPROT( subsc ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0(opname) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* chi \ ); @@ -80,7 +80,7 @@ GENPROT( invertsc ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0(opname) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* chi, \ double* zeta_r, \ @@ -93,7 +93,7 @@ GENPROT( getsc ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0(opname) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ double zeta_r, \ double zeta_i, \ @@ -106,7 +106,7 @@ GENPROT( setsc ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0(opname) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* chi, \ obj_t* zeta_r, \ @@ -119,7 +119,7 @@ GENPROT( unzipsc ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0(opname) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* zeta_r, \ obj_t* zeta_i, \ diff --git a/frame/0/bli_l0_tapi.h b/frame/0/bli_l0_tapi.h index 46c43d935..c2d600d66 100644 --- a/frame/0/bli_l0_tapi.h +++ b/frame/0/bli_l0_tapi.h @@ -40,7 +40,7 @@ #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ conj_t conjchi, \ ctype* chi, \ @@ -56,7 +56,7 @@ INSERT_GENTPROT_BASIC0( subsc ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ conj_t conjchi, \ ctype* chi \ @@ -68,7 +68,7 @@ INSERT_GENTPROT_BASIC0( invertsc ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ ctype* chi, \ ctype_r* absq \ @@ -81,7 +81,7 @@ INSERT_GENTPROTR_BASIC0( normfsc ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ ctype* chi, \ ctype* psi \ @@ -93,7 +93,7 @@ INSERT_GENTPROT_BASIC0( sqrtsc ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ ctype* chi, \ double* zeta_r, \ @@ -106,7 +106,7 @@ INSERT_GENTPROT_BASIC0( getsc ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ double zeta_r, \ double zeta_i, \ @@ -119,7 +119,7 @@ INSERT_GENTPROT_BASIC0( setsc ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ ctype* chi, \ ctype_r* zeta_r, \ @@ -132,7 +132,7 @@ INSERT_GENTPROTR_BASIC0( unzipsc ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ ctype_r* zeta_r, \ ctype_r* zeta_i, \ @@ -143,14 +143,14 @@ INSERT_GENTPROTR_BASIC0( zipsc ) // ----------------------------------------------------------------------------- -void bli_igetsc +BLIS_EXPORT_BLIS void bli_igetsc ( dim_t* chi, double* zeta_r, double* zeta_i ); -void bli_isetsc +BLIS_EXPORT_BLIS void bli_isetsc ( double zeta_r, double zeta_i, diff --git a/frame/0/copysc/bli_copysc.h b/frame/0/copysc/bli_copysc.h index 1d43919ca..1dfd9d7bc 100644 --- a/frame/0/copysc/bli_copysc.h +++ b/frame/0/copysc/bli_copysc.h @@ -40,7 +40,7 @@ #undef GENFRONT #define GENFRONT( opname ) \ \ -void PASTEMAC0(opname) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* chi, \ obj_t* psi \ @@ -55,7 +55,7 @@ GENFRONT( copysc ) #undef GENTPROT2 #define GENTPROT2( ctype_x, ctype_y, chx, chy, varname ) \ \ -void PASTEMAC2(chx,chy,varname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(chx,chy,varname) \ ( \ conj_t conjchi, \ void* chi, \ diff --git a/frame/1/bli_l1v_oapi.c b/frame/1/bli_l1v_oapi.c index 19e61bb7a..201af2e09 100644 --- a/frame/1/bli_l1v_oapi.c +++ b/frame/1/bli_l1v_oapi.c @@ -67,7 +67,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, y ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -113,7 +113,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, index ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -174,7 +174,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -232,7 +232,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -282,7 +282,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, y, rho ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -349,7 +349,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -394,7 +394,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -445,7 +445,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -490,7 +490,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, y ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -545,7 +545,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ diff --git a/frame/1/bli_l1v_oapi.h b/frame/1/bli_l1v_oapi.h index 3124db9c3..41aecdc4d 100644 --- a/frame/1/bli_l1v_oapi.h +++ b/frame/1/bli_l1v_oapi.h @@ -40,7 +40,7 @@ #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* y \ @@ -55,7 +55,7 @@ GENTPROT( subv ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* index \ @@ -68,7 +68,7 @@ GENTPROT( amaxv ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* x, \ @@ -83,7 +83,7 @@ GENTPROT( axpbyv ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* x, \ @@ -98,7 +98,7 @@ GENTPROT( scal2v ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* y, \ @@ -112,7 +112,7 @@ GENTPROT( dotv ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* x, \ @@ -128,7 +128,7 @@ GENTPROT( dotxv ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x \ BLIS_OAPI_EX_PARAMS \ @@ -140,7 +140,7 @@ GENTPROT( invertv ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* x \ @@ -154,7 +154,7 @@ GENTPROT( setv ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* y \ @@ -167,7 +167,7 @@ GENTPROT( swapv ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* beta, \ diff --git a/frame/1/bli_l1v_tapi.h b/frame/1/bli_l1v_tapi.h index 6ddd0c1af..5cb3295ef 100644 --- a/frame/1/bli_l1v_tapi.h +++ b/frame/1/bli_l1v_tapi.h @@ -40,7 +40,7 @@ #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjx, \ dim_t n, \ @@ -57,7 +57,7 @@ INSERT_GENTPROT_BASIC0( subv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ dim_t n, \ ctype* x, inc_t incx, \ @@ -71,7 +71,7 @@ INSERT_GENTPROT_BASIC0( amaxv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjx, \ dim_t n, \ @@ -88,7 +88,7 @@ INSERT_GENTPROT_BASIC0( axpbyv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjx, \ dim_t n, \ @@ -105,7 +105,7 @@ INSERT_GENTPROT_BASIC0( scal2v ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjx, \ conj_t conjy, \ @@ -122,7 +122,7 @@ INSERT_GENTPROT_BASIC0( dotv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjx, \ conj_t conjy, \ @@ -141,7 +141,7 @@ INSERT_GENTPROT_BASIC0( dotxv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ dim_t n, \ ctype* x, inc_t incx \ @@ -154,7 +154,7 @@ INSERT_GENTPROT_BASIC0( invertv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjalpha, \ dim_t n, \ @@ -170,7 +170,7 @@ INSERT_GENTPROT_BASIC0( setv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ dim_t n, \ ctype* x, inc_t incx, \ @@ -184,7 +184,7 @@ INSERT_GENTPROT_BASIC0( swapv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjx, \ dim_t n, \ diff --git a/frame/1/other/packv/bli_packv_cntl.c b/frame/1/other/packv/bli_packv_cntl.c index 05f1472d0..ce4586f3f 100644 --- a/frame/1/other/packv/bli_packv_cntl.c +++ b/frame/1/other/packv/bli_packv_cntl.c @@ -36,8 +36,8 @@ cntl_t* bli_packv_cntl_obj_create ( - void* var_func, - void* packv_var_func, + void_fp var_func, + void_fp packv_var_func, bszid_t bmid, pack_t pack_schema, cntl_t* sub_node diff --git a/frame/1/other/packv/bli_packv_cntl.h b/frame/1/other/packv/bli_packv_cntl.h index 87f33524b..f1ba76a86 100644 --- a/frame/1/other/packv/bli_packv_cntl.h +++ b/frame/1/other/packv/bli_packv_cntl.h @@ -58,8 +58,8 @@ typedef struct packv_params_s packv_params_t; cntl_t* bli_packv_cntl_obj_create ( - void* var_func, - void* packv_var_func, + void_fp var_func, + void_fp packv_var_func, bszid_t bmid, pack_t pack_schema, cntl_t* sub_node diff --git a/frame/1/other/scalv/bli_scalv_cntl.h b/frame/1/other/scalv/bli_scalv_cntl.h index c97536387..057ed60e8 100644 --- a/frame/1/other/scalv/bli_scalv_cntl.h +++ b/frame/1/other/scalv/bli_scalv_cntl.h @@ -41,8 +41,8 @@ typedef struct scalv_s scalv_t; #define bli_cntl_sub_scalv( cntl ) cntl->sub_scalv -void bli_scalv_cntl_init( void ); -void bli_scalv_cntl_finalize( void ); +void bli_scalv_cntl_init( void ); +void bli_scalv_cntl_finalize( void ); scalv_t* bli_scalv_cntl_obj_create( impl_t impl_type, varnum_t var_num ); void bli_scalv_cntl_obj_init( scalv_t* cntl, diff --git a/frame/1/other/unpackv/bli_unpackv_cntl.h b/frame/1/other/unpackv/bli_unpackv_cntl.h index 0defc6803..3d62be8c5 100644 --- a/frame/1/other/unpackv/bli_unpackv_cntl.h +++ b/frame/1/other/unpackv/bli_unpackv_cntl.h @@ -45,8 +45,8 @@ typedef struct unpackv_s unpackv_t; #define bli_cntl_sub_unpackv_y( cntl ) cntl->sub_unpackv_y #define bli_cntl_sub_unpackv_y1( cntl ) cntl->sub_unpackv_y1 -void bli_unpackv_cntl_init( void ); -void bli_unpackv_cntl_finalize( void ); +void bli_unpackv_cntl_init( void ); +void bli_unpackv_cntl_finalize( void ); unpackv_t* bli_unpackv_cntl_obj_create( impl_t impl_type, varnum_t var_num ); void bli_unpackv_cntl_obj_init( unpackv_t* cntl, diff --git a/frame/1d/bli_l1d_oapi.c b/frame/1d/bli_l1d_oapi.c index 1a8b8f124..15e68cf50 100644 --- a/frame/1d/bli_l1d_oapi.c +++ b/frame/1d/bli_l1d_oapi.c @@ -72,7 +72,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, y ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -138,7 +138,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -187,7 +187,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -243,7 +243,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -293,7 +293,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( alpha, x ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -349,7 +349,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -411,7 +411,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ diff --git a/frame/1d/bli_l1d_oapi.h b/frame/1d/bli_l1d_oapi.h index d0e39b313..47129b771 100644 --- a/frame/1d/bli_l1d_oapi.h +++ b/frame/1d/bli_l1d_oapi.h @@ -40,7 +40,7 @@ #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* y \ @@ -55,7 +55,7 @@ GENTPROT( subd ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* x, \ @@ -70,7 +70,7 @@ GENTPROT( scal2d ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x \ BLIS_OAPI_EX_PARAMS \ @@ -82,7 +82,7 @@ GENTPROT( invertd ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* x \ @@ -98,7 +98,7 @@ GENTPROT( shiftd ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* beta, \ diff --git a/frame/1d/bli_l1d_tapi.h b/frame/1d/bli_l1d_tapi.h index 823858578..35d093e86 100644 --- a/frame/1d/bli_l1d_tapi.h +++ b/frame/1d/bli_l1d_tapi.h @@ -40,7 +40,7 @@ #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ doff_t diagoffx, \ diag_t diagx, \ @@ -60,7 +60,7 @@ INSERT_GENTPROT_BASIC0( subd ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ doff_t diagoffx, \ diag_t diagx, \ @@ -80,7 +80,7 @@ INSERT_GENTPROT_BASIC0( scal2d ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ doff_t diagoffx, \ dim_t m, \ @@ -95,7 +95,7 @@ INSERT_GENTPROT_BASIC0( invertd ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjalpha, \ doff_t diagoffx, \ @@ -113,7 +113,7 @@ INSERT_GENTPROT_BASIC0( setd ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ doff_t diagoffx, \ dim_t m, \ @@ -129,7 +129,7 @@ INSERT_GENTPROTR_BASIC0( setid ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ doff_t diagoffx, \ dim_t m, \ @@ -145,7 +145,7 @@ INSERT_GENTPROT_BASIC0( shiftd ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ doff_t diagoffx, \ diag_t diagx, \ diff --git a/frame/1f/bli_l1f_oapi.c b/frame/1f/bli_l1f_oapi.c index d1e7f0dbe..db8fdfb68 100644 --- a/frame/1f/bli_l1f_oapi.c +++ b/frame/1f/bli_l1f_oapi.c @@ -88,7 +88,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alphay = bli_obj_buffer_for_1x1( dt, &alphay_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -157,7 +157,7 @@ void PASTEMAC(opname,EX_SUF) \ if ( bli_obj_has_trans( a ) ) { bli_swap_incs( &rs_a, &cs_a ); } \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -225,7 +225,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -310,7 +310,7 @@ void PASTEMAC(opname,EX_SUF) \ if ( bli_obj_has_trans( a ) ) { bli_swap_incs( &rs_a, &cs_a ); } \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -390,7 +390,7 @@ void PASTEMAC(opname,EX_SUF) \ if ( bli_obj_has_trans( a ) ) { bli_swap_incs( &rs_a, &cs_a ); } \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ diff --git a/frame/1f/bli_l1f_oapi.h b/frame/1f/bli_l1f_oapi.h index 47fdf381a..0348c4871 100644 --- a/frame/1f/bli_l1f_oapi.h +++ b/frame/1f/bli_l1f_oapi.h @@ -40,7 +40,7 @@ #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alphax, \ obj_t* alphay, \ @@ -56,7 +56,7 @@ GENTPROT( axpy2v ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* a, \ @@ -71,7 +71,7 @@ GENTPROT( axpyf ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* xt, \ @@ -88,7 +88,7 @@ GENTPROT( dotaxpyv ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* at, \ @@ -107,7 +107,7 @@ GENTPROT( dotxaxpyf ) #undef GENTPROT #define GENTPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* a, \ diff --git a/frame/1f/bli_l1f_tapi.h b/frame/1f/bli_l1f_tapi.h index 54361e8e6..2138b989d 100644 --- a/frame/1f/bli_l1f_tapi.h +++ b/frame/1f/bli_l1f_tapi.h @@ -40,7 +40,7 @@ #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjx, \ conj_t conjy, \ @@ -59,7 +59,7 @@ INSERT_GENTPROT_BASIC0( axpy2v ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conja, \ conj_t conjx, \ @@ -78,7 +78,7 @@ INSERT_GENTPROT_BASIC0( axpyf ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjxt, \ conj_t conjx, \ @@ -98,7 +98,7 @@ INSERT_GENTPROT_BASIC0( dotaxpyv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjat, \ conj_t conja, \ @@ -122,7 +122,7 @@ INSERT_GENTPROT_BASIC0( dotxaxpyf ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjat, \ conj_t conjx, \ diff --git a/frame/1m/bli_l1m_oapi.c b/frame/1m/bli_l1m_oapi.c index 4bb0de784..224a41bc9 100644 --- a/frame/1m/bli_l1m_oapi.c +++ b/frame/1m/bli_l1m_oapi.c @@ -73,7 +73,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, y ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -141,7 +141,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -218,7 +218,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_internal_scalar_buffer( &x_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -280,7 +280,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -349,7 +349,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -414,7 +414,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dty, &beta_local ); \ \ /* Query a (multi) type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp2)( dtx, dty ); \ \ diff --git a/frame/1m/bli_l1m_oapi.h b/frame/1m/bli_l1m_oapi.h index 3ca023deb..a6a94cf9f 100644 --- a/frame/1m/bli_l1m_oapi.h +++ b/frame/1m/bli_l1m_oapi.h @@ -40,7 +40,7 @@ #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* y \ @@ -55,7 +55,7 @@ GENPROT( subm ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* x, \ @@ -70,7 +70,7 @@ GENPROT( scal2m ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* x \ @@ -84,7 +84,7 @@ GENPROT( setm ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* beta, \ diff --git a/frame/1m/bli_l1m_tapi.h b/frame/1m/bli_l1m_tapi.h index a2592f8ba..03a1196ed 100644 --- a/frame/1m/bli_l1m_tapi.h +++ b/frame/1m/bli_l1m_tapi.h @@ -40,7 +40,7 @@ #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ doff_t diagoffx, \ diag_t diagx, \ @@ -61,7 +61,7 @@ INSERT_GENTPROT_BASIC0( subm ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ doff_t diagoffx, \ diag_t diagx, \ @@ -82,7 +82,7 @@ INSERT_GENTPROT_BASIC0( scal2m ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjalpha, \ doff_t diagoffx, \ @@ -102,7 +102,7 @@ INSERT_GENTPROT_BASIC0( setm ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ doff_t diagoffx, \ diag_t diagx, \ @@ -122,7 +122,7 @@ INSERT_GENTPROT_BASIC0( xpbym ) #undef GENTPROT2 #define GENTPROT2( ctype_x, ctype_y, chx, chy, opname ) \ \ -void PASTEMAC3(chx,chy,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC3(chx,chy,opname,EX_SUF) \ ( \ doff_t diagoffx, \ diag_t diagx, \ diff --git a/frame/1m/other/bli_scalm_cntl.c b/frame/1m/other/bli_scalm_cntl.c index afff3c164..b815dc530 100644 --- a/frame/1m/other/bli_scalm_cntl.c +++ b/frame/1m/other/bli_scalm_cntl.c @@ -36,7 +36,7 @@ cntl_t* bli_scalm_cntl_create_node ( - void* var_func, + void_fp var_func, cntl_t* sub_node ) { diff --git a/frame/1m/other/bli_scalm_cntl.h b/frame/1m/other/bli_scalm_cntl.h index 0d589f207..32a02f5da 100644 --- a/frame/1m/other/bli_scalm_cntl.h +++ b/frame/1m/other/bli_scalm_cntl.h @@ -35,6 +35,6 @@ cntl_t* bli_scalm_cntl_create_node ( - void* var_func, + void_fp var_func, cntl_t* sub_node ); diff --git a/frame/1m/packm/bli_packm_blk_var1.c b/frame/1m/packm/bli_packm_blk_var1.c index 6afc48fad..0c19829f2 100644 --- a/frame/1m/packm/bli_packm_blk_var1.c +++ b/frame/1m/packm/bli_packm_blk_var1.c @@ -57,7 +57,7 @@ typedef void (*FUNCPTR_T) void* p, inc_t rs_p, inc_t cs_p, inc_t is_p, dim_t pd_p, inc_t ps_p, - void* packm_ker, + void_fp packm_ker, cntx_t* cntx, thrinfo_t* thread ); @@ -152,7 +152,7 @@ void bli_packm_blk_var1 void* buf_kappa; func_t* packm_kers; - void* packm_ker; + void_fp packm_ker; FUNCPTR_T f; @@ -296,7 +296,7 @@ void PASTEMAC(ch,varname) \ void* p, inc_t rs_p, inc_t cs_p, \ inc_t is_p, \ dim_t pd_p, inc_t ps_p, \ - void* packm_ker, \ + void_fp packm_ker, \ cntx_t* cntx, \ thrinfo_t* thread \ ) \ diff --git a/frame/1m/packm/bli_packm_cntl.c b/frame/1m/packm/bli_packm_cntl.c index 5321d873a..5b3e26421 100644 --- a/frame/1m/packm/bli_packm_cntl.c +++ b/frame/1m/packm/bli_packm_cntl.c @@ -38,8 +38,8 @@ cntl_t* bli_packm_cntl_create_node ( rntm_t* rntm, - void* var_func, - void* packm_var_func, + void_fp var_func, + void_fp packm_var_func, bszid_t bmid_m, bszid_t bmid_n, bool_t does_invert_diag, diff --git a/frame/1m/packm/bli_packm_cntl.h b/frame/1m/packm/bli_packm_cntl.h index bbf0ea332..5ce1801b9 100644 --- a/frame/1m/packm/bli_packm_cntl.h +++ b/frame/1m/packm/bli_packm_cntl.h @@ -92,8 +92,8 @@ static packbuf_t bli_cntl_packm_params_pack_buf_type( cntl_t* cntl ) cntl_t* bli_packm_cntl_create_node ( rntm_t* rntm, - void* var_func, - void* packm_var_func, + void_fp var_func, + void_fp packm_var_func, bszid_t bmid_m, bszid_t bmid_n, bool_t does_invert_diag, diff --git a/frame/1m/packm/bli_packm_init.h b/frame/1m/packm/bli_packm_init.h index 6896ab913..9365a131e 100644 --- a/frame/1m/packm/bli_packm_init.h +++ b/frame/1m/packm/bli_packm_init.h @@ -40,7 +40,7 @@ siz_t bli_packm_init cntl_t* cntl ); -siz_t bli_packm_init_pack +BLIS_EXPORT_BLIS siz_t bli_packm_init_pack ( invdiag_t invert_diag, pack_t schema, diff --git a/frame/1m/packm/bli_packm_var.h b/frame/1m/packm/bli_packm_var.h index ce2065fd1..97aa875bd 100644 --- a/frame/1m/packm/bli_packm_var.h +++ b/frame/1m/packm/bli_packm_var.h @@ -40,7 +40,7 @@ #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0(opname) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* c, \ obj_t* p, \ @@ -101,7 +101,7 @@ void PASTEMAC(ch,varname) \ void* p, inc_t rs_p, inc_t cs_p, \ inc_t is_p, \ dim_t pd_p, inc_t ps_p, \ - void* packm_ker, \ + void_fp packm_ker, \ cntx_t* cntx, \ thrinfo_t* thread \ ); diff --git a/frame/1m/unpackm/bli_unpackm_cntl.c b/frame/1m/unpackm/bli_unpackm_cntl.c index 7c634a11e..f2be05a54 100644 --- a/frame/1m/unpackm/bli_unpackm_cntl.c +++ b/frame/1m/unpackm/bli_unpackm_cntl.c @@ -38,8 +38,8 @@ cntl_t* bli_unpackm_cntl_create_node ( rntm_t* rntm, - void* var_func, - void* unpackm_var_func, + void_fp var_func, + void_fp unpackm_var_func, cntl_t* sub_node ) { diff --git a/frame/1m/unpackm/bli_unpackm_cntl.h b/frame/1m/unpackm/bli_unpackm_cntl.h index 49a8b19f6..5c41d9465 100644 --- a/frame/1m/unpackm/bli_unpackm_cntl.h +++ b/frame/1m/unpackm/bli_unpackm_cntl.h @@ -49,8 +49,8 @@ typedef struct unpackm_params_s unpackm_params_t; cntl_t* bli_unpackm_cntl_create_node ( rntm_t* rntm, - void* var_func, - void* unpackm_var_func, + void_fp var_func, + void_fp unpackm_var_func, cntl_t* sub_node ); diff --git a/frame/2/bli_l2_oapi.c b/frame/2/bli_l2_oapi.c index 25acb4207..cc32fb61e 100644 --- a/frame/2/bli_l2_oapi.c +++ b/frame/2/bli_l2_oapi.c @@ -90,7 +90,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -157,7 +157,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -229,7 +229,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_beta = bli_obj_buffer_for_1x1( dt, &beta_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -293,7 +293,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -358,7 +358,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -422,7 +422,7 @@ void PASTEMAC(opname,EX_SUF) \ buf_alpha = bli_obj_buffer_for_1x1( dt, &alpha_local ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ diff --git a/frame/2/bli_l2_oapi.h b/frame/2/bli_l2_oapi.h index eb0f47249..6b6a1d77e 100644 --- a/frame/2/bli_l2_oapi.h +++ b/frame/2/bli_l2_oapi.h @@ -40,7 +40,7 @@ #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* a, \ @@ -58,7 +58,7 @@ GENPROT( symv ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* x, \ @@ -75,7 +75,7 @@ GENPROT( syr2 ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* x, \ @@ -90,7 +90,7 @@ GENPROT( syr ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* a, \ diff --git a/frame/2/bli_l2_tapi.h b/frame/2/bli_l2_tapi.h index b66ccfb7a..4b45236e2 100644 --- a/frame/2/bli_l2_tapi.h +++ b/frame/2/bli_l2_tapi.h @@ -40,7 +40,7 @@ #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ trans_t transa, \ conj_t conjx, \ @@ -60,7 +60,7 @@ INSERT_GENTPROT_BASIC0( gemv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ conj_t conjx, \ conj_t conjy, \ @@ -79,7 +79,7 @@ INSERT_GENTPROT_BASIC0( ger ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ uplo_t uploa, \ conj_t conja, \ @@ -100,7 +100,7 @@ INSERT_GENTPROT_BASIC0( symv ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ uplo_t uploa, \ conj_t conjx, \ @@ -117,7 +117,7 @@ INSERT_GENTPROTR_BASIC0( her ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ uplo_t uploa, \ conj_t conjx, \ @@ -134,7 +134,7 @@ INSERT_GENTPROT_BASIC0( syr ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ uplo_t uploa, \ conj_t conjx, \ @@ -154,7 +154,7 @@ INSERT_GENTPROT_BASIC0( syr2 ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ uplo_t uploa, \ trans_t transa, \ diff --git a/frame/2/gemv/bli_gemv_var_oapi.c b/frame/2/gemv/bli_gemv_var_oapi.c index 2e746b417..865773534 100644 --- a/frame/2/gemv/bli_gemv_var_oapi.c +++ b/frame/2/gemv/bli_gemv_var_oapi.c @@ -72,7 +72,7 @@ void PASTEMAC0(varname) \ void* buf_beta = bli_obj_buffer_for_1x1( dt, beta ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/gemv/other/bli_gemv_cntl.h b/frame/2/gemv/other/bli_gemv_cntl.h index f505e8997..e60fd8b5a 100644 --- a/frame/2/gemv/other/bli_gemv_cntl.h +++ b/frame/2/gemv/other/bli_gemv_cntl.h @@ -54,8 +54,8 @@ typedef struct gemv_s gemv_t; #define bli_cntl_sub_gemv_t_rp( cntl ) cntl->sub_gemv_t_rp #define bli_cntl_sub_gemv_t_cp( cntl ) cntl->sub_gemv_t_cp -void bli_gemv_cntl_init( void ); -void bli_gemv_cntl_finalize( void ); +void bli_gemv_cntl_init( void ); +void bli_gemv_cntl_finalize( void ); gemv_t* bli_gemv_cntl_obj_create( impl_t impl_type, varnum_t var_num, bszid_t bszid, diff --git a/frame/2/ger/bli_ger_var_oapi.c b/frame/2/ger/bli_ger_var_oapi.c index 3fd95e89f..f125efdf8 100644 --- a/frame/2/ger/bli_ger_var_oapi.c +++ b/frame/2/ger/bli_ger_var_oapi.c @@ -70,7 +70,7 @@ void PASTEMAC0(varname) \ void* buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/ger/other/bli_ger_cntl.h b/frame/2/ger/other/bli_ger_cntl.h index b7b460177..c5e0ebd3a 100644 --- a/frame/2/ger/other/bli_ger_cntl.h +++ b/frame/2/ger/other/bli_ger_cntl.h @@ -49,8 +49,8 @@ typedef struct ger_s ger_t; #define bli_cntl_sub_ger_rp( cntl ) cntl->sub_ger_rp #define bli_cntl_sub_ger_cp( cntl ) cntl->sub_ger_cp -void bli_ger_cntl_init( void ); -void bli_ger_cntl_finalize( void ); +void bli_ger_cntl_init( void ); +void bli_ger_cntl_finalize( void ); ger_t* bli_ger_cntl_obj_create( impl_t impl_type, varnum_t var_num, bszid_t bszid, diff --git a/frame/2/hemv/bli_hemv_var_oapi.c b/frame/2/hemv/bli_hemv_var_oapi.c index 845f288c3..bf0e4b202 100644 --- a/frame/2/hemv/bli_hemv_var_oapi.c +++ b/frame/2/hemv/bli_hemv_var_oapi.c @@ -73,7 +73,7 @@ void PASTEMAC0(varname) \ void* buf_beta = bli_obj_buffer_for_1x1( dt, beta ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/hemv/other/bli_hemv_cntl.h b/frame/2/hemv/other/bli_hemv_cntl.h index fba7b19b4..12f927056 100644 --- a/frame/2/hemv/other/bli_hemv_cntl.h +++ b/frame/2/hemv/other/bli_hemv_cntl.h @@ -52,8 +52,8 @@ typedef struct hemv_s hemv_t; #define bli_cntl_sub_hemv( cntl ) cntl->sub_hemv -void bli_hemv_cntl_init( void ); -void bli_hemv_cntl_finalize( void ); +void bli_hemv_cntl_init( void ); +void bli_hemv_cntl_finalize( void ); hemv_t* bli_hemv_cntl_obj_create( impl_t impl_type, varnum_t var_num, bszid_t bszid, diff --git a/frame/2/her/bli_her_var_oapi.c b/frame/2/her/bli_her_var_oapi.c index ffca2e71e..44c6d090d 100644 --- a/frame/2/her/bli_her_var_oapi.c +++ b/frame/2/her/bli_her_var_oapi.c @@ -66,7 +66,7 @@ void PASTEMAC0(varname) \ void* buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/her/other/bli_her_cntl.h b/frame/2/her/other/bli_her_cntl.h index 280492b29..45d24ba28 100644 --- a/frame/2/her/other/bli_her_cntl.h +++ b/frame/2/her/other/bli_her_cntl.h @@ -47,8 +47,8 @@ typedef struct her_s her_t; #define bli_cntl_sub_her( cntl ) cntl->sub_her -void bli_her_cntl_init( void ); -void bli_her_cntl_finalize( void ); +void bli_her_cntl_init( void ); +void bli_her_cntl_finalize( void ); her_t* bli_her_cntl_obj_create( impl_t impl_type, varnum_t var_num, bszid_t bszid, diff --git a/frame/2/her2/bli_her2_var_oapi.c b/frame/2/her2/bli_her2_var_oapi.c index 2b26e5476..dce87a1cd 100644 --- a/frame/2/her2/bli_her2_var_oapi.c +++ b/frame/2/her2/bli_her2_var_oapi.c @@ -72,7 +72,7 @@ void PASTEMAC0(varname) \ void* buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/her2/other/bli_her2_cntl.h b/frame/2/her2/other/bli_her2_cntl.h index 4eca63af0..4034b7260 100644 --- a/frame/2/her2/other/bli_her2_cntl.h +++ b/frame/2/her2/other/bli_her2_cntl.h @@ -49,8 +49,8 @@ typedef struct her2_s her2_t; #define bli_cntl_sub_her2( cntl ) cntl->sub_her2 -void bli_her2_cntl_init( void ); -void bli_her2_cntl_finalize( void ); +void bli_her2_cntl_init( void ); +void bli_her2_cntl_finalize( void ); her2_t* bli_her2_cntl_obj_create( impl_t impl_type, varnum_t var_num, bszid_t bszid, diff --git a/frame/2/trmv/bli_trmv_var_oapi.c b/frame/2/trmv/bli_trmv_var_oapi.c index 931eb2abb..c74d31223 100644 --- a/frame/2/trmv/bli_trmv_var_oapi.c +++ b/frame/2/trmv/bli_trmv_var_oapi.c @@ -66,7 +66,7 @@ void PASTEMAC0(varname) \ void* buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/trmv/other/bli_trmv_cntl.h b/frame/2/trmv/other/bli_trmv_cntl.h index 2474f5f6d..9693a5ad9 100644 --- a/frame/2/trmv/other/bli_trmv_cntl.h +++ b/frame/2/trmv/other/bli_trmv_cntl.h @@ -48,8 +48,8 @@ typedef struct trmv_s trmv_t; #define bli_cntl_sub_trmv( cntl ) cntl->sub_trmv -void bli_trmv_cntl_init( void ); -void bli_trmv_cntl_finalize( void ); +void bli_trmv_cntl_init( void ); +void bli_trmv_cntl_finalize( void ); trmv_t* bli_trmv_cntl_obj_create( impl_t impl_type, varnum_t var_num, bszid_t bszid, diff --git a/frame/2/trsv/bli_trsv_var_oapi.c b/frame/2/trsv/bli_trsv_var_oapi.c index 4cf346acf..62ac33e45 100644 --- a/frame/2/trsv/bli_trsv_var_oapi.c +++ b/frame/2/trsv/bli_trsv_var_oapi.c @@ -66,7 +66,7 @@ void PASTEMAC0(varname) \ void* buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,_unb,_vft) f = \ PASTEMAC(varname,_qfp)( dt ); \ \ diff --git a/frame/2/trsv/other/bli_trsv_cntl.h b/frame/2/trsv/other/bli_trsv_cntl.h index cb53c0fe9..4cb90a17d 100644 --- a/frame/2/trsv/other/bli_trsv_cntl.h +++ b/frame/2/trsv/other/bli_trsv_cntl.h @@ -49,8 +49,8 @@ typedef struct trsv_s trsv_t; #define bli_cntl_sub_trsv( cntl ) cntl->sub_trsv -void bli_trsv_cntl_init( void ); -void bli_trsv_cntl_finalize( void ); +void bli_trsv_cntl_init( void ); +void bli_trsv_cntl_finalize( void ); trsv_t* bli_trsv_cntl_obj_create( impl_t impl_type, varnum_t var_num, bszid_t bszid, diff --git a/frame/3/bli_l3.h b/frame/3/bli_l3.h index 7f2879c02..e60af306b 100644 --- a/frame/3/bli_l3.h +++ b/frame/3/bli_l3.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -60,14 +61,28 @@ #include "bli_tapi_ba.h" #include "bli_l3_tapi.h" -// Prototype microkernel wrapper APIs +// Define function types for small/unpacked handlers/kernels. +#include "bli_l3_sup_oft.h" +#include "bli_l3_sup_ft_ker.h" + +// Define static edge case logic for use in small/unpacked kernels. +//#include "bli_l3_sup_edge.h" + +// Prototype object API to small/unpacked matrix dispatcher. +#include "bli_l3_sup.h" + +// Prototype reference implementation of small/unpacked matrix handler. +#include "bli_l3_sup_ref.h" +#include "bli_l3_sup_vars.h" + +// Prototype microkernel wrapper APIs. #include "bli_l3_ukr_oapi.h" #include "bli_l3_ukr_tapi.h" // Generate function pointer arrays for tapi microkernel functions. #include "bli_l3_ukr_fpa.h" -// Operation-specific headers +// Operation-specific headers. #include "bli_gemm.h" #include "bli_hemm.h" #include "bli_herk.h" diff --git a/frame/3/bli_l3_oapi.c b/frame/3/bli_l3_oapi.c index d9ba27369..29eca1bf1 100644 --- a/frame/3/bli_l3_oapi.c +++ b/frame/3/bli_l3_oapi.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -56,13 +57,30 @@ void PASTEMAC(opname,EX_SUF) \ bli_init_once(); \ \ BLIS_OAPI_EX_DECLS \ +\ + /* If the rntm is non-NULL, it may indicate that we should forgo sup + handling altogether. */ \ + bool_t enable_sup = TRUE; \ + if ( rntm != NULL ) enable_sup = bli_rntm_l3_sup( rntm ); \ +\ + if ( enable_sup ) \ + { \ + /* Execute the small/unpacked oapi handler. If it finds that the problem + does not fall within the thresholds that define "small", or for some + other reason decides not to use the small/unpacked implementation, + the function returns with BLIS_FAILURE, which causes execution to + proceed towards the conventional implementation. */ \ + err_t result = PASTEMAC(opname,sup)( alpha, a, b, beta, c, cntx, rntm ); \ + if ( result == BLIS_SUCCESS ) return; \ + } \ \ /* Only proceed with an induced method if each of the operands have a complex storage datatype. NOTE: Allowing precisions to vary while using 1m, which is what we do here, is unique to gemm; other level-3 - operations use 1m only if all storage datatypes are equal (including - the computation datatype). If any operands are real, skip the induced - method chooser function and proceed directly with native execution. */ \ + operations use 1m only if all storage datatypes are equal (and they + ignore the computation precision). If any operands are real, skip the + induced method chooser function and proceed directly with native + execution. */ \ if ( bli_obj_is_complex( c ) && \ bli_obj_is_complex( a ) && \ bli_obj_is_complex( b ) ) \ @@ -81,6 +99,49 @@ void PASTEMAC(opname,EX_SUF) \ } GENFRONT( gemm ) + + +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC(opname,EX_SUF) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c \ + BLIS_OAPI_EX_PARAMS \ + ) \ +{ \ + bli_init_once(); \ +\ + BLIS_OAPI_EX_DECLS \ +\ + /* Only proceed with an induced method if each of the operands have a + complex storage datatype. NOTE: Allowing precisions to vary while + using 1m, which is what we do here, is unique to gemm; other level-3 + operations use 1m only if all storage datatypes are equal (and they + ignore the computation precision). If any operands are real, skip the + induced method chooser function and proceed directly with native + execution. */ \ + if ( bli_obj_is_complex( c ) && \ + bli_obj_is_complex( a ) && \ + bli_obj_is_complex( b ) ) \ + { \ + /* Invoke the operation's "ind" function--its induced method front-end. + For complex problems, it calls the highest priority induced method + that is available (ie: implemented and enabled), and if none are + enabled, it calls native execution. (For real problems, it calls + the operation's native execution interface.) */ \ + PASTEMAC(opname,ind)( alpha, a, b, beta, c, cntx, rntm ); \ + } \ + else \ + { \ + PASTEMAC(opname,nat)( alpha, a, b, beta, c, cntx, rntm ); \ + } \ +} + GENFRONT( her2k ) GENFRONT( syr2k ) diff --git a/frame/3/bli_l3_oapi.h b/frame/3/bli_l3_oapi.h index 2f0af81b2..4f9f20608 100644 --- a/frame/3/bli_l3_oapi.h +++ b/frame/3/bli_l3_oapi.h @@ -40,7 +40,7 @@ #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* a, \ @@ -58,7 +58,7 @@ GENPROT( syr2k ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ side_t side, \ obj_t* alpha, \ @@ -77,7 +77,7 @@ GENPROT( trmm3 ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* alpha, \ obj_t* a, \ @@ -93,7 +93,7 @@ GENPROT( syrk ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ side_t side, \ obj_t* alpha, \ diff --git a/frame/3/bli_l3_sup.c b/frame/3/bli_l3_sup.c new file mode 100644 index 000000000..05aaf7bcb --- /dev/null +++ b/frame/3/bli_l3_sup.c @@ -0,0 +1,136 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +err_t bli_gemmsup + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Return early if small matrix handling is disabled at configure-time. + #ifdef BLIS_DISABLE_SUP_HANDLING + return BLIS_FAILURE; + #endif + + // Return early if this is a mixed-datatype computation. + if ( bli_obj_dt( c ) != bli_obj_dt( a ) || + bli_obj_dt( c ) != bli_obj_dt( b ) || + bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) return BLIS_FAILURE; + + // Obtain a valid (native) context from the gks if necessary. + // NOTE: This must be done before calling the _check() function, since + // that function assumes the context pointer is valid. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + +#if 0 + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_thread_init_rntm( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } +#endif + + // Return early if a microkernel preference-induced transposition would + // have been performed and shifted the dimensions outside of the space + // of sup-handled problems. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( c, BLIS_GEMM_UKR, cntx ) ) + { + const num_t dt = bli_obj_dt( c ); + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width_after_trans( a ); + + // Pass in m and n reversed, which simulates a transposition of the + // entire operation pursuant to the microkernel storage preference. + if ( !bli_cntx_l3_sup_thresh_is_met( dt, n, m, k, cntx ) ) + return BLIS_FAILURE; + } + else // ukr_prefers_storage_of( c, ... ) + { + const num_t dt = bli_obj_dt( c ); + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width_after_trans( a ); + + if ( !bli_cntx_l3_sup_thresh_is_met( dt, m, n, k, cntx ) ) + return BLIS_FAILURE; + } + +#if 0 +const num_t dt = bli_obj_dt( c ); +const dim_t m = bli_obj_length( c ); +const dim_t n = bli_obj_width( c ); +const dim_t k = bli_obj_width_after_trans( a ); +const dim_t tm = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ); +const dim_t tn = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ); +const dim_t tk = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ); + +printf( "dims: %d %d %d (threshs: %d %d %d)\n", + (int)m, (int)n, (int)k, (int)tm, (int)tn, (int)tk ); +#endif + + // We've now ruled out the following two possibilities: + // - the ukernel prefers the operation as-is, and the sup thresholds are + // unsatisfied. + // - the ukernel prefers a transposed operation, and the sup thresholds are + // unsatisfied after taking into account the transposition. + // This implies that the sup thresholds (at least one of them) are met. + // and the small/unpacked handler should be called. + // NOTE: The sup handler is free to enforce a stricter threshold regime + // if it so chooses, in which case it can/should return BLIS_FAILURE. + + // Query the small/unpacked handler from the context and invoke it. + gemmsup_oft gemmsup_fp = bli_cntx_get_l3_sup_handler( BLIS_GEMM, cntx ); + + return + gemmsup_fp + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm + ); +} + + diff --git a/frame/3/bli_l3_sup.h b/frame/3/bli_l3_sup.h new file mode 100644 index 000000000..f0a3a559b --- /dev/null +++ b/frame/3/bli_l3_sup.h @@ -0,0 +1,45 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +err_t bli_gemmsup + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + diff --git a/frame/3/bli_l3_sup_ft_ker.h b/frame/3/bli_l3_sup_ft_ker.h new file mode 100644 index 000000000..5bb2218f3 --- /dev/null +++ b/frame/3/bli_l3_sup_ft_ker.h @@ -0,0 +1,68 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_L3_SUP_FT_KER_H +#define BLIS_L3_SUP_FT_KER_H + + +// +// -- Level-3 small/unpacked kernel function types ----------------------------- +// + +// gemmsup + +#undef GENTDEF +#define GENTDEF( ctype, ch, opname, tsuf ) \ +\ +typedef void (*PASTECH3(ch,opname,_ker,tsuf)) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ); + +INSERT_GENTDEF( gemmsup ) + + +#endif + diff --git a/frame/3/bli_l3_sup_ker.h b/frame/3/bli_l3_sup_ker.h new file mode 100644 index 000000000..6c77fffe0 --- /dev/null +++ b/frame/3/bli_l3_sup_ker.h @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// Define template prototypes for level-3 kernels on small/unpacked matrices. +// + +// Note: Instead of defining function prototype macro templates and then +// instantiating those macros to define the individual function prototypes, +// we simply alias the official operations' prototypes as defined in +// bli_l3_ker_prot.h. + +#undef GENTPROT +#define GENTPROT GEMMSUP_KER_PROT + +INSERT_GENTPROT_BASIC0( gemmsup_rv_ukr_name ) +INSERT_GENTPROT_BASIC0( gemmsup_rg_ukr_name ) +INSERT_GENTPROT_BASIC0( gemmsup_cv_ukr_name ) +INSERT_GENTPROT_BASIC0( gemmsup_cg_ukr_name ) + +INSERT_GENTPROT_BASIC0( gemmsup_rd_ukr_name ) +INSERT_GENTPROT_BASIC0( gemmsup_cd_ukr_name ) + +INSERT_GENTPROT_BASIC0( gemmsup_gx_ukr_name ) + diff --git a/frame/3/bli_l3_sup_ker_prot.h b/frame/3/bli_l3_sup_ker_prot.h new file mode 100644 index 000000000..899a47d3f --- /dev/null +++ b/frame/3/bli_l3_sup_ker_prot.h @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// Define template prototypes for level-3 kernels on small/unpacked matrices. +// + +#define GEMMSUP_KER_PROT( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ); + diff --git a/frame/3/bli_l3_sup_oft.h b/frame/3/bli_l3_sup_oft.h new file mode 100644 index 000000000..a06d28789 --- /dev/null +++ b/frame/3/bli_l3_sup_oft.h @@ -0,0 +1,62 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_L3_SUP_OFT_H +#define BLIS_L3_SUP_OFT_H + + +// +// -- Level-3 small/unpacked object function types ----------------------------- +// + +// gemm + +#undef GENTDEF +#define GENTDEF( opname ) \ +\ +typedef err_t (*PASTECH(opname,_oft)) \ +( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ +); + +GENTDEF( gemmsup ) + +#endif + diff --git a/frame/3/bli_l3_sup_ref.c b/frame/3/bli_l3_sup_ref.c new file mode 100644 index 000000000..bf4494077 --- /dev/null +++ b/frame/3/bli_l3_sup_ref.c @@ -0,0 +1,209 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +err_t bli_gemmsup_ref + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + +#if 0 + // FGVZ: The datatype-specific variant is now responsible for checking for + // alpha == 0.0. + + // If alpha is zero, scale by beta and return. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) + { + bli_scalm( beta, c ); + return BLIS_SUCCESS; + } +#endif + +#if 0 + // FGVZ: Will this be needed for constructing thrinfo_t's (recall: the + // sba needs to be attached to the rntm; see below)? Or will those nodes + // just be created "locally," in an exposed manner? + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_GEMM, + BLIS_LEFT, // ignored for gemm/hemm/symm + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width( &a_local ), + rntm + ); + + // FGVZ: the sba needs to be attached to the rntm. But it needs + // to be done in the thread region, since it needs a thread id. + //bli_sba_rntm_set_pool( tid, array, rntm_p ); +#endif + +#if 0 + // FGVZ: The datatype-specific variant is now responsible for inducing a + // transposition, if needed. + + // Induce transpositions on A and/or B if either object is marked for + // transposition. We can induce "fast" transpositions since they objects + // are guaranteed to not have structure or be packed. + if ( bli_obj_has_trans( a ) ) + { + bli_obj_induce_fast_trans( a ); + bli_obj_toggle_trans( a ); + } + if ( bli_obj_has_trans( b ) ) + { + bli_obj_induce_fast_trans( b ); + bli_obj_toggle_trans( b ); + } +#endif + +#if 0 + //bli_gemmsup_ref_var2 + //bli_gemmsup_ref_var1 + #if 0 + bli_gemmsup_ref_var1n + #else + #endif + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + const bool_t is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR ); + if ( is_rrr_rrc_rcr_crr ) + { + bli_gemmsup_ref_var2m + ( + BLIS_NO_TRANSPOSE, alpha, a, b, beta, c, stor_id, cntx, rntm + ); + } + else + { + bli_gemmsup_ref_var2m + ( + BLIS_TRANSPOSE, alpha, a, b, beta, c, stor_id, cntx, rntm + ); + } +#else + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + + // Don't use the small/unpacked implementation if one of the matrices + // uses general stride. + if ( stor_id == BLIS_XXX ) return BLIS_FAILURE; + + const bool_t is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR ); + const bool_t is_rcc_crc_ccr_ccc = !is_rrr_rrc_rcr_crr; + + const num_t dt = bli_obj_dt( c ); + const bool_t row_pref = bli_cntx_l3_sup_ker_prefers_rows_dt( dt, stor_id, cntx ); + + const bool_t is_primary = ( row_pref ? is_rrr_rrc_rcr_crr + : is_rcc_crc_ccr_ccc ); + + if ( is_primary ) + { + // This branch handles: + // - rrr rrc rcr crr for row-preferential kernels + // - rcc crc ccr ccc for column-preferential kernels + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t mu = m / MR; + const dim_t nu = n / NR; + + if ( mu >= nu ) + { + // block-panel macrokernel; m -> mc, mr; n -> nc, nr: var2() + bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, stor_id, cntx, rntm ); + } + else // if ( mu < nu ) + { + // panel-block macrokernel; m -> nc*,mr; n -> mc*,nr: var1() + bli_gemmsup_ref_var1n( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, stor_id, cntx, rntm ); + } + } + else + { + // This branch handles: + // - rrr rrc rcr crr for column-preferential kernels + // - rcc crc ccr ccc for row-preferential kernels + + const dim_t mt = bli_obj_width( c ); + const dim_t nt = bli_obj_length( c ); + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t mu = mt / MR; + const dim_t nu = nt / NR; + + if ( mu >= nu ) + { + // panel-block macrokernel; m -> nc, nr; n -> mc, mr: var2() + trans + bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, + alpha, a, b, beta, c, stor_id, cntx, rntm ); + } + else // if ( mu < nu ) + { + // block-panel macrokernel; m -> mc*,nr; n -> nc*,mr: var1() + trans + bli_gemmsup_ref_var1n( BLIS_TRANSPOSE, + alpha, a, b, beta, c, stor_id, cntx, rntm ); + } + // *requires nudging of mc,nc up to be a multiple of nr,mr. + } +#endif + + // Return success so that the caller knows that we computed the solution. + return BLIS_SUCCESS; +} + diff --git a/frame/3/bli_l3_sup_ref.h b/frame/3/bli_l3_sup_ref.h new file mode 100644 index 000000000..5d400a985 --- /dev/null +++ b/frame/3/bli_l3_sup_ref.h @@ -0,0 +1,45 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +err_t bli_gemmsup_ref + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + diff --git a/frame/3/bli_l3_sup_var12.c b/frame/3/bli_l3_sup_var12.c new file mode 100644 index 000000000..106ad86e4 --- /dev/null +++ b/frame/3/bli_l3_sup_var12.c @@ -0,0 +1,735 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmsup_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + stor3_t eff_id, + cntx_t* restrict cntx, + rntm_t* restrict rntm + ); + +#if 0 +// +// -- var2 --------------------------------------------------------------------- +// + +static FUNCPTR_T GENARRAY(ftypes_var2,gemmsup_ref_var2); + +void bli_gemmsup_ref_var2 + ( + trans_t trans, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + stor3_t eff_id, + cntx_t* cntx, + rntm_t* rntm + ) +{ +#if 0 + obj_t at, bt; + + bli_obj_alias_to( a, &at ); + bli_obj_alias_to( b, &bt ); + + // Induce transpositions on A and/or B if either object is marked for + // transposition. We can induce "fast" transpositions since they objects + // are guaranteed to not have structure or be packed. + if ( bli_obj_has_trans( &at ) ) { bli_obj_induce_fast_trans( &at ); } + if ( bli_obj_has_trans( &bt ) ) { bli_obj_induce_fast_trans( &bt ); } + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + + const dim_t k = bli_obj_width( &at ); + + void* restrict buf_a = bli_obj_buffer_at_off( &at ); + const inc_t rs_a = bli_obj_row_stride( &at ); + const inc_t cs_a = bli_obj_col_stride( &at ); + + void* restrict buf_b = bli_obj_buffer_at_off( &bt ); + const inc_t rs_b = bli_obj_row_stride( &bt ); + const inc_t cs_b = bli_obj_col_stride( &bt ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#else + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + dim_t k; + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + inc_t rs_a; + inc_t cs_a; + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b; + inc_t cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + k = bli_obj_width( a ); + + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else // if ( bli_obj_has_trans( a ) ) + { + // Assign the variables with an implicit transposition. + k = bli_obj_length( a ); + + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else // if ( bli_obj_has_trans( b ) ) + { + // Assign the variables with an implicit transposition. + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#endif + + // Index into the type combination array to extract the correct + // function pointer. + FUNCPTR_T f = ftypes_var2[dt_exec]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + eff_id, + cntx, + rntm + ); +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t eff_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm \ + ) \ +{ \ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* If alpha is zero, scale by beta and return. */ \ + if ( PASTEMAC(ch,eq0)( *(( ctype* )alpha) ) ) \ + { \ + PASTEMAC(ch,scalm) \ + ( \ + BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m, n, \ + beta, \ + c, rs_c, cs_c \ + ); \ + return; \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); \ + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c * NC; \ + const inc_t jcstep_b = cs_b * NC; \ +\ + const inc_t pcstep_a = cs_a * KC; \ + const inc_t pcstep_b = rs_b * KC; \ +\ + const inc_t icstep_c = rs_c * MC; \ + const inc_t icstep_a = rs_a * MC; \ +\ + const inc_t jrstep_c = cs_c * NR; \ + const inc_t jrstep_b = cs_b * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ + const inc_t irstep_a = rs_a * MR; \ +\ + /* Query a stor3_t enum value to characterize the problem. + Examples: BLIS_RRR, BLIS_RRC, BLIS_RCR, BLIS_RCC, etc. + NOTE: If any matrix is general-stored, we use the all-purpose sup + microkernel corresponding to the stor3_t enum value BLIS_XXX. */ \ + const stor3_t stor_id = bli_stor3_from_strides( rs_c, cs_c, \ + rs_a, cs_a, rs_b, cs_b ); \ +\ + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemmsup_ker_ft) \ + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + ctype* restrict one = PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Compute number of primary and leftover components of the outer + dimensions. + NOTE: Functionally speaking, we compute jc_iter as: + jc_iter = n / NC; if ( jc_left ) ++jc_iter; + However, this is implemented as: + jc_iter = ( n + NC - 1 ) / NC; + This avoids a branch at the cost of two additional integer instructions. + The pc_iter, mc_iter, nr_iter, and mr_iter variables are computed in + similar manner. */ \ + const dim_t jc_iter = ( n + NC - 1 ) / NC; \ + const dim_t jc_left = n % NC; \ +\ + const dim_t pc_iter = ( k + KC - 1 ) / KC; \ + const dim_t pc_left = k % KC; \ +\ + const dim_t ic_iter = ( m + MC - 1 ) / MC; \ + const dim_t ic_left = m % MC; \ +\ + const dim_t jc_inc = 1; \ + const dim_t pc_inc = 1; \ + const dim_t ic_inc = 1; \ + const dim_t jr_inc = 1; \ + const dim_t ir_inc = 1; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = 0; jj < jc_iter; jj += jc_inc ) \ + { \ + const dim_t nc_cur = ( bli_is_not_edge_f( jj, jc_iter, jc_left ) ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + const dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + const dim_t jr_left = nc_cur % NR; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = 0; pp < pc_iter; pp += pc_inc ) \ + { \ + const dim_t kc_cur = ( bli_is_not_edge_f( pp, pc_iter, pc_left ) ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? beta_cast : one ); \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = 0; ii < ic_iter; ii += ic_inc ) \ + { \ + const dim_t mc_cur = ( bli_is_not_edge_f( ii, ic_iter, ic_left ) ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + const dim_t ir_left = mc_cur % MR; \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = 0; j < jr_iter; j += jr_inc ) \ + { \ + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc + j * jrstep_b; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ +/* + ctype* restrict b2 = b_jr; \ +*/ \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = 0; i < ir_iter; i += ir_inc ) \ + { \ + const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic + i * irstep_a; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ +/* + ctype* restrict a2 = bli_gemm_get_next_a_upanel( a_ir, irstep_a, ir_inc ); \ + if ( bli_is_last_iter( i, ir_iter, 0, 1 ) ) \ + { \ + a2 = a_00; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, jrstep_b, jr_inc ); \ + if ( bli_is_last_iter( j, jr_iter, 0, 1 ) ) \ + b2 = b_00; \ + } \ +\ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +*/ \ +\ + /* Invoke the gemmsup micro-kernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a, cs_a, \ + b_jr, rs_b, cs_b, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ + } \ + } \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +} + +INSERT_GENTFUNC_BASIC0( gemmsup_ref_var2 ) + + +// +// -- var1 --------------------------------------------------------------------- +// + +static FUNCPTR_T GENARRAY(ftypes_var1,gemmsup_ref_var1); + +void bli_gemmsup_ref_var1 + ( + trans_t trans, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + stor3_t eff_id, + cntx_t* cntx, + rntm_t* rntm + ) +{ +#if 0 + obj_t at, bt; + + bli_obj_alias_to( a, &at ); + bli_obj_alias_to( b, &bt ); + + // Induce transpositions on A and/or B if either object is marked for + // transposition. We can induce "fast" transpositions since they objects + // are guaranteed to not have structure or be packed. + if ( bli_obj_has_trans( &at ) ) { bli_obj_induce_fast_trans( &at ); } + if ( bli_obj_has_trans( &bt ) ) { bli_obj_induce_fast_trans( &bt ); } + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + + const dim_t k = bli_obj_width( &at ); + + void* restrict buf_a = bli_obj_buffer_at_off( &at ); + const inc_t rs_a = bli_obj_row_stride( &at ); + const inc_t cs_a = bli_obj_col_stride( &at ); + + void* restrict buf_b = bli_obj_buffer_at_off( &bt ); + const inc_t rs_b = bli_obj_row_stride( &bt ); + const inc_t cs_b = bli_obj_col_stride( &bt ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#else + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + dim_t k; + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + inc_t rs_a; + inc_t cs_a; + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b; + inc_t cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + k = bli_obj_width( a ); + + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else // if ( bli_obj_has_trans( a ) ) + { + // Assign the variables with an implicit transposition. + k = bli_obj_length( a ); + + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else // if ( bli_obj_has_trans( b ) ) + { + // Assign the variables with an implicit transposition. + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#endif + + // Index into the type combination array to extract the correct + // function pointer. + FUNCPTR_T f = ftypes_var1[dt_exec]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + eff_id, + cntx, + rntm + ); +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t eff_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm \ + ) \ +{ \ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* If alpha is zero, scale by beta and return. */ \ + if ( PASTEMAC(ch,eq0)( *(( ctype* )alpha) ) ) \ + { \ + PASTEMAC(ch,scalm) \ + ( \ + BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m, n, \ + beta, \ + c, rs_c, cs_c \ + ); \ + return; \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); \ + const dim_t MC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); \ +\ + /* Nudge NC up to a multiple of MR and MC up to a multiple of NR. */ \ + const dim_t NC = bli_align_dim_to_mult( NC0, MR ); \ + const dim_t MC = bli_align_dim_to_mult( MC0, NR ); \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = rs_c * NC; \ + const inc_t jcstep_a = rs_a * NC; \ +\ + const inc_t pcstep_a = cs_a * KC; \ + const inc_t pcstep_b = rs_b * KC; \ +\ + const inc_t icstep_c = cs_c * MC; \ + const inc_t icstep_b = cs_b * MC; \ +\ + const inc_t jrstep_c = rs_c * MR; \ + const inc_t jrstep_a = rs_a * MR; \ +\ + const inc_t irstep_c = cs_c * NR; \ + const inc_t irstep_b = cs_b * NR; \ +\ + /* Query a stor3_t enum value to characterize the problem. + Examples: BLIS_RRR, BLIS_RRC, BLIS_RCR, BLIS_RCC, etc. + NOTE: If any matrix is general-stored, we use the all-purpose sup + microkernel corresponding to the stor3_t enum value BLIS_XXX. */ \ + const stor3_t stor_id = bli_stor3_from_strides( rs_c, cs_c, \ + rs_a, cs_a, rs_b, cs_b ); \ +\ + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemmsup_ker_ft) \ + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + ctype* restrict one = PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Compute number of primary and leftover components of the outer + dimensions. + NOTE: Functionally speaking, we compute jc_iter as: + jc_iter = m / NC; if ( jc_left ) ++jc_iter; + However, this is implemented as: + jc_iter = ( m + NC - 1 ) / NC; + This avoids a branch at the cost of two additional integer instructions. + The pc_iter, mc_iter, nr_iter, and mr_iter variables are computed in + similar manner. */ \ + const dim_t jc_iter = ( m + NC - 1 ) / NC; \ + const dim_t jc_left = m % NC; \ +\ + const dim_t pc_iter = ( k + KC - 1 ) / KC; \ + const dim_t pc_left = k % KC; \ +\ + const dim_t ic_iter = ( n + MC - 1 ) / MC; \ + const dim_t ic_left = n % MC; \ +\ + const dim_t jc_inc = 1; \ + const dim_t pc_inc = 1; \ + const dim_t ic_inc = 1; \ + const dim_t jr_inc = 1; \ + const dim_t ir_inc = 1; \ +\ + /* Loop over the m dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = 0; jj < jc_iter; jj += jc_inc ) \ + { \ + const dim_t nc_cur = ( bli_is_not_edge_f( jj, jc_iter, jc_left ) ? NC : jc_left ); \ +\ + ctype* restrict a_jc = a_00 + jj * jcstep_a; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + const dim_t jr_iter = ( nc_cur + MR - 1 ) / MR; \ + const dim_t jr_left = nc_cur % MR; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = 0; pp < pc_iter; pp += pc_inc ) \ + { \ + const dim_t kc_cur = ( bli_is_not_edge_f( pp, pc_iter, pc_left ) ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_jc + pp * pcstep_a; \ + ctype* restrict b_pc = b_00 + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? beta_cast : one ); \ +\ + /* Loop over the n dimension (MC rows at a time). */ \ + for ( dim_t ii = 0; ii < ic_iter; ii += ic_inc ) \ + { \ + const dim_t mc_cur = ( bli_is_not_edge_f( ii, ic_iter, ic_left ) ? MC : ic_left ); \ +\ + ctype* restrict b_ic = b_pc + ii * icstep_b; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + const dim_t ir_iter = ( mc_cur + NR - 1 ) / NR; \ + const dim_t ir_left = mc_cur % NR; \ +\ + /* Loop over the m dimension (NR columns at a time). */ \ + for ( dim_t j = 0; j < jr_iter; j += jr_inc ) \ + { \ + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict a_jr = a_pc + j * jrstep_a; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Loop over the n dimension (MR rows at a time). */ \ + for ( dim_t i = 0; i < ir_iter; i += ir_inc ) \ + { \ + const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict b_ir = b_ic + i * irstep_b; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + /* Invoke the gemmsup micro-kernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_jr, rs_a, cs_a, \ + b_ir, rs_b, cs_b, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ + } \ + } \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +} + +INSERT_GENTFUNC_BASIC0( gemmsup_ref_var1 ) +#endif + + diff --git a/frame/3/bli_l3_sup_var1n2m.c b/frame/3/bli_l3_sup_var1n2m.c new file mode 100644 index 000000000..df4c42526 --- /dev/null +++ b/frame/3/bli_l3_sup_var1n2m.c @@ -0,0 +1,803 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmsup_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + stor3_t eff_id, + cntx_t* restrict cntx, + rntm_t* restrict rntm + ); + +// +// -- var1n -------------------------------------------------------------------- +// + +static FUNCPTR_T GENARRAY(ftypes_var1n,gemmsup_ref_var1n); + +void bli_gemmsup_ref_var1n + ( + trans_t trans, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + stor3_t eff_id, + cntx_t* cntx, + rntm_t* rntm + ) +{ +#if 0 + obj_t at, bt; + + bli_obj_alias_to( a, &at ); + bli_obj_alias_to( b, &bt ); + + // Induce transpositions on A and/or B if either object is marked for + // transposition. We can induce "fast" transpositions since they objects + // are guaranteed to not have structure or be packed. + if ( bli_obj_has_trans( &at ) ) { bli_obj_induce_fast_trans( &at ); } + if ( bli_obj_has_trans( &bt ) ) { bli_obj_induce_fast_trans( &bt ); } + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + + const dim_t k = bli_obj_width( &at ); + + void* restrict buf_a = bli_obj_buffer_at_off( &at ); + const inc_t rs_a = bli_obj_row_stride( &at ); + const inc_t cs_a = bli_obj_col_stride( &at ); + + void* restrict buf_b = bli_obj_buffer_at_off( &bt ); + const inc_t rs_b = bli_obj_row_stride( &bt ); + const inc_t cs_b = bli_obj_col_stride( &bt ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#else + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + dim_t k; + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + inc_t rs_a; + inc_t cs_a; + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b; + inc_t cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + k = bli_obj_width( a ); + + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else // if ( bli_obj_has_trans( a ) ) + { + // Assign the variables with an implicit transposition. + k = bli_obj_length( a ); + + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else // if ( bli_obj_has_trans( b ) ) + { + // Assign the variables with an implicit transposition. + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#endif + + // Index into the type combination array to extract the correct + // function pointer. + FUNCPTR_T f = ftypes_var1n[dt_exec]; + + if ( bli_is_notrans( trans ) ) + { + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + eff_id, + cntx, + rntm + ); + } + else + { + // Invoke the function (transposing the operation). + f + ( + conjb, // swap the conj values. + conja, + n, // swap the m and n dimensions. + m, + k, + buf_alpha, + buf_b, cs_b, rs_b, // swap the positions of A and B. + buf_a, cs_a, rs_a, // swap the strides of A and B. + buf_beta, + buf_c, cs_c, rs_c, // swap the strides of C. + bli_stor3_trans( eff_id ), // transpose the stor3_t id. + cntx, + rntm + ); + } +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t stor_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm \ + ) \ +{ \ + /* If m or n is zero, return immediately. */ \ + if ( bli_zero_dim2( m, n ) ) return; \ +\ + /* If k < 1 or alpha is zero, scale by beta and return. */ \ + if ( k < 1 || PASTEMAC(ch,eq0)( *(( ctype* )alpha) ) ) \ + { \ + PASTEMAC(ch,scalm) \ + ( \ + BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m, n, \ + beta, \ + c, rs_c, cs_c \ + ); \ + return; \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* This transposition of the stor3_t id value is inherent to variant 1. + The reason: we assume that variant 2 is the "main" variant. The + consequence of this is that we assume that the millikernels that + iterate over m are registered to the kernel group associated with + the kernel preference. So, regardless of whether the mkernels are + row- or column-preferential, millikernels that iterate over n are + always placed in the slots for the opposite kernel group. */ \ + stor_id = bli_stor3_trans( stor_id ); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + dim_t KC; \ + if ( FALSE ) KC = KC0; \ + else if ( stor_id == BLIS_RRC || \ + stor_id == BLIS_CRC ) KC = KC0; \ + else if ( m <= MR && n <= NR ) KC = KC0; \ + else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; \ + else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; \ + else if ( m <= 4*MR && n <= 4*NR ) KC = KC0 / 4; \ + else KC = (( KC0 / 5 ) / 4 ) * 4; \ +\ + /* Nudge NC up to a multiple of MR and MC up to a multiple of NR. */ \ + const dim_t NC = bli_align_dim_to_mult( NC0, MR ); \ + const dim_t MC = bli_align_dim_to_mult( MC0, NR ); \ +\ + /* Query the maximum blocksize for MR, which implies a maximum blocksize + extension for the final iteration. */ \ + const dim_t MRM = bli_cntx_get_l3_sup_blksz_max_dt( dt, BLIS_MR, cntx ); \ + const dim_t MRE = MRM - MR; \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = rs_c * NC; \ + const inc_t jcstep_a = rs_a * NC; \ +\ + const inc_t pcstep_a = cs_a * KC; \ + const inc_t pcstep_b = rs_b * KC; \ +\ + const inc_t icstep_c = cs_c * MC; \ + const inc_t icstep_b = cs_b * MC; \ +\ + const inc_t jrstep_c = rs_c * MR; \ + const inc_t jrstep_a = rs_a * MR; \ +\ + /* + const inc_t irstep_c = cs_c * NR; \ + const inc_t irstep_b = cs_b * NR; \ + */ \ +\ + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemmsup_ker_ft) \ + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + ctype* restrict one = PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Compute number of primary and leftover components of the outer + dimensions. + NOTE: Functionally speaking, we compute jc_iter as: + jc_iter = m / NC; if ( jc_left ) ++jc_iter; + However, this is implemented as: + jc_iter = ( m + NC - 1 ) / NC; + This avoids a branch at the cost of two additional integer instructions. + The pc_iter, mc_iter, nr_iter, and mr_iter variables are computed in + similar manner. */ \ + const dim_t jc_iter = ( m + NC - 1 ) / NC; \ + const dim_t jc_left = m % NC; \ +\ + const dim_t pc_iter = ( k + KC - 1 ) / KC; \ + const dim_t pc_left = k % KC; \ +\ + const dim_t ic_iter = ( n + MC - 1 ) / MC; \ + const dim_t ic_left = n % MC; \ +\ + const dim_t jc_inc = 1; \ + const dim_t pc_inc = 1; \ + const dim_t ic_inc = 1; \ + const dim_t jr_inc = 1; \ + /* + const dim_t ir_inc = 1; \ + */ \ +\ + /* Loop over the m dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = 0; jj < jc_iter; jj += jc_inc ) \ + { \ + const dim_t nc_cur = ( bli_is_not_edge_f( jj, jc_iter, jc_left ) ? NC : jc_left ); \ +\ + ctype* restrict a_jc = a_00 + jj * jcstep_a; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + dim_t jr_iter = ( nc_cur + MR - 1 ) / MR; \ + dim_t jr_left = nc_cur % MR; \ +\ + /* An optimization: allow the last jr iteration to contain up to MRX + rows of C and A. (If MRX > MR, the mkernel has agreed to handle + these cases.) Note that this prevents us from declaring jr_iter and + jr_left as const. */ \ + if ( 1 ) \ + if ( MRE != 0 && 1 < jr_iter && jr_left != 0 && jr_left <= MRE ) \ + { \ + jr_iter--; jr_left += MR; \ + } \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = 0; pp < pc_iter; pp += pc_inc ) \ + { \ + const dim_t kc_cur = ( bli_is_not_edge_f( pp, pc_iter, pc_left ) ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_jc + pp * pcstep_a; \ + ctype* restrict b_pc = b_00 + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? beta_cast : one ); \ +\ + /* Loop over the n dimension (MC rows at a time). */ \ + for ( dim_t ii = 0; ii < ic_iter; ii += ic_inc ) \ + { \ + const dim_t mc_cur = ( bli_is_not_edge_f( ii, ic_iter, ic_left ) ? MC : ic_left ); \ +\ + ctype* restrict b_ic = b_pc + ii * icstep_b; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + /* + const dim_t ir_iter = ( mc_cur + NR - 1 ) / NR; \ + const dim_t ir_left = mc_cur % NR; \ + */ \ +\ + /* Loop over the m dimension (NR columns at a time). */ \ + for ( dim_t j = 0; j < jr_iter; j += jr_inc ) \ + { \ + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? MR : jr_left ); \ +\ + ctype* restrict a_jr = a_pc + j * jrstep_a; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Loop over the n dimension (MR rows at a time). */ \ + { \ + /* Invoke the gemmsup millikernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + nr_cur, /* Notice: nr_cur <= MR. */ \ + mc_cur, /* Recall: mc_cur partitions the n dimension! */ \ + kc_cur, \ + alpha_cast, \ + a_jr, rs_a, cs_a, \ + b_ic, rs_b, cs_b, \ + beta_use, \ + c_jr, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ + } \ + } \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +} + +INSERT_GENTFUNC_BASIC0( gemmsup_ref_var1n ) + + +// +// -- var2m -------------------------------------------------------------------- +// + +static FUNCPTR_T GENARRAY(ftypes_var2m,gemmsup_ref_var2m); + +void bli_gemmsup_ref_var2m + ( + trans_t trans, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + stor3_t eff_id, + cntx_t* cntx, + rntm_t* rntm + ) +{ +#if 0 + obj_t at, bt; + + bli_obj_alias_to( a, &at ); + bli_obj_alias_to( b, &bt ); + + // Induce transpositions on A and/or B if either object is marked for + // transposition. We can induce "fast" transpositions since they objects + // are guaranteed to not have structure or be packed. + if ( bli_obj_has_trans( &at ) ) { bli_obj_induce_fast_trans( &at ); } + if ( bli_obj_has_trans( &bt ) ) { bli_obj_induce_fast_trans( &bt ); } + + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + + const dim_t k = bli_obj_width( &at ); + + void* restrict buf_a = bli_obj_buffer_at_off( &at ); + const inc_t rs_a = bli_obj_row_stride( &at ); + const inc_t cs_a = bli_obj_col_stride( &at ); + + void* restrict buf_b = bli_obj_buffer_at_off( &bt ); + const inc_t rs_b = bli_obj_row_stride( &bt ); + const inc_t cs_b = bli_obj_col_stride( &bt ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#else + const num_t dt_exec = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + dim_t k; + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + inc_t rs_a; + inc_t cs_a; + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b; + inc_t cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + k = bli_obj_width( a ); + + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else // if ( bli_obj_has_trans( a ) ) + { + // Assign the variables with an implicit transposition. + k = bli_obj_length( a ); + + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else // if ( bli_obj_has_trans( b ) ) + { + // Assign the variables with an implicit transposition. + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt_exec, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt_exec, beta ); + +#endif + + // Index into the type combination array to extract the correct + // function pointer. + FUNCPTR_T f = ftypes_var2m[dt_exec]; + + if ( bli_is_notrans( trans ) ) + { + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + eff_id, + cntx, + rntm + ); + } + else + { + // Invoke the function (transposing the operation). + f + ( + conjb, // swap the conj values. + conja, + n, // swap the m and n dimensions. + m, + k, + buf_alpha, + buf_b, cs_b, rs_b, // swap the positions of A and B. + buf_a, cs_a, rs_a, // swap the strides of A and B. + buf_beta, + buf_c, cs_c, rs_c, // swap the strides of C. + bli_stor3_trans( eff_id ), // transpose the stor3_t id. + cntx, + rntm + ); + } +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t stor_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm \ + ) \ +{ \ + /* If m or n is zero, return immediately. */ \ + if ( bli_zero_dim2( m, n ) ) return; \ +\ + /* If k < 1 or alpha is zero, scale by beta and return. */ \ + if ( k < 1 || PASTEMAC(ch,eq0)( *(( ctype* )alpha) ) ) \ + { \ + PASTEMAC(ch,scalm) \ + ( \ + BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m, n, \ + beta, \ + c, rs_c, cs_c \ + ); \ + return; \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC0 = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + dim_t KC; \ + if ( stor_id == BLIS_RRR || \ + stor_id == BLIS_CCC ) KC = KC0; \ + else if ( stor_id == BLIS_RRC || \ + stor_id == BLIS_CRC ) KC = KC0; \ + else if ( m <= MR && n <= NR ) KC = KC0; \ + else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; \ + else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; \ + else if ( m <= 4*MR && n <= 4*NR ) KC = KC0 / 4; \ + else KC = (( KC0 / 5 ) / 4 ) * 4; \ +\ + /* Query the maximum blocksize for NR, which implies a maximum blocksize + extension for the final iteration. */ \ + const dim_t NRM = bli_cntx_get_l3_sup_blksz_max_dt( dt, BLIS_NR, cntx ); \ + const dim_t NRE = NRM - NR; \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c * NC; \ + const inc_t jcstep_b = cs_b * NC; \ +\ + const inc_t pcstep_a = cs_a * KC; \ + const inc_t pcstep_b = rs_b * KC; \ +\ + const inc_t icstep_c = rs_c * MC; \ + const inc_t icstep_a = rs_a * MC; \ +\ + const inc_t jrstep_c = cs_c * NR; \ + const inc_t jrstep_b = cs_b * NR; \ +\ + /* + const inc_t irstep_c = rs_c * MR; \ + const inc_t irstep_a = rs_a * MR; \ + */ \ +\ + /* Query the context for the sup microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemmsup_ker_ft) \ + gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + ctype* restrict one = PASTEMAC(ch,1); \ +\ + auxinfo_t aux; \ +\ + /* Compute number of primary and leftover components of the outer + dimensions. + NOTE: Functionally speaking, we compute jc_iter as: + jc_iter = n / NC; if ( jc_left ) ++jc_iter; + However, this is implemented as: + jc_iter = ( n + NC - 1 ) / NC; + This avoids a branch at the cost of two additional integer instructions. + The pc_iter, mc_iter, nr_iter, and mr_iter variables are computed in + similar manner. */ \ + const dim_t jc_iter = ( n + NC - 1 ) / NC; \ + const dim_t jc_left = n % NC; \ +\ + const dim_t pc_iter = ( k + KC - 1 ) / KC; \ + const dim_t pc_left = k % KC; \ +\ + const dim_t ic_iter = ( m + MC - 1 ) / MC; \ + const dim_t ic_left = m % MC; \ +\ + const dim_t jc_inc = 1; \ + const dim_t pc_inc = 1; \ + const dim_t ic_inc = 1; \ + const dim_t jr_inc = 1; \ + /* + const dim_t ir_inc = 1; \ + */ \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = 0; jj < jc_iter; jj += jc_inc ) \ + { \ + const dim_t nc_cur = ( bli_is_not_edge_f( jj, jc_iter, jc_left ) ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* An optimization: allow the last jr iteration to contain up to NRX + columns of C and B. (If NRX > NR, the mkernel has agreed to handle + these cases.) Note that this prevents us from declaring jr_iter and + jr_left as const. */ \ + if ( 1 ) \ + if ( NRE != 0 && 1 < jr_iter && jr_left != 0 && jr_left <= NRE ) \ + { \ + jr_iter--; jr_left += NR; \ + } \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = 0; pp < pc_iter; pp += pc_inc ) \ + { \ + const dim_t kc_cur = ( bli_is_not_edge_f( pp, pc_iter, pc_left ) ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? beta_cast : one ); \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = 0; ii < ic_iter; ii += ic_inc ) \ + { \ + const dim_t mc_cur = ( bli_is_not_edge_f( ii, ic_iter, ic_left ) ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + /* + const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + const dim_t ir_left = mc_cur % MR; \ + */ \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = 0; j < jr_iter; j += jr_inc ) \ + { \ + const dim_t nr_cur = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc + j * jrstep_b; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + { \ + /* Invoke the gemmsup millikernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mc_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ic, rs_a, cs_a, \ + b_jr, rs_b, cs_b, \ + beta_use, \ + c_jr, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ + } \ + } \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: b1", kc_cur, nr_cur, b_jr, rs_b, cs_b, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: a1", mr_cur, kc_cur, a_ir, rs_a, cs_a, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +} + +INSERT_GENTFUNC_BASIC0( gemmsup_ref_var2m ) + diff --git a/frame/3/bli_l3_sup_vars.h b/frame/3/bli_l3_sup_vars.h new file mode 100644 index 000000000..bc1d2b3ef --- /dev/null +++ b/frame/3/bli_l3_sup_vars.h @@ -0,0 +1,92 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype object-based interfaces. +// + +#undef GENPROT +#define GENPROT( opname ) \ +\ +void PASTEMAC0(opname) \ + ( \ + trans_t trans, \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + stor3_t eff_id, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ); + +GENPROT( gemmsup_ref_var1 ) +GENPROT( gemmsup_ref_var2 ) + +GENPROT( gemmsup_ref_var1n ) +GENPROT( gemmsup_ref_var2m ) + + +// +// Prototype BLAS-like interfaces with void pointer operands. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + stor3_t eff_id, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm \ + ); + +INSERT_GENTPROT_BASIC0( gemmsup_ref_var1 ) +INSERT_GENTPROT_BASIC0( gemmsup_ref_var2 ) + +INSERT_GENTPROT_BASIC0( gemmsup_ref_var1n ) +INSERT_GENTPROT_BASIC0( gemmsup_ref_var2m ) + diff --git a/frame/3/bli_l3_tapi.c b/frame/3/bli_l3_tapi.c index 4eeba1971..7b7f758ab 100644 --- a/frame/3/bli_l3_tapi.c +++ b/frame/3/bli_l3_tapi.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -64,7 +65,11 @@ void PASTEMAC2(ch,opname,EX_SUF) \ \ const num_t dt = PASTEMAC(ch,type); \ \ - obj_t alphao, ao, bo, betao, co; \ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ \ dim_t m_a, n_a; \ dim_t m_b, n_b; \ @@ -72,12 +77,12 @@ void PASTEMAC2(ch,opname,EX_SUF) \ bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ \ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ \ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, n, c, rs_c, cs_c, &co ); \ \ bli_obj_set_conjtrans( transa, &ao ); \ bli_obj_set_conjtrans( transb, &bo ); \ @@ -122,7 +127,11 @@ void PASTEMAC2(ch,opname,EX_SUF) \ \ const num_t dt = PASTEMAC(ch,type); \ \ - obj_t alphao, ao, bo, betao, co; \ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ \ dim_t mn_a; \ dim_t m_b, n_b; \ @@ -130,12 +139,12 @@ void PASTEMAC2(ch,opname,EX_SUF) \ bli_set_dim_with_side( side, m, n, &mn_a ); \ bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ \ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ \ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ + bli_obj_init_finish( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, n, c, rs_c, cs_c, &co ); \ \ bli_obj_set_uplo( uploa, &ao ); \ bli_obj_set_conj( conja, &ao ); \ @@ -183,17 +192,20 @@ void PASTEMAC2(ch,opname,EX_SUF) \ const num_t dt_r = PASTEMAC(chr,type); \ const num_t dt = PASTEMAC(ch,type); \ \ - obj_t alphao, ao, betao, co; \ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ \ dim_t m_a, n_a; \ \ bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ \ - bli_obj_create_1x1_with_attached_buffer( dt_r, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt_r, beta, &betao ); \ + bli_obj_init_finish_1x1( dt_r, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt_r, beta, &betao ); \ \ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ \ bli_obj_set_uplo( uploc, &co ); \ bli_obj_set_conjtrans( transa, &ao ); \ @@ -239,7 +251,11 @@ void PASTEMAC2(ch,opname,EX_SUF) \ const num_t dt_r = PASTEMAC(chr,type); \ const num_t dt = PASTEMAC(ch,type); \ \ - obj_t alphao, ao, bo, betao, co; \ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ \ dim_t m_a, n_a; \ dim_t m_b, n_b; \ @@ -247,12 +263,12 @@ void PASTEMAC2(ch,opname,EX_SUF) \ bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ \ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt_r, beta, &betao ); \ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt_r, beta, &betao ); \ \ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ \ bli_obj_set_uplo( uploc, &co ); \ bli_obj_set_conjtrans( transa, &ao ); \ @@ -297,17 +313,20 @@ void PASTEMAC2(ch,opname,EX_SUF) \ \ const num_t dt = PASTEMAC(ch,type); \ \ - obj_t alphao, ao, betao, co; \ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ \ dim_t m_a, n_a; \ \ bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ \ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ \ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ \ bli_obj_set_uplo( uploc, &co ); \ bli_obj_set_conjtrans( transa, &ao ); \ @@ -352,7 +371,11 @@ void PASTEMAC2(ch,opname,EX_SUF) \ \ const num_t dt = PASTEMAC(ch,type); \ \ - obj_t alphao, ao, bo, betao, co; \ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ \ dim_t m_a, n_a; \ dim_t m_b, n_b; \ @@ -360,12 +383,12 @@ void PASTEMAC2(ch,opname,EX_SUF) \ bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ bli_set_dims_with_trans( transb, m, k, &m_b, &n_b ); \ \ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ \ - bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, m, c, rs_c, cs_c, &co ); \ + bli_obj_init_finish( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, m, c, rs_c, cs_c, &co ); \ \ bli_obj_set_uplo( uploc, &co ); \ bli_obj_set_conjtrans( transa, &ao ); \ @@ -414,7 +437,11 @@ void PASTEMAC2(ch,opname,EX_SUF) \ \ const num_t dt = PASTEMAC(ch,type); \ \ - obj_t alphao, ao, bo, betao, co; \ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ \ dim_t mn_a; \ dim_t m_b, n_b; \ @@ -422,12 +449,12 @@ void PASTEMAC2(ch,opname,EX_SUF) \ bli_set_dim_with_side( side, m, n, &mn_a ); \ bli_set_dims_with_trans( transb, m, n, &m_b, &n_b ); \ \ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ - bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, beta, &betao ); \ \ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ - bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ + bli_obj_init_finish( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m, n, c, rs_c, cs_c, &co ); \ \ bli_obj_set_uplo( uploa, &ao ); \ bli_obj_set_diag( diaga, &ao ); \ @@ -475,16 +502,18 @@ void PASTEMAC2(ch,opname,EX_SUF) \ \ const num_t dt = PASTEMAC(ch,type); \ \ - obj_t alphao, ao, bo; \ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ \ dim_t mn_a; \ \ bli_set_dim_with_side( side, m, n, &mn_a ); \ \ - bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, alpha, &alphao ); \ \ - bli_obj_create_with_attached_buffer( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ - bli_obj_create_with_attached_buffer( dt, m, n, b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, mn_a, mn_a, a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m, n, b, rs_b, cs_b, &bo ); \ \ bli_obj_set_uplo( uploa, &ao ); \ bli_obj_set_diag( diaga, &ao ); \ diff --git a/frame/3/bli_l3_tapi.h b/frame/3/bli_l3_tapi.h index 4ae9d6921..a809c2a68 100644 --- a/frame/3/bli_l3_tapi.h +++ b/frame/3/bli_l3_tapi.h @@ -40,7 +40,7 @@ #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ trans_t transa, \ trans_t transb, \ @@ -61,7 +61,7 @@ INSERT_GENTPROT_BASIC0( gemm ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ side_t side, \ uplo_t uploa, \ @@ -84,7 +84,7 @@ INSERT_GENTPROT_BASIC0( symm ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -103,7 +103,7 @@ INSERT_GENTPROTR_BASIC0( herk ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -124,7 +124,7 @@ INSERT_GENTPROTR_BASIC0( her2k ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -143,7 +143,7 @@ INSERT_GENTPROT_BASIC0( syrk ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ uplo_t uploc, \ trans_t transa, \ @@ -164,7 +164,7 @@ INSERT_GENTPROT_BASIC0( syr2k ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ side_t side, \ uplo_t uploa, \ @@ -187,7 +187,7 @@ INSERT_GENTPROT_BASIC0( trmm3 ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ side_t side, \ uplo_t uploa, \ diff --git a/frame/3/bli_l3_thrinfo.c b/frame/3/bli_l3_thrinfo.c index 1dbd101f3..0eaf10840 100644 --- a/frame/3/bli_l3_thrinfo.c +++ b/frame/3/bli_l3_thrinfo.c @@ -99,35 +99,84 @@ void bli_l3_thrinfo_print_gemm_paths thrinfo_t** threads ) { + // In order to query the number of threads, we query the only thread we + // know exists: thread 0. dim_t n_threads = bli_thread_num_threads( threads[0] ); - dim_t gl_id; - thrinfo_t* jc_info = threads[0]; - thrinfo_t* pc_info = bli_thrinfo_sub_node( jc_info ); - thrinfo_t* pb_info = bli_thrinfo_sub_node( pc_info ); - thrinfo_t* ic_info = bli_thrinfo_sub_node( pb_info ); - thrinfo_t* pa_info = bli_thrinfo_sub_node( ic_info ); - thrinfo_t* jr_info = bli_thrinfo_sub_node( pa_info ); - thrinfo_t* ir_info = bli_thrinfo_sub_node( jr_info ); + // For the purposes of printing the "header" information that is common + // to the various instances of a thrinfo_t (ie: across all threads), we + // choose the last thread in case the problem is so small that there is + // only an "edge" case, which will always be assigned to the last thread + // (at least for higher levels of partitioning). + thrinfo_t* jc_info = threads[n_threads-1]; + thrinfo_t* pc_info = NULL; + thrinfo_t* pb_info = NULL; + thrinfo_t* ic_info = NULL; + thrinfo_t* pa_info = NULL; + thrinfo_t* jr_info = NULL; + thrinfo_t* ir_info = NULL; - dim_t jc_way = bli_thread_n_way( jc_info ); - dim_t pc_way = bli_thread_n_way( pc_info ); - dim_t pb_way = bli_thread_n_way( pb_info ); - dim_t ic_way = bli_thread_n_way( ic_info ); - dim_t pa_way = bli_thread_n_way( pa_info ); - dim_t jr_way = bli_thread_n_way( jr_info ); - dim_t ir_way = bli_thread_n_way( ir_info ); + // Initialize the n_ways and n_threads fields of each thrinfo_t "level" + // to -1. More than likely, these will all be overwritten with meaningful + // values, but in case some thrinfo_t trees are not fully built (see + // next commnet), these will be the placeholder values. + dim_t jc_way = -1, pc_way = -1, pb_way = -1, ic_way = -1, + pa_way = -1, jr_way = -1, ir_way = -1; - dim_t jc_nt = bli_thread_num_threads( jc_info ); - dim_t pc_nt = bli_thread_num_threads( pc_info ); - dim_t pb_nt = bli_thread_num_threads( pb_info ); - dim_t ic_nt = bli_thread_num_threads( ic_info ); - dim_t pa_nt = bli_thread_num_threads( pa_info ); - dim_t jr_nt = bli_thread_num_threads( jr_info ); - dim_t ir_nt = bli_thread_num_threads( ir_info ); + dim_t jc_nt = -1, pc_nt = -1, pb_nt = -1, ic_nt = -1, + pa_nt = -1, jr_nt = -1, ir_nt = -1; + + // NOTE: We must check each thrinfo_t pointer for NULLness. Certain threads + // may not fully build their thrinfo_t structures--specifically when the + // dimension being parallelized is not large enough for each thread to have + // even one unit of work (where as unit is usually a single micropanel's + // width, MR or NR). + + if ( !jc_info ) goto print_header; + + jc_way = bli_thread_n_way( jc_info ); + jc_nt = bli_thread_num_threads( jc_info ); + pc_info = bli_thrinfo_sub_node( jc_info ); + + if ( !pc_info ) goto print_header; + + pc_way = bli_thread_n_way( pc_info ); + pc_nt = bli_thread_num_threads( pc_info ); + pb_info = bli_thrinfo_sub_node( pc_info ); + + if ( !pb_info ) goto print_header; + + pb_way = bli_thread_n_way( pb_info ); + pb_nt = bli_thread_num_threads( pb_info ); + ic_info = bli_thrinfo_sub_node( pb_info ); + + if ( !ic_info ) goto print_header; + + ic_way = bli_thread_n_way( ic_info ); + ic_nt = bli_thread_num_threads( ic_info ); + pa_info = bli_thrinfo_sub_node( ic_info ); + + if ( !pa_info ) goto print_header; + + pa_way = bli_thread_n_way( pa_info ); + pa_nt = bli_thread_num_threads( pa_info ); + jr_info = bli_thrinfo_sub_node( pa_info ); + + if ( !jr_info ) goto print_header; + + jr_way = bli_thread_n_way( jr_info ); + jr_nt = bli_thread_num_threads( jr_info ); + ir_info = bli_thrinfo_sub_node( jr_info ); + + if ( !ir_info ) goto print_header; + + ir_way = bli_thread_n_way( ir_info ); + ir_nt = bli_thread_num_threads( ir_info ); + + print_header: printf( " jc kc pb ic pa jr ir\n" ); - printf( "xx_nt: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n", + printf( "xx_nt: %4ld %4ld %4ld %4ld %4ld %4ld %4ld\n", ( unsigned long )jc_nt, ( unsigned long )pc_nt, ( unsigned long )pb_nt, @@ -135,7 +184,7 @@ void bli_l3_thrinfo_print_gemm_paths ( unsigned long )pa_nt, ( unsigned long )jr_nt, ( unsigned long )ir_nt ); - printf( "xx_way: %4lu %4lu %4lu %4lu %4lu %4lu %4lu\n", + printf( "xx_way: %4ld %4ld %4ld %4ld %4ld %4ld %4ld\n", ( unsigned long )jc_way, ( unsigned long )pc_way, ( unsigned long )pb_way, @@ -145,116 +194,59 @@ void bli_l3_thrinfo_print_gemm_paths ( unsigned long )ir_way ); printf( "============================================\n" ); - dim_t jc_comm_id; - dim_t pc_comm_id; - dim_t pb_comm_id; - dim_t ic_comm_id; - dim_t pa_comm_id; - dim_t jr_comm_id; - dim_t ir_comm_id; - - dim_t jc_work_id; - dim_t pc_work_id; - dim_t pb_work_id; - dim_t ic_work_id; - dim_t pa_work_id; - dim_t jr_work_id; - dim_t ir_work_id; - - for ( gl_id = 0; gl_id < n_threads; ++gl_id ) + for ( dim_t gl_id = 0; gl_id < n_threads; ++gl_id ) { jc_info = threads[gl_id]; - // NOTE: We must check each thrinfo_t pointer for NULLness. Certain threads - // may not fully build their thrinfo_t structures--specifically when the - // dimension being parallelized is not large enough for each thread to have - // even one unit of work (where as unit is usually a single micropanel's - // width, MR or NR). - if ( !jc_info ) - { - jc_comm_id = pc_comm_id = pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1; - jc_work_id = pc_work_id = pb_work_id = ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1; - } - else - { - jc_comm_id = bli_thread_ocomm_id( jc_info ); - jc_work_id = bli_thread_work_id( jc_info ); - pc_info = bli_thrinfo_sub_node( jc_info ); + dim_t jc_comm_id = -1, pc_comm_id = -1, pb_comm_id = -1, ic_comm_id = -1, + pa_comm_id = -1, jr_comm_id = -1, ir_comm_id = -1; - if ( !pc_info ) - { - pc_comm_id = pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1; - pc_work_id = pb_work_id = ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1; - } - else - { - pc_comm_id = bli_thread_ocomm_id( pc_info ); - pc_work_id = bli_thread_work_id( pc_info ); - pb_info = bli_thrinfo_sub_node( pc_info ); + dim_t jc_work_id = -1, pc_work_id = -1, pb_work_id = -1, ic_work_id = -1, + pa_work_id = -1, jr_work_id = -1, ir_work_id = -1; - if ( !pb_info ) - { - pb_comm_id = ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1; - pb_work_id = ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1; - } - else - { - pb_comm_id = bli_thread_ocomm_id( pb_info ); - pb_work_id = bli_thread_work_id( pb_info ); - ic_info = bli_thrinfo_sub_node( pb_info ); + if ( !jc_info ) goto print_thrinfo; - if ( !ic_info ) - { - ic_comm_id = pa_comm_id = jr_comm_id = ir_comm_id = -1; - ic_work_id = pa_work_id = jr_work_id = ir_work_id = -1; - } - else - { - ic_comm_id = bli_thread_ocomm_id( ic_info ); - ic_work_id = bli_thread_work_id( ic_info ); - pa_info = bli_thrinfo_sub_node( ic_info ); + jc_comm_id = bli_thread_ocomm_id( jc_info ); + jc_work_id = bli_thread_work_id( jc_info ); + pc_info = bli_thrinfo_sub_node( jc_info ); - if ( !pa_info ) - { - pa_comm_id = jr_comm_id = ir_comm_id = -1; - pa_work_id = jr_work_id = ir_work_id = -1; - } - else - { - pa_comm_id = bli_thread_ocomm_id( pa_info ); - pa_work_id = bli_thread_work_id( pa_info ); - jr_info = bli_thrinfo_sub_node( pa_info ); + if ( !pc_info ) goto print_thrinfo; - if ( !jr_info ) - { - jr_comm_id = ir_comm_id = -1; - jr_work_id = ir_work_id = -1; - } - else - { - jr_comm_id = bli_thread_ocomm_id( jr_info ); - jr_work_id = bli_thread_work_id( jr_info ); - ir_info = bli_thrinfo_sub_node( jr_info ); + pc_comm_id = bli_thread_ocomm_id( pc_info ); + pc_work_id = bli_thread_work_id( pc_info ); + pb_info = bli_thrinfo_sub_node( pc_info ); - if ( !ir_info ) - { - ir_comm_id = -1; - ir_work_id = -1; - } - else - { - ir_comm_id = bli_thread_ocomm_id( ir_info ); - ir_work_id = bli_thread_work_id( ir_info ); - } - } - } - } - } - } - } + if ( !pb_info ) goto print_thrinfo; + + pb_comm_id = bli_thread_ocomm_id( pb_info ); + pb_work_id = bli_thread_work_id( pb_info ); + ic_info = bli_thrinfo_sub_node( pb_info ); + + if ( !ic_info ) goto print_thrinfo; + + ic_comm_id = bli_thread_ocomm_id( ic_info ); + ic_work_id = bli_thread_work_id( ic_info ); + pa_info = bli_thrinfo_sub_node( ic_info ); + + if ( !pa_info ) goto print_thrinfo; + + pa_comm_id = bli_thread_ocomm_id( pa_info ); + pa_work_id = bli_thread_work_id( pa_info ); + jr_info = bli_thrinfo_sub_node( pa_info ); + + if ( !jr_info ) goto print_thrinfo; + + jr_comm_id = bli_thread_ocomm_id( jr_info ); + jr_work_id = bli_thread_work_id( jr_info ); + ir_info = bli_thrinfo_sub_node( jr_info ); + + if ( !ir_info ) goto print_thrinfo; + + ir_comm_id = bli_thread_ocomm_id( ir_info ); + ir_work_id = bli_thread_work_id( ir_info ); + + print_thrinfo: - //printf( " gl jc pb kc pa ic jr \n" ); - //printf( " gl jc kc pb ic pa jr \n" ); printf( "comm ids: %4ld %4ld %4ld %4ld %4ld %4ld %4ld\n", ( long )jc_comm_id, ( long )pc_comm_id, @@ -285,44 +277,105 @@ void bli_l3_thrinfo_print_trsm_paths thrinfo_t** threads ) { + // In order to query the number of threads, we query the only thread we + // know exists: thread 0. dim_t n_threads = bli_thread_num_threads( threads[0] ); - dim_t gl_id; - thrinfo_t* jc_info = threads[0]; - thrinfo_t* pc_info = bli_thrinfo_sub_node( jc_info ); - thrinfo_t* pb_info = bli_thrinfo_sub_node( pc_info ); - thrinfo_t* ic_info = bli_thrinfo_sub_node( pb_info ); + // For the purposes of printing the "header" information that is common + // to the various instances of a thrinfo_t (ie: across all threads), we + // choose the last thread in case the problem is so small that there is + // only an "edge" case, which will always be assigned to the last thread + // (at least for higher levels of partitioning). + thrinfo_t* jc_info = threads[n_threads-1]; + thrinfo_t* pc_info = NULL; + thrinfo_t* pb_info = NULL; + thrinfo_t* ic_info = NULL; + thrinfo_t* pa_info = NULL; thrinfo_t* pa_info0 = NULL; + thrinfo_t* jr_info = NULL; thrinfo_t* jr_info0 = NULL; + thrinfo_t* ir_info = NULL; thrinfo_t* ir_info0 = NULL; - thrinfo_t* pa_info = bli_thrinfo_sub_node( ic_info ); - thrinfo_t* jr_info = bli_thrinfo_sub_node( pa_info ); - thrinfo_t* ir_info = bli_thrinfo_sub_node( jr_info ); - thrinfo_t* pa_info0 = bli_thrinfo_sub_prenode( ic_info ); - thrinfo_t* jr_info0 = ( pa_info0 ? bli_thrinfo_sub_node( pa_info0 ) : NULL ); - thrinfo_t* ir_info0 = ( jr_info0 ? bli_thrinfo_sub_node( jr_info0 ) : NULL ); + // Initialize the n_ways and n_threads fields of each thrinfo_t "level" + // to -1. More than likely, these will all be overwritten with meaningful + // values, but in case some thrinfo_t trees are not fully built (see + // next commnet), these will be the placeholder values. + dim_t jc_way = -1, pc_way = -1, pb_way = -1, ic_way = -1, + pa_way = -1, jr_way = -1, ir_way = -1, + pa_way0 = -1, jr_way0 = -1, ir_way0 = -1; - dim_t jc_way = bli_thread_n_way( jc_info ); - dim_t pc_way = bli_thread_n_way( pc_info ); - dim_t pb_way = bli_thread_n_way( pb_info ); - dim_t ic_way = bli_thread_n_way( ic_info ); + dim_t jc_nt = -1, pc_nt = -1, pb_nt = -1, ic_nt = -1, + pa_nt = -1, jr_nt = -1, ir_nt = -1, + pa_nt0 = -1, jr_nt0 = -1, ir_nt0 = -1; - dim_t pa_way = bli_thread_n_way( pa_info ); - dim_t jr_way = bli_thread_n_way( jr_info ); - dim_t ir_way = bli_thread_n_way( ir_info ); - dim_t pa_way0 = ( pa_info0 ? bli_thread_n_way( pa_info0 ) : -1 ); - dim_t jr_way0 = ( jr_info0 ? bli_thread_n_way( jr_info0 ) : -1 ); - dim_t ir_way0 = ( ir_info0 ? bli_thread_n_way( ir_info0 ) : -1 ); + // NOTE: We must check each thrinfo_t pointer for NULLness. Certain threads + // may not fully build their thrinfo_t structures--specifically when the + // dimension being parallelized is not large enough for each thread to have + // even one unit of work (where as unit is usually a single micropanel's + // width, MR or NR). - dim_t jc_nt = bli_thread_num_threads( jc_info ); - dim_t pc_nt = bli_thread_num_threads( pc_info ); - dim_t pb_nt = bli_thread_num_threads( pb_info ); - dim_t ic_nt = bli_thread_num_threads( ic_info ); + if ( !jc_info ) goto print_header; - dim_t pa_nt = bli_thread_num_threads( pa_info ); - dim_t jr_nt = bli_thread_num_threads( jr_info ); - dim_t ir_nt = bli_thread_num_threads( ir_info ); - dim_t pa_nt0 = ( pa_info0 ? bli_thread_num_threads( pa_info0 ) : -1 ); - dim_t jr_nt0 = ( jr_info0 ? bli_thread_num_threads( jr_info0 ) : -1 ); - dim_t ir_nt0 = ( ir_info0 ? bli_thread_num_threads( ir_info0 ) : -1 ); + jc_way = bli_thread_n_way( jc_info ); + jc_nt = bli_thread_num_threads( jc_info ); + pc_info = bli_thrinfo_sub_node( jc_info ); + + if ( !pc_info ) goto print_header; + + pc_way = bli_thread_n_way( pc_info ); + pc_nt = bli_thread_num_threads( pc_info ); + pb_info = bli_thrinfo_sub_node( pc_info ); + + if ( !pb_info ) goto print_header; + + pb_way = bli_thread_n_way( pb_info ); + pb_nt = bli_thread_num_threads( pb_info ); + ic_info = bli_thrinfo_sub_node( pb_info ); + + if ( !ic_info ) goto print_header; + + ic_way = bli_thread_n_way( ic_info ); + ic_nt = bli_thread_num_threads( ic_info ); + pa_info = bli_thrinfo_sub_node( ic_info ); + pa_info0 = bli_thrinfo_sub_prenode( ic_info ); + + // check_header_prenode: + + if ( !pa_info0 ) goto check_header_node; + + pa_way0 = bli_thread_n_way( pa_info0 ); + pa_nt0 = bli_thread_num_threads( pa_info0 ); + jr_info0 = bli_thrinfo_sub_node( pa_info0 ); + + if ( !jr_info0 ) goto check_header_node; + + jr_way0 = bli_thread_n_way( jr_info0 ); + jr_nt0 = bli_thread_num_threads( jr_info0 ); + ir_info0 = bli_thrinfo_sub_node( jr_info0 ); + + if ( !ir_info0 ) goto check_header_node; + + ir_way0 = bli_thread_n_way( ir_info0 ); + ir_nt0 = bli_thread_num_threads( ir_info0 ); + + check_header_node: + + if ( !pa_info ) goto print_header; + + pa_way = bli_thread_n_way( pa_info ); + pa_nt = bli_thread_num_threads( pa_info ); + jr_info = bli_thrinfo_sub_node( pa_info ); + + if ( !jr_info ) goto print_header; + + jr_way = bli_thread_n_way( jr_info ); + jr_nt = bli_thread_num_threads( jr_info ); + ir_info = bli_thrinfo_sub_node( jr_info ); + + if ( !ir_info ) goto print_header; + + ir_way = bli_thread_n_way( ir_info ); + ir_nt = bli_thread_num_threads( ir_info ); + + print_header: printf( " jc kc pb ic pa jr ir\n" ); printf( "xx_nt: %4ld %4ld %4ld %4ld %2ld|%2ld %2ld|%2ld %2ld|%2ld\n", @@ -343,26 +396,105 @@ void bli_l3_thrinfo_print_trsm_paths ( long )ir_way0, ( long )ir_way ); printf( "==================================================\n" ); - dim_t jc_comm_id; - dim_t pc_comm_id; - dim_t pb_comm_id; - dim_t ic_comm_id; - dim_t pa_comm_id0, pa_comm_id; - dim_t jr_comm_id0, jr_comm_id; - dim_t ir_comm_id0, ir_comm_id; - dim_t jc_work_id; - dim_t pc_work_id; - dim_t pb_work_id; - dim_t ic_work_id; - dim_t pa_work_id0, pa_work_id; - dim_t jr_work_id0, jr_work_id; - dim_t ir_work_id0, ir_work_id; - - for ( gl_id = 0; gl_id < n_threads; ++gl_id ) + for ( dim_t gl_id = 0; gl_id < n_threads; ++gl_id ) { jc_info = threads[gl_id]; +#if 1 + // NOTE: This cpp branch contains code that is safe to execute + // for small problems that are parallelized enough that one or + // more threads gets no work. + + dim_t jc_comm_id = -1, pc_comm_id = -1, pb_comm_id = -1, ic_comm_id = -1, + pa_comm_id = -1, jr_comm_id = -1, ir_comm_id = -1, + pa_comm_id0 = -1, jr_comm_id0 = -1, ir_comm_id0 = -1; + + dim_t jc_work_id = -1, pc_work_id = -1, pb_work_id = -1, ic_work_id = -1, + pa_work_id = -1, jr_work_id = -1, ir_work_id = -1, + pa_work_id0 = -1, jr_work_id0 = -1, ir_work_id0 = -1; + + if ( !jc_info ) goto print_thrinfo; + + jc_comm_id = bli_thread_ocomm_id( jc_info ); + jc_work_id = bli_thread_work_id( jc_info ); + pc_info = bli_thrinfo_sub_node( jc_info ); + + if ( !pc_info ) goto print_thrinfo; + + pc_comm_id = bli_thread_ocomm_id( pc_info ); + pc_work_id = bli_thread_work_id( pc_info ); + pb_info = bli_thrinfo_sub_node( pc_info ); + + if ( !pb_info ) goto print_thrinfo; + + pb_comm_id = bli_thread_ocomm_id( pb_info ); + pb_work_id = bli_thread_work_id( pb_info ); + ic_info = bli_thrinfo_sub_node( pb_info ); + + if ( !ic_info ) goto print_thrinfo; + + ic_comm_id = bli_thread_ocomm_id( ic_info ); + ic_work_id = bli_thread_work_id( ic_info ); + pa_info = bli_thrinfo_sub_node( ic_info ); + pa_info0 = bli_thrinfo_sub_prenode( ic_info ); + + // check_thrinfo_prenode: + + if ( !pa_info0 ) goto check_thrinfo_node; + + pa_comm_id0 = bli_thread_ocomm_id( pa_info0 ); + pa_work_id0 = bli_thread_work_id( pa_info0 ); + jr_info0 = bli_thrinfo_sub_node( pa_info0 ); + + if ( !jr_info0 ) goto check_thrinfo_node; + + jr_comm_id0 = bli_thread_ocomm_id( jr_info0 ); + jr_work_id0 = bli_thread_work_id( jr_info0 ); + ir_info0 = bli_thrinfo_sub_node( jr_info0 ); + + if ( !ir_info0 ) goto check_thrinfo_node; + + ir_comm_id0 = bli_thread_ocomm_id( ir_info0 ); + ir_work_id0 = bli_thread_work_id( ir_info0 ); + + check_thrinfo_node: + + if ( !pa_info ) goto print_thrinfo; + + pa_comm_id = bli_thread_ocomm_id( pa_info ); + pa_work_id = bli_thread_work_id( pa_info ); + jr_info = bli_thrinfo_sub_node( pa_info ); + + if ( !jr_info ) goto print_thrinfo; + + jr_comm_id = bli_thread_ocomm_id( jr_info ); + jr_work_id = bli_thread_work_id( jr_info ); + ir_info = bli_thrinfo_sub_node( jr_info ); + + if ( !ir_info ) goto print_thrinfo; + + ir_comm_id = bli_thread_ocomm_id( ir_info ); + ir_work_id = bli_thread_work_id( ir_info ); + + print_thrinfo: +#else + dim_t jc_comm_id; + dim_t pc_comm_id; + dim_t pb_comm_id; + dim_t ic_comm_id; + dim_t pa_comm_id0, pa_comm_id; + dim_t jr_comm_id0, jr_comm_id; + dim_t ir_comm_id0, ir_comm_id; + + dim_t jc_work_id; + dim_t pc_work_id; + dim_t pb_work_id; + dim_t ic_work_id; + dim_t pa_work_id0, pa_work_id; + dim_t jr_work_id0, jr_work_id; + dim_t ir_work_id0, ir_work_id; + // NOTE: We must check each thrinfo_t pointer for NULLness. Certain threads // may not fully build their thrinfo_t structures--specifically when the // dimension being parallelized is not large enough for each thread to have @@ -488,6 +620,7 @@ void bli_l3_thrinfo_print_trsm_paths } } } +#endif printf( "comm ids: %4ld %4ld %4ld %4ld %2ld|%2ld %2ld|%2ld %2ld|%2ld\n", ( long )jc_comm_id, diff --git a/frame/3/bli_l3_ukr_oapi.c b/frame/3/bli_l3_ukr_oapi.c index a8191b1aa..33262b0bb 100644 --- a/frame/3/bli_l3_ukr_oapi.c +++ b/frame/3/bli_l3_ukr_oapi.c @@ -69,7 +69,7 @@ void PASTEMAC0(opname) \ bli_auxinfo_set_is_b( 1, &data ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(tname,_ukr,_vft) f = \ PASTEMAC(opname,_qfp)( dt ); \ \ @@ -130,7 +130,7 @@ void PASTEMAC0(opname) \ if ( bli_obj_is_lower( a11 ) ) \ { \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(tname,_ukr,_vft) f = \ PASTEMAC(opnamel,_qfp)( dt ); \ \ @@ -150,7 +150,7 @@ void PASTEMAC0(opname) \ else /* if ( bli_obj_is_upper( a11 ) ) */ \ { \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(tname,_ukr,_vft) f = \ PASTEMAC(opnameu,_qfp)( dt ); \ \ @@ -205,7 +205,7 @@ void PASTEMAC0(opname) \ if ( bli_obj_is_lower( a ) ) \ { \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(tname,_ukr,_vft) f = \ PASTEMAC(opnamel,_qfp)( dt ); \ \ @@ -221,7 +221,7 @@ void PASTEMAC0(opname) \ else /* if ( bli_obj_is_upper( a ) ) */ \ { \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(tname,_ukr,_vft) f = \ PASTEMAC(opnameu,_qfp)( dt ); \ \ diff --git a/frame/3/bli_l3_ukr_oapi.h b/frame/3/bli_l3_ukr_oapi.h index 512df492b..5fed11ede 100644 --- a/frame/3/bli_l3_ukr_oapi.h +++ b/frame/3/bli_l3_ukr_oapi.h @@ -40,7 +40,7 @@ #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0(opname) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* alpha, \ obj_t* a, \ @@ -56,7 +56,7 @@ GENPROT( gemm_ukernel ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0(opname) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* alpha, \ obj_t* a1x, \ @@ -73,7 +73,7 @@ GENPROT( gemmtrsm_ukernel ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0(opname) \ +BLIS_EXPORT_BLIS void PASTEMAC0(opname) \ ( \ obj_t* a, \ obj_t* b, \ diff --git a/frame/3/gemm/bli_gemm_cntl.c b/frame/3/gemm/bli_gemm_cntl.c index fc0a4a786..d7cd0a92c 100644 --- a/frame/3/gemm/bli_gemm_cntl.c +++ b/frame/3/gemm/bli_gemm_cntl.c @@ -56,9 +56,9 @@ cntl_t* bli_gemmbp_cntl_create pack_t schema_b ) { - void* macro_kernel_fp; - void* packa_fp; - void* packb_fp; + void_fp macro_kernel_fp; + void_fp packa_fp; + void_fp packb_fp; // Use the function pointers to the macrokernels that use slab // assignment of micropanels to threads in the jr and ir loops. @@ -165,7 +165,7 @@ cntl_t* bli_gemmpb_cntl_create opid_t family ) { - void* macro_kernel_p = bli_gemm_ker_var1; + void_fp macro_kernel_p = bli_gemm_ker_var1; // Change the macro-kernel if the operation family is herk or trmm. //if ( family == BLIS_HERK ) macro_kernel_p = bli_herk_x_ker_var2; @@ -270,7 +270,7 @@ cntl_t* bli_gemm_cntl_create_node rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, cntl_t* sub_node ) { diff --git a/frame/3/gemm/bli_gemm_cntl.h b/frame/3/gemm/bli_gemm_cntl.h index c6d20b170..bff91b58a 100644 --- a/frame/3/gemm/bli_gemm_cntl.h +++ b/frame/3/gemm/bli_gemm_cntl.h @@ -74,7 +74,7 @@ cntl_t* bli_gemm_cntl_create_node rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, cntl_t* sub_node ); diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index 20928e198..bb6f2ee3b 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -54,6 +54,7 @@ void bli_gemm_front obj_t c_local; + #ifdef BLIS_ENABLE_SMALL_MATRIX // Only handle small problems separately for homogeneous datatypes. if ( bli_obj_dt( a ) == bli_obj_dt( b ) && diff --git a/frame/3/gemm/other/bli_gemm_ker_var5.c b/frame/3/gemm/other/bli_gemm_ker_var5.c index 0d0c914d8..9e13c4edd 100644 --- a/frame/3/gemm/other/bli_gemm_ker_var5.c +++ b/frame/3/gemm/other/bli_gemm_ker_var5.c @@ -45,7 +45,7 @@ typedef void (*FUNCPTR_T)( void* b, inc_t rs_b, dim_t pd_b, inc_t ps_b, void* beta, void* c, inc_t rs_c, inc_t cs_c, - void* gemm_ukr + void_fp gemm_ukr ); static FUNCPTR_T GENARRAY(ftypes,gemm_ker_var5); @@ -87,7 +87,7 @@ void bli_gemm_ker_var5( obj_t* a, FUNCPTR_T f; func_t* gemm_ukrs; - void* gemm_ukr; + void_fp gemm_ukr; // Detach and multiply the scalars attached to A and B. @@ -135,7 +135,7 @@ void PASTEMAC(ch,varname)( \ void* b, inc_t rs_b, dim_t pd_b, inc_t ps_b, \ void* beta, \ void* c, inc_t rs_c, inc_t cs_c, \ - void* gemm_ukr \ + void_fp gemm_ukr \ ) \ { \ /* Cast the micro-kernel address to its function pointer type. */ \ diff --git a/frame/3/gemm/other/bli_gemm_ker_var5.h b/frame/3/gemm/other/bli_gemm_ker_var5.h index 7e24bb5f9..e88db5cb5 100644 --- a/frame/3/gemm/other/bli_gemm_ker_var5.h +++ b/frame/3/gemm/other/bli_gemm_ker_var5.h @@ -59,7 +59,7 @@ void PASTEMAC(ch,varname)( \ void* b, inc_t rs_b, dim_t pd_b, inc_t ps_b, \ void* beta, \ void* c, inc_t rs_c, inc_t cs_c, \ - void* gemm_ukr \ + void_fp gemm_ukr \ ); INSERT_GENTPROT_BASIC( gemm_ker_var5 ) diff --git a/frame/3/old/bli_l3_sup_edge.h b/frame/3/old/bli_l3_sup_edge.h new file mode 100644 index 000000000..06f3bb18b --- /dev/null +++ b/frame/3/old/bli_l3_sup_edge.h @@ -0,0 +1,141 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +static +void bli_dgemmsup_ker_edge_dispatcher + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx, + const dim_t num_mr, + const dim_t num_nr, + dim_t* restrict mrs, + dim_t* restrict nrs, + dgemmsup_ker_ft* kmap + ) +{ + #if 1 + + // outer loop = mr; inner loop = nr + + dim_t n_left = n0; + double* restrict cj = c; + double* restrict bj = b; + + for ( dim_t j = 0; n_left != 0; ++j ) + { + const dim_t nr_cur = nrs[ j ]; + + if ( nr_cur <= n_left ) + { + dim_t m_left = m0; + double* restrict cij = cj; + double* restrict ai = a; + + for ( dim_t i = 0; m_left != 0; ++i ) + { + const dim_t mr_cur = mrs[ i ]; + + if ( mr_cur <= m_left ) + { + dgemmsup_ker_ft ker_fp = kmap[ i*num_nr + j*1 ]; + + ker_fp + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + } + + cj += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + } + + #else + + // outer loop = nr; inner loop = mr + + dim_t m_left = m0; + double* restrict ci = c; + double* restrict ai = a; + + for ( dim_t i = 0; m_left != 0; ++i ) + { + const dim_t mr_cur = mrs[ i ]; + + if ( mr_cur <= m_left ) + { + dim_t n_left = n0; + double* restrict cij = ci; + double* restrict bj = b; + + for ( dim_t j = 0; n_left != 0; ++j ) + { + const dim_t nr_cur = nrs[ j ]; + + if ( nr_cur <= n_left ) + { + dgemmsup_ker_ft ker_fp = kmap[ i*num_nr + j*1 ]; + + ker_fp + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + + } + } + + ci += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + } + #endif +} + diff --git a/frame/3/syrk/bli_syrk_front.h b/frame/3/syrk/bli_syrk_front.h index 98b1e1251..de6b0ed0a 100644 --- a/frame/3/syrk/bli_syrk_front.h +++ b/frame/3/syrk/bli_syrk_front.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +43,7 @@ void bli_syrk_front rntm_t* rntm, cntl_t* cntl ); - + err_t bli_syrk_small ( obj_t* alpha, @@ -52,4 +53,4 @@ err_t bli_syrk_small obj_t* c, cntx_t* cntx, cntl_t* cntl - ); \ No newline at end of file + ); diff --git a/frame/3/trsm/bli_trsm_cntl.c b/frame/3/trsm/bli_trsm_cntl.c index ff2f18a1f..845370448 100644 --- a/frame/3/trsm/bli_trsm_cntl.c +++ b/frame/3/trsm/bli_trsm_cntl.c @@ -56,9 +56,9 @@ cntl_t* bli_trsm_l_cntl_create pack_t schema_b ) { - void* macro_kernel_p; - void* packa_fp; - void* packb_fp; + void_fp macro_kernel_p; + void_fp packa_fp; + void_fp packb_fp; // Use the function pointer to the macrokernels that use slab // assignment of micropanels to threads in the jr and ir loops. @@ -210,10 +210,10 @@ cntl_t* bli_trsm_r_cntl_create ) { // NOTE: trsm macrokernels are presently disabled for right-side execution. - void* macro_kernel_p = bli_trsm_xx_ker_var2; + void_fp macro_kernel_p = bli_trsm_xx_ker_var2; - void* packa_fp = bli_packm_blk_var1; - void* packb_fp = bli_packm_blk_var1; + void_fp packa_fp = bli_packm_blk_var1; + void_fp packb_fp = bli_packm_blk_var1; const opid_t family = BLIS_TRSM; @@ -318,7 +318,7 @@ cntl_t* bli_trsm_cntl_create_node rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, cntl_t* sub_node ) { diff --git a/frame/3/trsm/bli_trsm_cntl.h b/frame/3/trsm/bli_trsm_cntl.h index f81f70d07..7fdb1fc4f 100644 --- a/frame/3/trsm/bli_trsm_cntl.h +++ b/frame/3/trsm/bli_trsm_cntl.h @@ -69,7 +69,7 @@ cntl_t* bli_trsm_cntl_create_node rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, cntl_t* sub_node ); diff --git a/frame/3/trsm/bli_trsm_front.c b/frame/3/trsm/bli_trsm_front.c index 3c3b08a08..eb88913ea 100644 --- a/frame/3/trsm/bli_trsm_front.c +++ b/frame/3/trsm/bli_trsm_front.c @@ -34,7 +34,6 @@ */ #include "blis.h" - //#define PRINT_SMALL_TRSM_INFO void bli_trsm_front @@ -49,11 +48,15 @@ void bli_trsm_front ) { bli_init_once(); - + int i, j; obj_t a_local; obj_t b_local; obj_t c_local; +//int m = bli_obj_length(*b); +//int n = bli_obj_width(*b); +//float *L = a->buffer; + // float *B = b->buffer; #ifdef PRINT_SMALL_TRSM_INFO printf("Side:: %c\n", side ? 'R' : 'L'); diff --git a/frame/base/bli_apool.h b/frame/base/bli_apool.h index 7d61e8eb1..f08d89c9d 100644 --- a/frame/base/bli_apool.h +++ b/frame/base/bli_apool.h @@ -56,7 +56,7 @@ static pool_t* bli_apool_pool( apool_t* apool ) return &(apool->pool); } -static bli_pthread_mutex_t* bli_apool_mutex( apool_t* apool ) +static bli_pthread_mutex_t* bli_apool_mutex( apool_t* apool ) { return &(apool->mutex); } diff --git a/frame/base/bli_arch.c b/frame/base/bli_arch.c index 9cbbade72..09388da0d 100644 --- a/frame/base/bli_arch.c +++ b/frame/base/bli_arch.c @@ -36,6 +36,7 @@ #ifndef BLIS_CONFIGURETIME_CPUID #include "blis.h" #else + #define BLIS_EXPORT_BLIS #include "bli_system.h" #include "bli_type_defs.h" #include "bli_arch.h" diff --git a/frame/base/bli_arch.h b/frame/base/bli_arch.h index 4299a12a0..6b8a38ebd 100644 --- a/frame/base/bli_arch.h +++ b/frame/base/bli_arch.h @@ -35,12 +35,12 @@ #ifndef BLIS_ARCH_H #define BLIS_ARCH_H -arch_t bli_arch_query_id( void ); +BLIS_EXPORT_BLIS arch_t bli_arch_query_id( void ); -void bli_arch_set_id_once( void ); -void bli_arch_set_id( void ); +void bli_arch_set_id_once( void ); +void bli_arch_set_id( void ); -char* bli_arch_string( arch_t id ); +BLIS_EXPORT_BLIS char* bli_arch_string( arch_t id ); #endif diff --git a/frame/base/bli_blksz.h b/frame/base/bli_blksz.h index 15280ca18..a3400b2fa 100644 --- a/frame/base/bli_blksz.h +++ b/frame/base/bli_blksz.h @@ -186,7 +186,7 @@ static void bli_blksz_scale_def_max // ----------------------------------------------------------------------------- -blksz_t* bli_blksz_create_ed +BLIS_EXPORT_BLIS blksz_t* bli_blksz_create_ed ( dim_t b_s, dim_t be_s, dim_t b_d, dim_t be_d, @@ -194,13 +194,13 @@ blksz_t* bli_blksz_create_ed dim_t b_z, dim_t be_z ); -blksz_t* bli_blksz_create +BLIS_EXPORT_BLIS blksz_t* bli_blksz_create ( dim_t b_s, dim_t b_d, dim_t b_c, dim_t b_z, dim_t be_s, dim_t be_d, dim_t be_c, dim_t be_z ); -void bli_blksz_init_ed +BLIS_EXPORT_BLIS void bli_blksz_init_ed ( blksz_t* b, dim_t b_s, dim_t be_s, @@ -209,20 +209,20 @@ void bli_blksz_init_ed dim_t b_z, dim_t be_z ); -void bli_blksz_init +BLIS_EXPORT_BLIS void bli_blksz_init ( blksz_t* b, dim_t b_s, dim_t b_d, dim_t b_c, dim_t b_z, dim_t be_s, dim_t be_d, dim_t be_c, dim_t be_z ); -void bli_blksz_init_easy +BLIS_EXPORT_BLIS void bli_blksz_init_easy ( blksz_t* b, dim_t b_s, dim_t b_d, dim_t b_c, dim_t b_z ); -void bli_blksz_free +BLIS_EXPORT_BLIS void bli_blksz_free ( blksz_t* b ); @@ -230,7 +230,7 @@ void bli_blksz_free // ----------------------------------------------------------------------------- #if 0 -void bli_blksz_reduce_dt_to +BLIS_EXPORT_BLIS void bli_blksz_reduce_dt_to ( num_t dt_bm, blksz_t* bmult, num_t dt_bs, blksz_t* blksz diff --git a/frame/base/bli_check.h b/frame/base/bli_check.h index 90de609ad..539458406 100644 --- a/frame/base/bli_check.h +++ b/frame/base/bli_check.h @@ -34,7 +34,7 @@ */ -err_t bli_check_error_code_helper( gint_t code, char* file, guint_t line ); +BLIS_EXPORT_BLIS err_t bli_check_error_code_helper( gint_t code, char* file, guint_t line ); err_t bli_check_valid_error_level( errlev_t level ); diff --git a/frame/base/bli_clock.c b/frame/base/bli_clock.c index a4df37e16..f5bf2c49f 100644 --- a/frame/base/bli_clock.c +++ b/frame/base/bli_clock.c @@ -59,6 +59,9 @@ double bli_clock_min_diff( double time_min, double time_start ) // - under a nanosecond // is actually garbled due to the clocks being taken too closely together. if ( time_min <= 0.0 ) time_min = time_min_prev; + // To genuinely measure time for an application taking more than an hour, the below + // line is commented. If wrongly measuring higher time we could always use previous_min. + /* else if ( time_min > 3600.0 ) time_min = time_min_prev; */ else if ( time_min < 1.0e-9 ) time_min = time_min_prev; return time_min; diff --git a/frame/base/bli_clock.h b/frame/base/bli_clock.h index c17eafdd1..d52e26ef6 100644 --- a/frame/base/bli_clock.h +++ b/frame/base/bli_clock.h @@ -32,7 +32,8 @@ */ -double bli_clock( void ); -double bli_clock_min_diff( double time_min, double time_start ); +BLIS_EXPORT_BLIS double bli_clock( void ); +BLIS_EXPORT_BLIS double bli_clock_min_diff( double time_min, double time_start ); + double bli_clock_helper( void ); diff --git a/frame/base/bli_cntl.c b/frame/base/bli_cntl.c index 21027ee2c..0de6fbc39 100644 --- a/frame/base/bli_cntl.c +++ b/frame/base/bli_cntl.c @@ -40,7 +40,7 @@ cntl_t* bli_cntl_create_node rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, void* params, cntl_t* sub_node ) diff --git a/frame/base/bli_cntl.h b/frame/base/bli_cntl.h index a697f44ab..998a92571 100644 --- a/frame/base/bli_cntl.h +++ b/frame/base/bli_cntl.h @@ -42,7 +42,7 @@ struct cntl_s // Basic fields (usually required). opid_t family; bszid_t bszid; - void* var_func; + void_fp var_func; struct cntl_s* sub_prenode; struct cntl_s* sub_node; @@ -60,56 +60,56 @@ typedef struct cntl_s cntl_t; // -- Control tree prototypes -- -cntl_t* bli_cntl_create_node +BLIS_EXPORT_BLIS cntl_t* bli_cntl_create_node ( rntm_t* rntm, opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, void* params, cntl_t* sub_node ); -void bli_cntl_free_node +BLIS_EXPORT_BLIS void bli_cntl_free_node ( rntm_t* rntm, cntl_t* cntl ); -void bli_cntl_clear_node +BLIS_EXPORT_BLIS void bli_cntl_clear_node ( cntl_t* cntl ); // ----------------------------------------------------------------------------- -void bli_cntl_free +BLIS_EXPORT_BLIS void bli_cntl_free ( rntm_t* rntm, cntl_t* cntl, thrinfo_t* thread ); -void bli_cntl_free_w_thrinfo +BLIS_EXPORT_BLIS void bli_cntl_free_w_thrinfo ( rntm_t* rntm, cntl_t* cntl, thrinfo_t* thread ); -void bli_cntl_free_wo_thrinfo +BLIS_EXPORT_BLIS void bli_cntl_free_wo_thrinfo ( rntm_t* rntm, cntl_t* cntl ); -cntl_t* bli_cntl_copy +BLIS_EXPORT_BLIS cntl_t* bli_cntl_copy ( rntm_t* rntm, cntl_t* cntl ); -void bli_cntl_mark_family +BLIS_EXPORT_BLIS void bli_cntl_mark_family ( opid_t family, cntl_t* cntl @@ -137,7 +137,7 @@ static bszid_t bli_cntl_bszid( cntl_t* cntl ) return cntl->bszid; } -static void* bli_cntl_var_func( cntl_t* cntl ) +static void_fp bli_cntl_var_func( cntl_t* cntl ) { return cntl->var_func; } @@ -200,7 +200,7 @@ static void bli_cntl_set_bszid( bszid_t bszid, cntl_t* cntl ) cntl->bszid = bszid; } -static void bli_cntl_set_var_func( void* var_func, cntl_t* cntl ) +static void bli_cntl_set_var_func( void_fp var_func, cntl_t* cntl ) { cntl->var_func = var_func; } diff --git a/frame/base/bli_cntx.c b/frame/base/bli_cntx.c index 3d47ff604..b612518b8 100644 --- a/frame/base/bli_cntx.c +++ b/frame/base/bli_cntx.c @@ -48,9 +48,8 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) // This function can be called from the bli_cntx_init_*() function for // a particular architecture if the kernel developer wishes to use // non-default blocksizes. It should be called after - // bli_cntx_init_defaults() so that default blocksizes remain - // for any datatypes / register blocksizes that were not targed for - // optimization. + // bli_cntx_init_defaults() so that the context begins with default + // blocksizes across all datatypes. /* Example prototypes: @@ -76,49 +75,36 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) cntx_t* cntx ); */ + va_list args; dim_t i; - bszid_t* bszids; - blksz_t** blkszs; - bszid_t* bmults; - double* dsclrs; - double* msclrs; - - cntx_t* cntx; - - blksz_t* cntx_blkszs; - bszid_t* cntx_bmults; - - // Allocate some temporary local arrays. - + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_blkszs(): " ); + #endif + bszid_t* bszids = bli_malloc_intl( n_bs * sizeof( bszid_t ) ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_blkszs(): " ); #endif - bszids = bli_malloc_intl( n_bs * sizeof( bszid_t ) ); + blksz_t** blkszs = bli_malloc_intl( n_bs * sizeof( blksz_t* ) ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_blkszs(): " ); #endif - blkszs = bli_malloc_intl( n_bs * sizeof( blksz_t* ) ); + bszid_t* bmults = bli_malloc_intl( n_bs * sizeof( bszid_t ) ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_blkszs(): " ); #endif - bmults = bli_malloc_intl( n_bs * sizeof( bszid_t ) ); + double* dsclrs = bli_malloc_intl( n_bs * sizeof( double ) ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_blkszs(): " ); #endif - dsclrs = bli_malloc_intl( n_bs * sizeof( double ) ); - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_cntx_set_blkszs(): " ); - #endif - msclrs = bli_malloc_intl( n_bs * sizeof( double ) ); + double* msclrs = bli_malloc_intl( n_bs * sizeof( double ) ); // -- Begin variable argument section -- @@ -175,7 +161,7 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) } // The last argument should be the context pointer. - cntx = ( cntx_t* )va_arg( args, cntx_t* ); + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); // Shutdown variable argument environment and clean up stack. va_end( args ); @@ -188,8 +174,9 @@ void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ) // Query the context for the addresses of: // - the blocksize object array // - the blocksize multiple array - cntx_blkszs = bli_cntx_blkszs_buf( cntx ); - cntx_bmults = bli_cntx_bmults_buf( cntx ); + + blksz_t* cntx_blkszs = bli_cntx_blkszs_buf( cntx ); + bszid_t* cntx_bmults = bli_cntx_bmults_buf( cntx ); // Now that we have the context address, we want to copy the values // from the temporary buffers into the corresponding buffers in the @@ -353,15 +340,10 @@ void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ) NOTE: This function modifies an existing context that is presumed to have been initialized for native execution. */ + va_list args; dim_t i; - bszid_t* bszids; - double* dsclrs; - double* msclrs; - - cntx_t* cntx; - // Return early if called with BLIS_NAT. if ( method == BLIS_NAT ) return; @@ -370,17 +352,17 @@ void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_ind_blkszs(): " ); #endif - bszids = bli_malloc_intl( n_bs * sizeof( bszid_t ) ); + bszid_t* bszids = bli_malloc_intl( n_bs * sizeof( bszid_t ) ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_ind_blkszs(): " ); #endif - dsclrs = bli_malloc_intl( n_bs * sizeof( double ) ); + double* dsclrs = bli_malloc_intl( n_bs * sizeof( double ) ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_ind_blkszs(): " ); #endif - msclrs = bli_malloc_intl( n_bs * sizeof( double ) ); + double* msclrs = bli_malloc_intl( n_bs * sizeof( double ) ); // -- Begin variable argument section -- @@ -408,7 +390,7 @@ void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ) } // The last argument should be the context pointer. - cntx = ( cntx_t* )va_arg( args, cntx_t* ); + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); // Shutdown variable argument environment and clean up stack. va_end( args ); @@ -523,22 +505,22 @@ void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ) // This function can be called from the bli_cntx_init_*() function for // a particular architecture if the kernel developer wishes to use // non-default level-3 microkernels. It should be called after - // bli_cntx_init_defaults() so that default functions are still called - // for any datatypes / register blocksizes that were not targed for - // optimization. + // bli_cntx_init_defaults() so that the context begins with default + // microkernels across all datatypes. /* Example prototypes: void bli_cntx_set_l3_nat_ukrs ( dim_t n_ukrs, - l3ukr_t ukr0_id, num_t dt0, void* ukr0_fp, bool_t pref0, - l3ukr_t ukr1_id, num_t dt1, void* ukr1_fp, bool_t pref1, - l3ukr_t ukr2_id, num_t dt2, void* ukr2_fp, bool_t pref2, + l3ukr_t ukr0_id, num_t dt0, void_fp ukr0_fp, bool_t pref0, + l3ukr_t ukr1_id, num_t dt1, void_fp ukr1_fp, bool_t pref1, + l3ukr_t ukr2_id, num_t dt2, void_fp ukr2_fp, bool_t pref2, ... cntx_t* cntx ); */ + va_list args; dim_t i; @@ -557,7 +539,7 @@ void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l3_nat_ukrs(): " ); #endif - void** ukr_fps = bli_malloc_intl( n_ukrs * sizeof( void* ) ); + void_fp* ukr_fps = bli_malloc_intl( n_ukrs * sizeof( void_fp ) ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l3_nat_ukrs(): " ); @@ -590,7 +572,7 @@ void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ) // within a bool_t afterwards. const l3ukr_t ukr_id = ( l3ukr_t )va_arg( args, l3ukr_t ); const num_t ukr_dt = ( num_t )va_arg( args, num_t ); - void* ukr_fp = ( void* )va_arg( args, void* ); + void_fp ukr_fp = ( void_fp )va_arg( args, void_fp ); const bool_t ukr_pref = ( bool_t )va_arg( args, int ); // Store the values in our temporary arrays. @@ -623,11 +605,11 @@ void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ) // Process each blocksize id tuple provided. for ( i = 0; i < n_ukrs; ++i ) { - // Read the current blocksize id, blksz_t* pointer, blocksize - // multiple id, and blocksize scalar. + // Read the current ukernel id, ukernel datatype, ukernel function + // pointer, and ukernel preference. const l3ukr_t ukr_id = ukr_ids[ i ]; const num_t ukr_dt = ukr_dts[ i ]; - void* ukr_fp = ukr_fps[ i ]; + void_fp ukr_fp = ukr_fps[ i ]; const bool_t ukr_pref = ukr_prefs[ i ]; // Index into the func_t and mbool_t for the current kernel id @@ -672,27 +654,513 @@ void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ) // ----------------------------------------------------------------------------- +void bli_cntx_set_l3_sup_thresh( dim_t n_thresh, ... ) +{ + // This function can be called from the bli_cntx_init_*() function for + // a particular architecture if the kernel developer wishes to use + // non-default thresholds for small/unpacked matrix handling. It should + // be called after bli_cntx_init_defaults() so that the context begins + // with default thresholds. + + /* Example prototypes: + + void bli_cntx_set_l3_sup_thresh + ( + dim_t n_thresh, + threshid_t th0_id, blksz_t* blksz0, + threshid_t th1_id, blksz_t* blksz1, + ... + cntx_t* cntx + ); + + */ + + va_list args; + dim_t i; + + // Allocate some temporary local arrays. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_thresh(): " ); + #endif + threshid_t* threshids = bli_malloc_intl( n_thresh * sizeof( threshid_t ) ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_thresh(): " ); + #endif + blksz_t** threshs = bli_malloc_intl( n_thresh * sizeof( blksz_t* ) ); + + // -- Begin variable argument section -- + + // Initialize variable argument environment. + va_start( args, n_thresh ); + + // Process n_thresh tuples. + for ( i = 0; i < n_thresh; ++i ) + { + // Here, we query the variable argument list for: + // - the threshid_t of the threshold we're about to process, + // - the address of the blksz_t object, + threshid_t th_id = ( threshid_t )va_arg( args, threshid_t ); + blksz_t* thresh = ( blksz_t* )va_arg( args, blksz_t* ); + + // Store the values in our temporary arrays. + threshids[ i ] = th_id; + threshs[ i ] = thresh; + } + + // The last argument should be the context pointer. + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); + + // Shutdown variable argument environment and clean up stack. + va_end( args ); + + // -- End variable argument section -- + + // Query the context for the addresses of: + // - the threshold array + blksz_t* cntx_threshs = bli_cntx_l3_sup_thresh_buf( cntx ); + + // Now that we have the context address, we want to copy the values + // from the temporary buffers into the corresponding buffers in the + // context. Notice that the blksz_t* pointers were saved, rather than + // the objects themselves, but we copy the contents of the objects + // when copying into the context. + + // Process each blocksize id tuple provided. + for ( i = 0; i < n_thresh; ++i ) + { + // Read the current blocksize id, blksz_t* pointer, blocksize + // multiple id, and blocksize scalar. + threshid_t th_id = threshids[ i ]; + blksz_t* thresh = threshs[ i ]; + + blksz_t* cntx_thresh = &cntx_threshs[ th_id ]; + + // Copy the blksz_t object contents into the appropriate + // location within the context's blksz_t array. + //cntx_threshs[ th_id ] = *thresh; + //bli_blksz_copy( thresh, cntx_thresh ); + bli_blksz_copy_if_pos( thresh, cntx_thresh ); + } + + // Free the temporary local arrays. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_thresh(): " ); + #endif + bli_free_intl( threshs ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_thresh(): " ); + #endif + bli_free_intl( threshids ); +} + +// ----------------------------------------------------------------------------- + +void bli_cntx_set_l3_sup_handlers( dim_t n_ops, ... ) +{ + // This function can be called from the bli_cntx_init_*() function for + // a particular architecture if the kernel developer wishes to use + // non-default level-3 operation handler for small/unpacked matrices. It + // should be called after bli_cntx_init_defaults() so that the context + // begins with default sup handlers across all datatypes. + + /* Example prototypes: + + void bli_cntx_set_l3_sup_handlers + ( + dim_t n_ops, + opid_t op0_id, void* handler0_fp, + opid_t op1_id, void* handler1_fp, + opid_t op2_id, void* handler2_fp, + ... + cntx_t* cntx + ); + */ + + va_list args; + dim_t i; + + // Allocate some temporary local arrays. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_handlers(): " ); + #endif + opid_t* op_ids = bli_malloc_intl( n_ops * sizeof( opid_t ) ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_handlers(): " ); + #endif + void** op_fps = bli_malloc_intl( n_ops * sizeof( void* ) ); + + // -- Begin variable argument section -- + + // Initialize variable argument environment. + va_start( args, n_ops ); + + // Process n_ukrs tuples. + for ( i = 0; i < n_ops; ++i ) + { + // Here, we query the variable argument list for: + // - the opid_t of the operation we're about to process, + // - the sup handler function pointer + // that we need to store to the context. + const opid_t op_id = ( opid_t )va_arg( args, opid_t ); + void* op_fp = ( void* )va_arg( args, void* ); + + // Store the values in our temporary arrays. + op_ids[ i ] = op_id; + op_fps[ i ] = op_fp; + } + + // The last argument should be the context pointer. + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); + + // Shutdown variable argument environment and clean up stack. + va_end( args ); + + // -- End variable argument section -- + + // Query the context for the addresses of: + // - the l3 small/unpacked handlers array + void** cntx_l3_sup_handlers = bli_cntx_l3_sup_handlers_buf( cntx ); + + // Now that we have the context address, we want to copy the values + // from the temporary buffers into the corresponding buffers in the + // context. + + // Process each operation id tuple provided. + for ( i = 0; i < n_ops; ++i ) + { + // Read the current ukernel id, ukernel datatype, and ukernel function + // pointer. + const opid_t op_id = op_ids[ i ]; + void* op_fp = op_fps[ i ]; + + // Store the sup handler function pointer into the slot for the + // specified operation id. + cntx_l3_sup_handlers[ op_id ] = op_fp; + } + + // Free the temporary local arrays. + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_handlers(): " ); + #endif + bli_free_intl( op_ids ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_handlers(): " ); + #endif + bli_free_intl( op_fps ); +} + +// ----------------------------------------------------------------------------- + +void bli_cntx_set_l3_sup_blkszs( dim_t n_bs, ... ) +{ + // This function can be called from the bli_cntx_init_*() function for + // a particular architecture if the kernel developer wishes to use + // non-default l3 sup blocksizes. It should be called after + // bli_cntx_init_defaults() so that the context begins with default + // blocksizes across all datatypes. + + /* Example prototypes: + + void bli_cntx_set_blkszs + ( + dim_t n_bs, + bszid_t bs0_id, blksz_t* blksz0, + bszid_t bs1_id, blksz_t* blksz1, + bszid_t bs2_id, blksz_t* blksz2, + ... + cntx_t* cntx + ); + */ + + va_list args; + dim_t i; + + // Allocate some temporary local arrays. + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_blkszs(): " ); + #endif + bszid_t* bszids = bli_malloc_intl( n_bs * sizeof( bszid_t ) ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_blkszs(): " ); + #endif + blksz_t** blkszs = bli_malloc_intl( n_bs * sizeof( blksz_t* ) ); + + // -- Begin variable argument section -- + + // Initialize variable argument environment. + va_start( args, n_bs ); + + // Process n_bs tuples. + for ( i = 0; i < n_bs; ++i ) + { + // Here, we query the variable argument list for: + // - the bszid_t of the blocksize we're about to process, + // - the address of the blksz_t object. + bszid_t bs_id = ( bszid_t )va_arg( args, bszid_t ); + blksz_t* blksz = ( blksz_t* )va_arg( args, blksz_t* ); + + // Store the values in our temporary arrays. + bszids[ i ] = bs_id; + blkszs[ i ] = blksz; + } + + // The last argument should be the context pointer. + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); + + // Shutdown variable argument environment and clean up stack. + va_end( args ); + + // -- End variable argument section -- + + // Query the context for the addresses of: + // - the blocksize object array + blksz_t* cntx_l3_sup_blkszs = bli_cntx_l3_sup_blkszs_buf( cntx ); + + // Now that we have the context address, we want to copy the values + // from the temporary buffers into the corresponding buffers in the + // context. Notice that the blksz_t* pointers were saved, rather than + // the objects themselves, but we copy the contents of the objects + // when copying into the context. + + // Process each blocksize id tuple provided. + for ( i = 0; i < n_bs; ++i ) + { + // Read the current blocksize id, blksz_t* pointer, blocksize + // multiple id, and blocksize scalar. + bszid_t bs_id = bszids[ i ]; + blksz_t* blksz = blkszs[ i ]; + + blksz_t* cntx_l3_sup_blksz = &cntx_l3_sup_blkszs[ bs_id ]; + + // Copy the blksz_t object contents into the appropriate + // location within the context's blksz_t array. + //cntx_l3_sup_blkszs[ bs_id ] = *blksz; + //bli_blksz_copy( blksz, cntx_l3_sup_blksz ); + bli_blksz_copy_if_pos( blksz, cntx_l3_sup_blksz ); + } + + // Free the temporary local arrays. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_blkszs(): " ); + #endif + bli_free_intl( blkszs ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_blkszs(): " ); + #endif + bli_free_intl( bszids ); +} + +// ----------------------------------------------------------------------------- + +void bli_cntx_set_l3_sup_kers( dim_t n_ukrs, ... ) +{ + // This function can be called from the bli_cntx_init_*() function for + // a particular architecture if the kernel developer wishes to use + // non-default level-3 microkernels for small/unpacked matrices. It + // should be called after bli_cntx_init_defaults() so that the context + // begins with default sup micro/millikernels across all datatypes. + + /* Example prototypes: + + void bli_cntx_set_l3_sup_kers + ( + dim_t n_ukrs, + stor3_t stor_id0, num_t dt0, void* ukr0_fp, bool_t pref0, + stor3_t stor_id1, num_t dt1, void* ukr1_fp, bool_t pref1, + stor3_t stor_id2, num_t dt2, void* ukr2_fp, bool_t pref2, + ... + cntx_t* cntx + ); + */ + + va_list args; + dim_t i; + + // Allocate some temporary local arrays. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + stor3_t* st3_ids = bli_malloc_intl( n_ukrs * sizeof( stor3_t ) ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + num_t* ukr_dts = bli_malloc_intl( n_ukrs * sizeof( num_t ) ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + void** ukr_fps = bli_malloc_intl( n_ukrs * sizeof( void* ) ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + bool_t* ukr_prefs = bli_malloc_intl( n_ukrs * sizeof( bool_t ) ); + + // -- Begin variable argument section -- + + // Initialize variable argument environment. + va_start( args, n_ukrs ); + + // Process n_ukrs tuples. + for ( i = 0; i < n_ukrs; ++i ) + { + // Here, we query the variable argument list for: + // - the stor3_t storage case being assigned to the kernel we're + // about to process, + // - the datatype of the kernel, + // - the kernel function pointer, and + // - the kernel function storage preference + // that we need to store to the context. + const stor3_t st3_id = ( stor3_t )va_arg( args, stor3_t ); + const num_t ukr_dt = ( num_t )va_arg( args, num_t ); + void* ukr_fp = ( void* )va_arg( args, void* ); + const bool_t ukr_pref = ( bool_t )va_arg( args, int ); + + // Store the values in our temporary arrays. + st3_ids[ i ] = st3_id; + ukr_dts[ i ] = ukr_dt; + ukr_fps[ i ] = ukr_fp; + ukr_prefs[ i ] = ukr_pref; + } + + // The last argument should be the context pointer. + cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* ); + + // Shutdown variable argument environment and clean up stack. + va_end( args ); + + // -- End variable argument section -- + + // Query the context for the addresses of: + // - the l3 small/unpacked ukernel func_t array + // - the l3 small/unpacked ukernel preferences array + func_t* cntx_l3_sup_kers = bli_cntx_l3_sup_kers_buf( cntx ); + mbool_t* cntx_l3_sup_kers_prefs = bli_cntx_l3_sup_kers_prefs_buf( cntx ); + + // Now that we have the context address, we want to copy the values + // from the temporary buffers into the corresponding buffers in the + // context. + +#if 0 + dim_t sup_map[ BLIS_NUM_LEVEL3_SUP_UKRS ][2]; + + // Create the small/unpacked ukernel mappings: + // - rv -> rrr 0, rcr 2 + // - rg -> rrc 1, rcc 3 + // - cv -> ccr 6, ccc 7 + // - cg -> crr 4, crc 5 + // - rd -> rrc 1 + // - cd -> crc 5 + // - rc -> rcc 3 + // - cr -> crr 4 + // - gx -> xxx 8 + // NOTE: We only need to set one slot in the context l3_sup_kers array + // for the general-stride/generic ukernel type, but since the loop below + // needs to be set up to set two slots to accommodate the RV, RG, CV, and + // CG, ukernel types, we will just be okay with the GX ukernel being set + // redundantly. (The RD, CD, CR, and RC ukernel types are set redundantly + // for the same reason.) + sup_map[ BLIS_GEMMSUP_RV_UKR ][0] = BLIS_RRR; + sup_map[ BLIS_GEMMSUP_RV_UKR ][1] = BLIS_RCR; + sup_map[ BLIS_GEMMSUP_RG_UKR ][0] = BLIS_RRC; + sup_map[ BLIS_GEMMSUP_RG_UKR ][1] = BLIS_RCC; + sup_map[ BLIS_GEMMSUP_CV_UKR ][0] = BLIS_CCR; + sup_map[ BLIS_GEMMSUP_CV_UKR ][1] = BLIS_CCC; + sup_map[ BLIS_GEMMSUP_CG_UKR ][0] = BLIS_CRR; + sup_map[ BLIS_GEMMSUP_CG_UKR ][1] = BLIS_CRC; + + sup_map[ BLIS_GEMMSUP_RD_UKR ][0] = BLIS_RRC; + sup_map[ BLIS_GEMMSUP_RD_UKR ][1] = BLIS_RRC; + sup_map[ BLIS_GEMMSUP_CD_UKR ][0] = BLIS_CRC; + sup_map[ BLIS_GEMMSUP_CD_UKR ][1] = BLIS_CRC; + + sup_map[ BLIS_GEMMSUP_RC_UKR ][0] = BLIS_RCC; + sup_map[ BLIS_GEMMSUP_RC_UKR ][1] = BLIS_RCC; + sup_map[ BLIS_GEMMSUP_CR_UKR ][0] = BLIS_CRR; + sup_map[ BLIS_GEMMSUP_CR_UKR ][1] = BLIS_CRR; + + sup_map[ BLIS_GEMMSUP_GX_UKR ][0] = BLIS_XXX; + sup_map[ BLIS_GEMMSUP_GX_UKR ][1] = BLIS_XXX; +#endif + + // Process each blocksize id tuple provided. + for ( i = 0; i < n_ukrs; ++i ) + { + // Read the current stor3_t id, ukernel datatype, ukernel function + // pointer, and ukernel preference. + const stor3_t st3_id = st3_ids[ i ]; + const num_t ukr_dt = ukr_dts[ i ]; + void* ukr_fp = ukr_fps[ i ]; + const bool_t ukr_pref = ukr_prefs[ i ]; + + // Index to the func_t and mbool_t for the current stor3_t id + // being processed. + func_t* ukrs = &cntx_l3_sup_kers[ st3_id ]; + mbool_t* prefs = &cntx_l3_sup_kers_prefs[ st3_id ]; + + // Store the ukernel function pointer and preference values into + // the stor3_t location in the context. + bli_func_set_dt( ukr_fp, ukr_dt, ukrs ); + bli_mbool_set_dt( ukr_pref, ukr_dt, prefs ); + } + + // Free the temporary local arrays. + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + bli_free_intl( st3_ids ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + bli_free_intl( ukr_dts ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + bli_free_intl( ukr_fps ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_cntx_set_l3_sup_kers(): " ); + #endif + bli_free_intl( ukr_prefs ); +} + +// ----------------------------------------------------------------------------- + void bli_cntx_set_l1f_kers( dim_t n_kers, ... ) { // This function can be called from the bli_cntx_init_*() function for // a particular architecture if the kernel developer wishes to use // non-default level-1f kernels. It should be called after - // bli_cntx_init_defaults() so that default functions are still called - // for any datatypes / register blocksizes that were not targed for - // optimization. + // bli_cntx_init_defaults() so that the context begins with default l1f + // kernels across all datatypes. /* Example prototypes: void bli_cntx_set_l1f_kers ( dim_t n_ukrs, - l1fkr_t ker0_id, num_t ker0_dt, void* ker0_fp, - l1fkr_t ker1_id, num_t ker1_dt, void* ker1_fp, - l1fkr_t ker2_id, num_t ker2_dt, void* ker2_fp, + l1fkr_t ker0_id, num_t ker0_dt, void_fp ker0_fp, + l1fkr_t ker1_id, num_t ker1_dt, void_fp ker1_fp, + l1fkr_t ker2_id, num_t ker2_dt, void_fp ker2_fp, ... cntx_t* cntx ); */ + va_list args; dim_t i; @@ -711,7 +1179,7 @@ void bli_cntx_set_l1f_kers( dim_t n_kers, ... ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l1f_kers(): " ); #endif - void** ker_fps = bli_malloc_intl( n_kers * sizeof( void* ) ); + void_fp* ker_fps = bli_malloc_intl( n_kers * sizeof( void_fp ) ); // -- Begin variable argument section -- @@ -728,7 +1196,7 @@ void bli_cntx_set_l1f_kers( dim_t n_kers, ... ) // that we need to store to the context. const l1fkr_t ker_id = ( l1fkr_t )va_arg( args, l1fkr_t ); const num_t ker_dt = ( num_t )va_arg( args, num_t ); - void* ker_fp = ( void* )va_arg( args, void* ); + void_fp ker_fp = ( void_fp )va_arg( args, void_fp ); // Store the values in our temporary arrays. ker_ids[ i ] = ker_id; @@ -755,11 +1223,11 @@ void bli_cntx_set_l1f_kers( dim_t n_kers, ... ) // Process each blocksize id tuple provided. for ( i = 0; i < n_kers; ++i ) { - // Read the current blocksize id, blksz_t* pointer, blocksize - // multiple id, and blocksize scalar. + // Read the current kernel id, kernel datatype, and kernel function + // pointer. const l1fkr_t ker_id = ker_ids[ i ]; const num_t ker_dt = ker_dts[ i ]; - void* ker_fp = ker_fps[ i ]; + void_fp ker_fp = ker_fps[ i ]; // Index into the func_t and mbool_t for the current kernel id // being processed. @@ -795,22 +1263,22 @@ void bli_cntx_set_l1v_kers( dim_t n_kers, ... ) // This function can be called from the bli_cntx_init_*() function for // a particular architecture if the kernel developer wishes to use // non-default level-1v kernels. It should be called after - // bli_cntx_init_defaults() so that default functions are still called - // for any datatypes / register blocksizes that were not targed for - // optimization. + // bli_cntx_init_defaults() so that the context begins with default l1v + // kernels across all datatypes. /* Example prototypes: void bli_cntx_set_l1v_kers ( dim_t n_ukrs, - l1vkr_t ker0_id, num_t ker0_dt, void* ker0_fp, - l1vkr_t ker1_id, num_t ker1_dt, void* ker1_fp, - l1vkr_t ker2_id, num_t ker2_dt, void* ker2_fp, + l1vkr_t ker0_id, num_t ker0_dt, void_fp ker0_fp, + l1vkr_t ker1_id, num_t ker1_dt, void_fp ker1_fp, + l1vkr_t ker2_id, num_t ker2_dt, void_fp ker2_fp, ... cntx_t* cntx ); */ + va_list args; dim_t i; @@ -829,7 +1297,7 @@ void bli_cntx_set_l1v_kers( dim_t n_kers, ... ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l1v_kers(): " ); #endif - void** ker_fps = bli_malloc_intl( n_kers * sizeof( void* ) ); + void_fp* ker_fps = bli_malloc_intl( n_kers * sizeof( void_fp ) ); // -- Begin variable argument section -- @@ -846,7 +1314,7 @@ void bli_cntx_set_l1v_kers( dim_t n_kers, ... ) // that we need to store to the context. const l1vkr_t ker_id = ( l1vkr_t )va_arg( args, l1vkr_t ); const num_t ker_dt = ( num_t )va_arg( args, num_t ); - void* ker_fp = ( void* )va_arg( args, void* ); + void_fp ker_fp = ( void_fp )va_arg( args, void_fp ); // Store the values in our temporary arrays. ker_ids[ i ] = ker_id; @@ -873,11 +1341,11 @@ void bli_cntx_set_l1v_kers( dim_t n_kers, ... ) // Process each blocksize id tuple provided. for ( i = 0; i < n_kers; ++i ) { - // Read the current blocksize id, blksz_t* pointer, blocksize - // multiple id, and blocksize scalar. + // Read the current kernel id, kernel datatype, and kernel function + // pointer. const l1vkr_t ker_id = ker_ids[ i ]; const num_t ker_dt = ker_dts[ i ]; - void* ker_fp = ker_fps[ i ]; + void_fp ker_fp = ker_fps[ i ]; // Index into the func_t and mbool_t for the current kernel id // being processed. @@ -913,22 +1381,22 @@ void bli_cntx_set_packm_kers( dim_t n_kers, ... ) // This function can be called from the bli_cntx_init_*() function for // a particular architecture if the kernel developer wishes to use // non-default packing kernels. It should be called after - // bli_cntx_init_defaults() so that default functions are still called - // for any datatypes / register blocksizes that were not targed for - // optimization. + // bli_cntx_init_defaults() so that the context begins with default packm + // kernels across all datatypes. /* Example prototypes: void bli_cntx_set_packm_kers ( dim_t n_ukrs, - l1mkr_t ker0_id, num_t ker0_dt, void* ker0_fp, - l1mkr_t ker1_id, num_t ker1_dt, void* ker1_fp, - l1mkr_t ker2_id, num_t ker2_dt, void* ker2_fp, + l1mkr_t ker0_id, num_t ker0_dt, void_fp ker0_fp, + l1mkr_t ker1_id, num_t ker1_dt, void_fp ker1_fp, + l1mkr_t ker2_id, num_t ker2_dt, void_fp ker2_fp, ... cntx_t* cntx ); */ + va_list args; dim_t i; @@ -947,7 +1415,7 @@ void bli_cntx_set_packm_kers( dim_t n_kers, ... ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_packm_kers(): " ); #endif - void** ker_fps = bli_malloc_intl( n_kers * sizeof( void* ) ); + void_fp* ker_fps = bli_malloc_intl( n_kers * sizeof( void_fp ) ); // -- Begin variable argument section -- @@ -964,7 +1432,7 @@ void bli_cntx_set_packm_kers( dim_t n_kers, ... ) // that we need to store to the context. const l1mkr_t ker_id = ( l1mkr_t )va_arg( args, l1mkr_t ); const num_t ker_dt = ( num_t )va_arg( args, num_t ); - void* ker_fp = ( void* )va_arg( args, void* ); + void_fp ker_fp = ( void_fp )va_arg( args, void_fp ); // Store the values in our temporary arrays. ker_ids[ i ] = ker_id; @@ -991,11 +1459,11 @@ void bli_cntx_set_packm_kers( dim_t n_kers, ... ) // Process each blocksize id tuple provided. for ( i = 0; i < n_kers; ++i ) { - // Read the current blocksize id, blksz_t* pointer, blocksize - // multiple id, and blocksize scalar. + // Read the current kernel id, kernel datatype, and kernel function + // pointer. const l1mkr_t ker_id = ker_ids[ i ]; const num_t ker_dt = ker_dts[ i ]; - void* ker_fp = ker_fps[ i ]; + void_fp ker_fp = ker_fps[ i ]; // Index into the func_t and mbool_t for the current kernel id // being processed. @@ -1061,11 +1529,11 @@ void bli_cntx_print( cntx_t* cntx ) ); } - for ( i = 0; i < BLIS_NUM_LEVEL3_UKRS; ++i ) + for ( i = 0; i < BLIS_NUM_3OP_RC_COMBOS; ++i ) { - func_t* ukr = bli_cntx_get_l3_nat_ukrs( i, cntx ); + func_t* ukr = bli_cntx_get_l3_sup_kers( i, cntx ); - printf( "l3 nat ukr %2lu: %16p %16p %16p %16p\n", + printf( "l3 sup ukr %2lu: %16p %16p %16p %16p\n", ( unsigned long )i, bli_func_get_dt( BLIS_FLOAT, ukr ), bli_func_get_dt( BLIS_DOUBLE, ukr ), diff --git a/frame/base/bli_cntx.h b/frame/base/bli_cntx.h index 450c753b6..fae7b5f6e 100644 --- a/frame/base/bli_cntx.h +++ b/frame/base/bli_cntx.h @@ -6,6 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,6 +50,12 @@ typedef struct cntx_s func_t* l3_nat_ukrs; mbool_t* l3_nat_ukrs_prefs; + blksz_t* l3_sup_thresh; + void** l3_sup_handlers; + blksz_t* l3_sup_blkszs; + func_t* l3_sup_kers; + mbool_t* l3_sup_kers_prefs; + func_t* l1f_kers; func_t* l1v_kers; @@ -89,6 +96,26 @@ static mbool_t* bli_cntx_l3_nat_ukrs_prefs_buf( cntx_t* cntx ) { return cntx->l3_nat_ukrs_prefs; } +static blksz_t* bli_cntx_l3_sup_thresh_buf( cntx_t* cntx ) +{ + return cntx->l3_sup_thresh; +} +static void** bli_cntx_l3_sup_handlers_buf( cntx_t* cntx ) +{ + return cntx->l3_sup_handlers; +} +static blksz_t* bli_cntx_l3_sup_blkszs_buf( cntx_t* cntx ) +{ + return cntx->l3_sup_blkszs; +} +static func_t* bli_cntx_l3_sup_kers_buf( cntx_t* cntx ) +{ + return cntx->l3_sup_kers; +} +static mbool_t* bli_cntx_l3_sup_kers_prefs_buf( cntx_t* cntx ) +{ + return cntx->l3_sup_kers_prefs; +} static func_t* bli_cntx_l1f_kers_buf( cntx_t* cntx ) { return cntx->l1f_kers; @@ -217,7 +244,7 @@ static func_t* bli_cntx_get_l3_vir_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) return func; } -static void* bli_cntx_get_l3_vir_ukr_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +static void_fp bli_cntx_get_l3_vir_ukr_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) { func_t* func = bli_cntx_get_l3_vir_ukrs( ukr_id, cntx ); @@ -232,7 +259,7 @@ static func_t* bli_cntx_get_l3_nat_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) return func; } -static void* bli_cntx_get_l3_nat_ukr_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +static void_fp bli_cntx_get_l3_nat_ukr_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) { func_t* func = bli_cntx_get_l3_nat_ukrs( ukr_id, cntx ); @@ -258,6 +285,108 @@ static bool_t bli_cntx_get_l3_nat_ukr_prefs_dt( num_t dt, l3ukr_t ukr_id, cntx_t // ----------------------------------------------------------------------------- +static blksz_t* bli_cntx_get_l3_sup_thresh( threshid_t thresh_id, cntx_t* cntx ) +{ + blksz_t* threshs = bli_cntx_l3_sup_thresh_buf( cntx ); + blksz_t* thresh = &threshs[ thresh_id ]; + + // Return the address of the blksz_t identified by thresh_id. + return thresh; +} + +static dim_t bli_cntx_get_l3_sup_thresh_dt( num_t dt, threshid_t thresh_id, cntx_t* cntx ) +{ + blksz_t* threshs = bli_cntx_get_l3_sup_thresh( thresh_id, cntx ); + dim_t thresh_dt = bli_blksz_get_def( dt, threshs ); + + // Return the main (default) threshold value for the datatype given. + return thresh_dt; +} + +static bool_t bli_cntx_l3_sup_thresh_is_met( num_t dt, dim_t m, dim_t n, dim_t k, cntx_t* cntx ) +{ + if ( m < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ) ) return TRUE; + if ( n < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ) ) return TRUE; + if ( k < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ) ) return TRUE; + + return FALSE; +} + +// ----------------------------------------------------------------------------- + +static void* bli_cntx_get_l3_sup_handler( opid_t op, cntx_t* cntx ) +{ + void** funcs = bli_cntx_l3_sup_handlers_buf( cntx ); + void* func = funcs[ op ]; + + return func; +} + +// ----------------------------------------------------------------------------- + +static blksz_t* bli_cntx_get_l3_sup_blksz( bszid_t bs_id, cntx_t* cntx ) +{ + blksz_t* blkszs = bli_cntx_l3_sup_blkszs_buf( cntx ); + blksz_t* blksz = &blkszs[ bs_id ]; + + // Return the address of the blksz_t identified by bs_id. + return blksz; +} + +static dim_t bli_cntx_get_l3_sup_blksz_def_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) +{ + blksz_t* blksz = bli_cntx_get_l3_sup_blksz( bs_id, cntx ); + dim_t bs_dt = bli_blksz_get_def( dt, blksz ); + + // Return the main (default) blocksize value for the datatype given. + return bs_dt; +} + +static dim_t bli_cntx_get_l3_sup_blksz_max_dt( num_t dt, bszid_t bs_id, cntx_t* cntx ) +{ + blksz_t* blksz = bli_cntx_get_l3_sup_blksz( bs_id, cntx ); + dim_t bs_dt = bli_blksz_get_max( dt, blksz ); + + // Return the auxiliary (maximum) blocksize value for the datatype given. + return bs_dt; +} + +// ----------------------------------------------------------------------------- + +static func_t* bli_cntx_get_l3_sup_kers( stor3_t stor_id, cntx_t* cntx ) +{ + func_t* funcs = bli_cntx_l3_sup_kers_buf( cntx ); + func_t* func = &funcs[ stor_id ]; + + return func; +} + +static void* bli_cntx_get_l3_sup_ker_dt( num_t dt, stor3_t stor_id, cntx_t* cntx ) +{ + func_t* func = bli_cntx_get_l3_sup_kers( stor_id, cntx ); + + return bli_func_get_dt( dt, func ); +} + +// ----------------------------------------------------------------------------- + +static mbool_t* bli_cntx_get_l3_sup_ker_prefs( stor3_t stor_id, cntx_t* cntx ) +{ + mbool_t* mbools = bli_cntx_l3_sup_kers_prefs_buf( cntx ); + mbool_t* mbool = &mbools[ stor_id ]; + + return mbool; +} + +static bool_t bli_cntx_get_l3_sup_ker_prefs_dt( num_t dt, stor3_t stor_id, cntx_t* cntx ) +{ + mbool_t* mbool = bli_cntx_get_l3_sup_ker_prefs( stor_id, cntx ); + + return bli_mbool_get_dt( dt, mbool ); +} + +// ----------------------------------------------------------------------------- + static func_t* bli_cntx_get_l1f_kers( l1fkr_t ker_id, cntx_t* cntx ) { func_t* funcs = bli_cntx_l1f_kers_buf( cntx ); @@ -266,7 +395,7 @@ static func_t* bli_cntx_get_l1f_kers( l1fkr_t ker_id, cntx_t* cntx ) return func; } -static void* bli_cntx_get_l1f_ker_dt( num_t dt, l1fkr_t ker_id, cntx_t* cntx ) +static void_fp bli_cntx_get_l1f_ker_dt( num_t dt, l1fkr_t ker_id, cntx_t* cntx ) { func_t* func = bli_cntx_get_l1f_kers( ker_id, cntx ); @@ -283,7 +412,7 @@ static func_t* bli_cntx_get_l1v_kers( l1vkr_t ker_id, cntx_t* cntx ) return func; } -static void* bli_cntx_get_l1v_ker_dt( num_t dt, l1vkr_t ker_id, cntx_t* cntx ) +static void_fp bli_cntx_get_l1v_ker_dt( num_t dt, l1vkr_t ker_id, cntx_t* cntx ) { func_t* func = bli_cntx_get_l1v_kers( ker_id, cntx ); @@ -309,9 +438,9 @@ static func_t* bli_cntx_get_packm_kers( l1mkr_t ker_id, cntx_t* cntx ) return func; } -static void* bli_cntx_get_packm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t* cntx ) +static void_fp bli_cntx_get_packm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t* cntx ) { - void* fp = NULL; + void_fp fp = NULL; // Only query the context for the packm func_t (and then extract the // datatype-specific function pointer) if the packm kernel being @@ -344,9 +473,9 @@ static func_t* bli_cntx_get_unpackm_kers( l1mkr_t ker_id, cntx_t* cntx ) return func; } -static void* bli_cntx_get_unpackm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t* cntx ) +static void_fp bli_cntx_get_unpackm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t* cntx ) { - void* fp = NULL; + void_fp fp = NULL; // Only query the context for the unpackm func_t (and then extract the // datatype-specific function pointer) if the unpackm kernel being @@ -366,7 +495,7 @@ static void* bli_cntx_get_unpackm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t* cntx static bool_t bli_cntx_l3_nat_ukr_prefers_rows_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) { - bool_t prefs = bli_cntx_get_l3_nat_ukr_prefs_dt( dt, ukr_id, cntx ); + const bool_t prefs = bli_cntx_get_l3_nat_ukr_prefs_dt( dt, ukr_id, cntx ); // A ukernel preference of TRUE means the ukernel prefers row storage. return ( bool_t ) @@ -375,7 +504,7 @@ static bool_t bli_cntx_l3_nat_ukr_prefers_rows_dt( num_t dt, l3ukr_t ukr_id, cnt static bool_t bli_cntx_l3_nat_ukr_prefers_cols_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) { - bool_t prefs = bli_cntx_get_l3_nat_ukr_prefs_dt( dt, ukr_id, cntx ); + const bool_t prefs = bli_cntx_get_l3_nat_ukr_prefs_dt( dt, ukr_id, cntx ); // A ukernel preference of FALSE means the ukernel prefers column storage. return ( bool_t ) @@ -458,10 +587,58 @@ static bool_t bli_cntx_l3_vir_ukr_dislikes_storage_of( obj_t* obj, l3ukr_t ukr_i // ----------------------------------------------------------------------------- +static bool_t bli_cntx_l3_sup_ker_prefers_rows_dt( num_t dt, stor3_t stor_id, cntx_t* cntx ) +{ + const bool_t prefs = bli_cntx_get_l3_sup_ker_prefs_dt( dt, stor_id, cntx ); + + // A ukernel preference of TRUE means the ukernel prefers row storage. + return ( bool_t ) + ( prefs == TRUE ); +} + +static bool_t bli_cntx_l3_sup_ker_prefers_cols_dt( num_t dt, stor3_t stor_id, cntx_t* cntx ) +{ + const bool_t prefs = bli_cntx_get_l3_sup_ker_prefs_dt( dt, stor_id, cntx ); + + // A ukernel preference of FALSE means the ukernel prefers column storage. + return ( bool_t ) + ( prefs == FALSE ); +} + +#if 0 +// NOTE: These static functions aren't needed yet. + +static bool_t bli_cntx_l3_sup_ker_prefers_storage_of( obj_t* obj, stor3_t stor_id, cntx_t* cntx ) +{ + const num_t dt = bli_obj_dt( obj ); + const bool_t ukr_prefers_rows + = bli_cntx_l3_sup_ker_prefers_rows_dt( dt, stor_id, cntx ); + const bool_t ukr_prefers_cols + = bli_cntx_l3_sup_ker_prefers_cols_dt( dt, stor_id, cntx ); + bool_t r_val = FALSE; + + if ( bli_obj_is_row_stored( obj ) && ukr_prefers_rows ) r_val = TRUE; + else if ( bli_obj_is_col_stored( obj ) && ukr_prefers_cols ) r_val = TRUE; + + return r_val; +} + +static bool_t bli_cntx_l3_sup_ker_dislikes_storage_of( obj_t* obj, stor3_t stor_id, cntx_t* cntx ) +{ + return ( bool_t ) + !bli_cntx_l3_sup_ker_prefers_storage_of( obj, stor_id, cntx ); +} +#endif + +// ----------------------------------------------------------------------------- + // // -- cntx_t modification (complex) -------------------------------------------- // +// NOTE: The framework does not use any of the following functions. We provide +// them in order to facilitate creating/modifying custom contexts. + static void bli_cntx_set_blksz( bszid_t bs_id, blksz_t* blksz, bszid_t mult_id, cntx_t* cntx ) { blksz_t* blkszs = bli_cntx_blkszs_buf( cntx ); @@ -471,6 +648,22 @@ static void bli_cntx_set_blksz( bszid_t bs_id, blksz_t* blksz, bszid_t mult_id, bmults[ bs_id ] = mult_id; } +static void bli_cntx_set_blksz_def_dt( num_t dt, bszid_t bs_id, dim_t bs, cntx_t* cntx ) +{ + blksz_t* blkszs = bli_cntx_blkszs_buf( cntx ); + blksz_t* blksz = &blkszs[ bs_id ]; + + bli_blksz_set_def( bs, dt, blksz ); +} + +static void bli_cntx_set_blksz_max_dt( num_t dt, bszid_t bs_id, dim_t bs, cntx_t* cntx ) +{ + blksz_t* blkszs = bli_cntx_blkszs_buf( cntx ); + blksz_t* blksz = &blkszs[ bs_id ]; + + bli_blksz_set_max( bs, dt, blksz ); +} + static void bli_cntx_set_l3_vir_ukr( l3ukr_t ukr_id, func_t* func, cntx_t* cntx ) { func_t* funcs = bli_cntx_l3_vir_ukrs_buf( cntx ); @@ -513,7 +706,7 @@ static void bli_cntx_set_packm_ker( l1mkr_t ker_id, func_t* func, cntx_t* cntx ) funcs[ ker_id ] = *func; } -static void bli_cntx_set_packm_ker_dt( void* fp, num_t dt, l1mkr_t ker_id, cntx_t* cntx ) +static void bli_cntx_set_packm_ker_dt( void_fp fp, num_t dt, l1mkr_t ker_id, cntx_t* cntx ) { func_t* func = ( func_t* )bli_cntx_get_packm_kers( ker_id, cntx ); @@ -527,7 +720,7 @@ static void bli_cntx_set_unpackm_ker( l1mkr_t ker_id, func_t* func, cntx_t* cntx funcs[ ker_id ] = *func; } -static void bli_cntx_set_unpackm_ker_dt( void* fp, num_t dt, l1mkr_t ker_id, cntx_t* cntx ) +static void bli_cntx_set_unpackm_ker_dt( void_fp fp, num_t dt, l1mkr_t ker_id, cntx_t* cntx ) { func_t* func = ( func_t* )bli_cntx_get_unpackm_kers( ker_id, cntx ); @@ -538,18 +731,24 @@ static void bli_cntx_set_unpackm_ker_dt( void* fp, num_t dt, l1mkr_t ker_id, cnt // Function prototypes -void bli_cntx_clear( cntx_t* cntx ); +BLIS_EXPORT_BLIS void bli_cntx_clear( cntx_t* cntx ); -void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_blkszs( ind_t method, dim_t n_bs, ... ); -void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_ind_blkszs( ind_t method, dim_t n_bs, ... ); -void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ); -void bli_cntx_set_l1f_kers( dim_t n_kers, ... ); -void bli_cntx_set_l1v_kers( dim_t n_kers, ... ); -void bli_cntx_set_packm_kers( dim_t n_kers, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l3_nat_ukrs( dim_t n_ukrs, ... ); -void bli_cntx_print( cntx_t* cntx ); +BLIS_EXPORT_BLIS void bli_cntx_set_l3_sup_thresh( dim_t n_thresh, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l3_sup_handlers( dim_t n_ops, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l3_sup_blkszs( dim_t n_bs, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l3_sup_kers( dim_t n_ukrs, ... ); + +BLIS_EXPORT_BLIS void bli_cntx_set_l1f_kers( dim_t n_kers, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_l1v_kers( dim_t n_kers, ... ); +BLIS_EXPORT_BLIS void bli_cntx_set_packm_kers( dim_t n_kers, ... ); + +BLIS_EXPORT_BLIS void bli_cntx_print( cntx_t* cntx ); #endif diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index d7e954977..ac3346326 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -48,6 +48,7 @@ #ifndef BLIS_CONFIGURETIME_CPUID #include "blis.h" #else + #define BLIS_EXPORT_BLIS #include "bli_system.h" #include "bli_type_defs.h" #include "bli_cpuid.h" diff --git a/frame/base/bli_cpuid.h b/frame/base/bli_cpuid.h index 70c861e2f..da53aeb57 100644 --- a/frame/base/bli_cpuid.h +++ b/frame/base/bli_cpuid.h @@ -51,29 +51,29 @@ #ifndef BLIS_CPUID_H #define BLIS_CPUID_H -arch_t bli_cpuid_query_id( void ); +arch_t bli_cpuid_query_id( void ); // Intel -bool_t bli_cpuid_is_skx( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_knl( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_haswell( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_sandybridge( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_penryn( uint32_t family, uint32_t model, uint32_t features ); +bool_t bli_cpuid_is_skx( uint32_t family, uint32_t model, uint32_t features ); +bool_t bli_cpuid_is_knl( uint32_t family, uint32_t model, uint32_t features ); +bool_t bli_cpuid_is_haswell( uint32_t family, uint32_t model, uint32_t features ); +bool_t bli_cpuid_is_sandybridge( uint32_t family, uint32_t model, uint32_t features ); +bool_t bli_cpuid_is_penryn( uint32_t family, uint32_t model, uint32_t features ); // AMD -bool_t bli_cpuid_is_zen2( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_zen( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_excavator( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_steamroller( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_piledriver( uint32_t family, uint32_t model, uint32_t features ); -bool_t bli_cpuid_is_bulldozer( uint32_t family, uint32_t model, uint32_t features ); +BLIS_EXPORT_BLIS bool_t bli_cpuid_is_zen2( uint32_t family, uint32_t model, uint32_t features ); +BLIS_EXPORT_BLIS bool_t bli_cpuid_is_zen( uint32_t family, uint32_t model, uint32_t features ); +BLIS_EXPORT_BLIS bool_t bli_cpuid_is_excavator( uint32_t family, uint32_t model, uint32_t features ); +BLIS_EXPORT_BLIS bool_t bli_cpuid_is_steamroller( uint32_t family, uint32_t model, uint32_t features ); +BLIS_EXPORT_BLIS bool_t bli_cpuid_is_piledriver( uint32_t family, uint32_t model, uint32_t features ); +BLIS_EXPORT_BLIS bool_t bli_cpuid_is_bulldozer( uint32_t family, uint32_t model, uint32_t features ); // ARM -bool_t bli_cpuid_is_thunderx2( uint32_t model, uint32_t part, uint32_t features ); -bool_t bli_cpuid_is_cortexa57( uint32_t model, uint32_t part, uint32_t features ); -bool_t bli_cpuid_is_cortexa53( uint32_t model, uint32_t part, uint32_t features ); -bool_t bli_cpuid_is_cortexa15( uint32_t model, uint32_t part, uint32_t features ); -bool_t bli_cpuid_is_cortexa9( uint32_t model, uint32_t part, uint32_t features ); +bool_t bli_cpuid_is_thunderx2( uint32_t model, uint32_t part, uint32_t features ); +bool_t bli_cpuid_is_cortexa57( uint32_t model, uint32_t part, uint32_t features ); +bool_t bli_cpuid_is_cortexa53( uint32_t model, uint32_t part, uint32_t features ); +bool_t bli_cpuid_is_cortexa15( uint32_t model, uint32_t part, uint32_t features ); +bool_t bli_cpuid_is_cortexa9( uint32_t model, uint32_t part, uint32_t features ); uint32_t bli_cpuid_query( uint32_t* family, uint32_t* model, uint32_t* features ); diff --git a/frame/base/bli_error.h b/frame/base/bli_error.h index bb624a5dc..e04c6784d 100644 --- a/frame/base/bli_error.h +++ b/frame/base/bli_error.h @@ -33,13 +33,13 @@ */ -void bli_print_msg( char* str, char* file, guint_t line ); -void bli_abort( void ); +BLIS_EXPORT_BLIS errlev_t bli_error_checking_level( void ); +BLIS_EXPORT_BLIS void bli_error_checking_level_set( errlev_t new_level ); -errlev_t bli_error_checking_level( void ); -void bli_error_checking_level_set( errlev_t new_level ); +BLIS_EXPORT_BLIS bool_t bli_error_checking_is_enabled( void ); -bool_t bli_error_checking_is_enabled( void ); +void bli_print_msg( char* str, char* file, guint_t line ); +void bli_abort( void ); -char* bli_error_string_for_code( gint_t code ); +char* bli_error_string_for_code( gint_t code ); diff --git a/frame/base/bli_func.c b/frame/base/bli_func.c index 435bd81de..d383cd0f2 100644 --- a/frame/base/bli_func.c +++ b/frame/base/bli_func.c @@ -37,10 +37,10 @@ func_t* bli_func_create ( - void* ptr_s, - void* ptr_d, - void* ptr_c, - void* ptr_z + void_fp ptr_s, + void_fp ptr_d, + void_fp ptr_c, + void_fp ptr_z ) { func_t* f; @@ -62,10 +62,10 @@ func_t* bli_func_create void bli_func_init ( func_t* f, - void* ptr_s, - void* ptr_d, - void* ptr_c, - void* ptr_z + void_fp ptr_s, + void_fp ptr_d, + void_fp ptr_c, + void_fp ptr_z ) { bli_func_set_dt( ptr_s, BLIS_FLOAT, f ); diff --git a/frame/base/bli_func.h b/frame/base/bli_func.h index 0f927ad81..a820d0b7e 100644 --- a/frame/base/bli_func.h +++ b/frame/base/bli_func.h @@ -36,7 +36,7 @@ // func_t query -static void* bli_func_get_dt +static void_fp bli_func_get_dt ( num_t dt, func_t* func @@ -49,7 +49,7 @@ static void* bli_func_get_dt static void bli_func_set_dt ( - void* fp, + void_fp fp, num_t dt, func_t* func ) @@ -63,7 +63,7 @@ static void bli_func_copy_dt num_t dt_dst, func_t* func_dst ) { - void* fp = bli_func_get_dt( dt_src, func_src ); + void_fp fp = bli_func_get_dt( dt_src, func_src ); bli_func_set_dt( fp, dt_dst, func_dst ); } @@ -72,19 +72,19 @@ static void bli_func_copy_dt func_t* bli_func_create ( - void* ptr_s, - void* ptr_d, - void* ptr_c, - void* ptr_z + void_fp ptr_s, + void_fp ptr_d, + void_fp ptr_c, + void_fp ptr_z ); void bli_func_init ( func_t* f, - void* ptr_s, - void* ptr_d, - void* ptr_c, - void* ptr_z + void_fp ptr_s, + void_fp ptr_d, + void_fp ptr_c, + void_fp ptr_z ); void bli_func_init_null diff --git a/frame/base/bli_getopt.h b/frame/base/bli_getopt.h index 215df82f7..1b5a7a002 100644 --- a/frame/base/bli_getopt.h +++ b/frame/base/bli_getopt.h @@ -40,7 +40,7 @@ typedef struct getopt_s int optopt; } getopt_t; -void bli_getopt_init_state( int opterr, getopt_t* state ); +BLIS_EXPORT_BLIS void bli_getopt_init_state( int opterr, getopt_t* state ); -int bli_getopt( int argc, char** const argv, const char* optstring, getopt_t* state ); +BLIS_EXPORT_BLIS int bli_getopt( int argc, char** const argv, const char* optstring, getopt_t* state ); diff --git a/frame/base/bli_gks.c b/frame/base/bli_gks.c index 80b239182..624163f13 100644 --- a/frame/base/bli_gks.c +++ b/frame/base/bli_gks.c @@ -41,11 +41,11 @@ static cntx_t** gks[ BLIS_NUM_ARCHS ]; // The array of function pointers holding the registered context initialization // functions for induced methods. -static void* cntx_ind_init[ BLIS_NUM_ARCHS ]; +static void_fp cntx_ind_init[ BLIS_NUM_ARCHS ]; // The array of function pointers holding the registered context initialization // functions for reference kernels. -static void* cntx_ref_init[ BLIS_NUM_ARCHS ]; +static void_fp cntx_ref_init[ BLIS_NUM_ARCHS ]; // Define a function pointer type for context initialization functions. typedef void (*nat_cntx_init_ft)( cntx_t* cntx ); @@ -240,7 +240,7 @@ void bli_gks_init_index( void ) // architecture id elements of the internal arrays to NULL. const size_t gks_size = sizeof( cntx_t* ) * BLIS_NUM_ARCHS; - const size_t fpa_size = sizeof( void* ) * BLIS_NUM_ARCHS; + const size_t fpa_size = sizeof( void_fp ) * BLIS_NUM_ARCHS; // Set every entry in gks and context init function pointer arrays to // zero/NULL. This is done so that later on we know which ones were @@ -297,10 +297,10 @@ cntx_t* bli_gks_lookup_ind_cntx void bli_gks_register_cntx ( - arch_t id, - void* nat_fp, - void* ref_fp, - void* ind_fp + arch_t id, + void_fp nat_fp, + void_fp ref_fp, + void_fp ind_fp ) { // This function is called by bli_gks_init() for each architecture that @@ -590,8 +590,8 @@ bool_t bli_gks_cntx_l3_nat_ukr_is_ref // Query each context for the micro-kernel function pointer for the // specified datatype. - void* ref_fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr_id, &ref_cntx ); - void* fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr_id, cntx ); + void_fp ref_fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr_id, &ref_cntx ); + void_fp fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr_id, cntx ); // Return the result. return fp == ref_fp; @@ -619,7 +619,7 @@ char* bli_gks_l3_ukr_impl_string( l3ukr_t ukr, ind_t method, num_t dt ) // then query the ukernel function pointer for the given datatype from // that context. cntx_t* cntx = bli_gks_query_ind_cntx( method, dt ); - void* fp = bli_cntx_get_l3_vir_ukr_dt( dt, ukr, cntx ); + void_fp fp = bli_cntx_get_l3_vir_ukr_dt( dt, ukr, cntx ); // Check whether the ukernel function pointer is NULL for the given // datatype. If it is NULL, return the string for not applicable. @@ -698,8 +698,8 @@ kimpl_t bli_gks_l3_ukr_impl_type( l3ukr_t ukr, ind_t method, num_t dt ) // Query the native ukernel func_t from both the native and reference // contexts. - void* nat_fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr, nat_cntx ); - void* ref_fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr, &ref_cntx_l ); + void_fp nat_fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr, nat_cntx ); + void_fp ref_fp = bli_cntx_get_l3_nat_ukr_dt( dt, ukr, &ref_cntx_l ); if ( nat_fp == ref_fp ) return BLIS_REFERENCE_UKERNEL; else return BLIS_OPTIMIZED_UKERNEL; diff --git a/frame/base/bli_gks.h b/frame/base/bli_gks.h index a87a07136..b84dcfc97 100644 --- a/frame/base/bli_gks.h +++ b/frame/base/bli_gks.h @@ -35,10 +35,6 @@ #ifndef BLIS_GKS_H #define BLIS_GKS_H -arch_t bli_arch_query_id( void ); - -// ----------------------------------------------------------------------------- - void bli_gks_init( void ); void bli_gks_finalize( void ); @@ -46,22 +42,23 @@ void bli_gks_init_index( void ); cntx_t* bli_gks_lookup_nat_cntx( arch_t id ); cntx_t* bli_gks_lookup_ind_cntx( arch_t id, ind_t ind ); -void bli_gks_register_cntx( arch_t id, void* nat_fp, void* ref_fp, void* ind_fp ); +void bli_gks_register_cntx( arch_t id, void_fp nat_fp, void_fp ref_fp, void_fp ind_fp ); + +BLIS_EXPORT_BLIS cntx_t* bli_gks_query_cntx( void ); +BLIS_EXPORT_BLIS cntx_t* bli_gks_query_nat_cntx( void ); -cntx_t* bli_gks_query_cntx( void ); -cntx_t* bli_gks_query_nat_cntx( void ); cntx_t* bli_gks_query_cntx_noinit( void ); -cntx_t* bli_gks_query_ind_cntx( ind_t ind, num_t dt ); -void bli_gks_init_ref_cntx( cntx_t* cntx ); +BLIS_EXPORT_BLIS cntx_t* bli_gks_query_ind_cntx( ind_t ind, num_t dt ); -bool_t bli_gks_cntx_l3_nat_ukr_is_ref( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ); +BLIS_EXPORT_BLIS void bli_gks_init_ref_cntx( cntx_t* cntx ); -char* bli_gks_l3_ukr_impl_string( l3ukr_t ukr, ind_t method, num_t dt ); -kimpl_t bli_gks_l3_ukr_impl_type( l3ukr_t ukr, ind_t method, num_t dt ); +bool_t bli_gks_cntx_l3_nat_ukr_is_ref( num_t dt, l3ukr_t ukr_id, cntx_t* cntx ); + +BLIS_EXPORT_BLIS char* bli_gks_l3_ukr_impl_string( l3ukr_t ukr, ind_t method, num_t dt ); +BLIS_EXPORT_BLIS kimpl_t bli_gks_l3_ukr_impl_type( l3ukr_t ukr, ind_t method, num_t dt ); //char* bli_gks_l3_ukr_avail_impl_string( l3ukr_t ukr, num_t dt ); - #endif diff --git a/frame/base/bli_info.h b/frame/base/bli_info.h index 2addc0c55..8f7869e51 100644 --- a/frame/base/bli_info.h +++ b/frame/base/bli_info.h @@ -36,37 +36,37 @@ // -- General library information ---------------------------------------------- -char* bli_info_get_version_str( void ); -char* bli_info_get_int_type_size_str( void ); +BLIS_EXPORT_BLIS char* bli_info_get_version_str( void ); +BLIS_EXPORT_BLIS char* bli_info_get_int_type_size_str( void ); // -- General configuration-related -------------------------------------------- -gint_t bli_info_get_int_type_size( void ); -gint_t bli_info_get_num_fp_types( void ); -gint_t bli_info_get_max_type_size( void ); -gint_t bli_info_get_page_size( void ); -gint_t bli_info_get_simd_num_registers( void ); -gint_t bli_info_get_simd_size( void ); -gint_t bli_info_get_simd_align_size( void ); -gint_t bli_info_get_stack_buf_max_size( void ); -gint_t bli_info_get_stack_buf_align_size( void ); -gint_t bli_info_get_heap_addr_align_size( void ); -gint_t bli_info_get_heap_stride_align_size( void ); -gint_t bli_info_get_pool_addr_align_size( void ); -gint_t bli_info_get_enable_stay_auto_init( void ); -gint_t bli_info_get_enable_blas( void ); -gint_t bli_info_get_enable_cblas( void ); -gint_t bli_info_get_blas_int_type_size( void ); -gint_t bli_info_get_enable_pba_pools( void ); -gint_t bli_info_get_enable_sba_pools( void ); -gint_t bli_info_get_enable_threading( void ); -gint_t bli_info_get_enable_openmp( void ); -gint_t bli_info_get_enable_pthreads( void ); -gint_t bli_info_get_thread_part_jrir_slab( void ); -gint_t bli_info_get_thread_part_jrir_rr( void ); -gint_t bli_info_get_enable_memkind( void ); -gint_t bli_info_get_enable_sandbox( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_int_type_size( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_num_fp_types( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_max_type_size( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_page_size( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_simd_num_registers( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_simd_size( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_simd_align_size( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_stack_buf_max_size( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_stack_buf_align_size( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_heap_addr_align_size( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_heap_stride_align_size( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_pool_addr_align_size( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_enable_stay_auto_init( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_enable_blas( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_enable_cblas( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_blas_int_type_size( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_enable_pba_pools( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_enable_sba_pools( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_enable_threading( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_enable_openmp( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_enable_pthreads( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_thread_part_jrir_slab( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_thread_part_jrir_rr( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_enable_memkind( void ); +BLIS_EXPORT_BLIS gint_t bli_info_get_enable_sandbox( void ); // -- Kernel implementation-related -------------------------------------------- @@ -74,23 +74,23 @@ gint_t bli_info_get_enable_sandbox( void ); // -- Level-3 kernel definitions -- -char* bli_info_get_gemm_ukr_impl_string( ind_t method, num_t dt ); -char* bli_info_get_gemmtrsm_l_ukr_impl_string( ind_t method, num_t dt ); -char* bli_info_get_gemmtrsm_u_ukr_impl_string( ind_t method, num_t dt ); -char* bli_info_get_trsm_l_ukr_impl_string( ind_t method, num_t dt ); -char* bli_info_get_trsm_u_ukr_impl_string( ind_t method, num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_gemm_ukr_impl_string( ind_t method, num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_gemmtrsm_l_ukr_impl_string( ind_t method, num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_gemmtrsm_u_ukr_impl_string( ind_t method, num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_trsm_l_ukr_impl_string( ind_t method, num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_trsm_u_ukr_impl_string( ind_t method, num_t dt ); // -- BLIS implementation query (level-3) -------------------------------------- -char* bli_info_get_gemm_impl_string( num_t dt ); -char* bli_info_get_hemm_impl_string( num_t dt ); -char* bli_info_get_herk_impl_string( num_t dt ); -char* bli_info_get_her2k_impl_string( num_t dt ); -char* bli_info_get_symm_impl_string( num_t dt ); -char* bli_info_get_syrk_impl_string( num_t dt ); -char* bli_info_get_syr2k_impl_string( num_t dt ); -char* bli_info_get_trmm_impl_string( num_t dt ); -char* bli_info_get_trmm3_impl_string( num_t dt ); -char* bli_info_get_trsm_impl_string( num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_gemm_impl_string( num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_hemm_impl_string( num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_herk_impl_string( num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_her2k_impl_string( num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_symm_impl_string( num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_syrk_impl_string( num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_syr2k_impl_string( num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_trmm_impl_string( num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_trmm3_impl_string( num_t dt ); +BLIS_EXPORT_BLIS char* bli_info_get_trsm_impl_string( num_t dt ); diff --git a/frame/base/bli_init.h b/frame/base/bli_init.h index b37a8e342..f174ac0f9 100644 --- a/frame/base/bli_init.h +++ b/frame/base/bli_init.h @@ -32,15 +32,15 @@ */ -void bli_init( void ); -void bli_finalize( void ); +BLIS_EXPORT_BLIS void bli_init( void ); +BLIS_EXPORT_BLIS void bli_finalize( void ); -void bli_init_auto( void ); -void bli_finalize_auto( void ); +void bli_init_auto( void ); +void bli_finalize_auto( void ); -void bli_init_apis( void ); -void bli_finalize_apis( void ); +void bli_init_apis( void ); +void bli_finalize_apis( void ); -void bli_init_once( void ); -void bli_finalize_once( void ); +void bli_init_once( void ); +void bli_finalize_once( void ); diff --git a/frame/base/bli_machval.h b/frame/base/bli_machval.h index 07606da3f..25177a250 100644 --- a/frame/base/bli_machval.h +++ b/frame/base/bli_machval.h @@ -39,8 +39,7 @@ // // Prototype object-based interface. // -void bli_machval( machval_t mval, - obj_t* v ); +BLIS_EXPORT_BLIS void bli_machval( machval_t mval, obj_t* v ); // @@ -49,7 +48,7 @@ void bli_machval( machval_t mval, #undef GENTPROTR #define GENTPROTR( ctype_v, ctype_vr, chv, chvr, opname ) \ \ -void PASTEMAC(chv,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(chv,opname) \ ( \ machval_t mval, \ void* v \ diff --git a/frame/base/bli_malloc.h b/frame/base/bli_malloc.h index 226f732f5..2659a81fa 100644 --- a/frame/base/bli_malloc.h +++ b/frame/base/bli_malloc.h @@ -40,16 +40,16 @@ typedef void (*free_ft) ( void* p ); // ----------------------------------------------------------------------------- #if 0 -void* bli_malloc_pool( size_t size ); -void bli_free_pool( void* p ); +BLIS_EXPORT_BLIS void* bli_malloc_pool( size_t size ); +BLIS_EXPORT_BLIS void bli_free_pool( void* p ); #endif void* bli_malloc_intl( size_t size ); void* bli_calloc_intl( size_t size ); void bli_free_intl( void* p ); -void* bli_malloc_user( size_t size ); -void bli_free_user( void* p ); +BLIS_EXPORT_BLIS void* bli_malloc_user( size_t size ); +BLIS_EXPORT_BLIS void bli_free_user( void* p ); // ----------------------------------------------------------------------------- diff --git a/frame/base/bli_obj.c b/frame/base/bli_obj.c index 2a9b3786e..f2b59e180 100644 --- a/frame/base/bli_obj.c +++ b/frame/base/bli_obj.c @@ -405,7 +405,8 @@ void bli_adjust_strides // matrix). if ( m == 0 || n == 0 ) return; - // Interpret rs = cs = 0 as request for column storage. + // Interpret rs = cs = 0 as request for column storage and -1 as a request + // for row storage. if ( *rs == 0 && *cs == 0 && ( *is == 0 || *is == 1 ) ) { // First we handle the 1x1 scalar case explicitly. @@ -414,8 +415,9 @@ void bli_adjust_strides *rs = 1; *cs = 1; } - // We use column-major storage, except when m == 1, because we don't - // want both strides to be unit. + // We use column-major storage, except when m == 1, in which case we + // use what amounts to row-major storage because we don't want both + // strides to be unit. else if ( m == 1 && n > 1 ) { *rs = n; @@ -445,6 +447,46 @@ void bli_adjust_strides BLIS_HEAP_STRIDE_ALIGN_SIZE ); } } + else if ( *rs == -1 && *cs == -1 && ( *is == 0 || *is == 1 ) ) + { + // First we handle the 1x1 scalar case explicitly. + if ( m == 1 && n == 1 ) + { + *rs = 1; + *cs = 1; + } + // We use row-major storage, except when n == 1, in which case we + // use what amounts to column-major storage because we don't want both + // strides to be unit. + else if ( n == 1 && m > 1 ) + { + *rs = 1; + *cs = m; + } + else + { + *rs = n; + *cs = 1; + } + + // Use default complex storage. + *is = 1; + + // Align the strides depending on the tilt of the matrix. Note that + // scalars are neither row nor column tilted. Also note that alignment + // is only done for rs = cs = -1, and any user-supplied row and column + // strides are preserved. + if ( bli_is_col_tilted( m, n, *rs, *cs ) ) + { + *cs = bli_align_dim_to_size( *cs, elem_size, + BLIS_HEAP_STRIDE_ALIGN_SIZE ); + } + else if ( bli_is_row_tilted( m, n, *rs, *cs ) ) + { + *rs = bli_align_dim_to_size( *rs, elem_size, + BLIS_HEAP_STRIDE_ALIGN_SIZE ); + } + } else if ( *rs == 1 && *cs == 1 ) { // If both strides are unit, this is probably a "lazy" request for a diff --git a/frame/base/bli_obj.h b/frame/base/bli_obj.h index 69f1aaa80..4436d2cd8 100644 --- a/frame/base/bli_obj.h +++ b/frame/base/bli_obj.h @@ -34,7 +34,7 @@ #include "bli_obj_check.h" -void bli_obj_create +BLIS_EXPORT_BLIS void bli_obj_create ( num_t dt, dim_t m, @@ -44,7 +44,7 @@ void bli_obj_create obj_t* obj ); -void bli_obj_create_with_attached_buffer +BLIS_EXPORT_BLIS void bli_obj_create_with_attached_buffer ( num_t dt, dim_t m, @@ -55,7 +55,7 @@ void bli_obj_create_with_attached_buffer obj_t* obj ); -void bli_obj_create_without_buffer +BLIS_EXPORT_BLIS void bli_obj_create_without_buffer ( num_t dt, dim_t m, @@ -63,7 +63,7 @@ void bli_obj_create_without_buffer obj_t* obj ); -void bli_obj_alloc_buffer +BLIS_EXPORT_BLIS void bli_obj_alloc_buffer ( inc_t rs, inc_t cs, @@ -71,7 +71,7 @@ void bli_obj_alloc_buffer obj_t* obj ); -void bli_obj_attach_buffer +BLIS_EXPORT_BLIS void bli_obj_attach_buffer ( void* p, inc_t rs, @@ -80,26 +80,26 @@ void bli_obj_attach_buffer obj_t* obj ); -void bli_obj_create_1x1 +BLIS_EXPORT_BLIS void bli_obj_create_1x1 ( num_t dt, obj_t* obj ); -void bli_obj_create_1x1_with_attached_buffer +BLIS_EXPORT_BLIS void bli_obj_create_1x1_with_attached_buffer ( num_t dt, void* p, obj_t* obj ); -void bli_obj_create_conf_to +BLIS_EXPORT_BLIS void bli_obj_create_conf_to ( obj_t* s, obj_t* d ); -void bli_obj_free +BLIS_EXPORT_BLIS void bli_obj_free ( obj_t* obj ); @@ -114,36 +114,36 @@ void bli_adjust_strides inc_t* is ); -siz_t bli_dt_size +BLIS_EXPORT_BLIS siz_t bli_dt_size ( num_t dt ); -char* bli_dt_string +BLIS_EXPORT_BLIS char* bli_dt_string ( num_t dt ); -dim_t bli_align_dim_to_mult +BLIS_EXPORT_BLIS dim_t bli_align_dim_to_mult ( dim_t dim, dim_t dim_mult ); -dim_t bli_align_dim_to_size +BLIS_EXPORT_BLIS dim_t bli_align_dim_to_size ( dim_t dim, siz_t elem_size, siz_t align_size ); -dim_t bli_align_ptr_to_size +BLIS_EXPORT_BLIS dim_t bli_align_ptr_to_size ( void* p, size_t align_size ); -void bli_obj_print +BLIS_EXPORT_BLIS void bli_obj_print ( char* label, obj_t* obj diff --git a/frame/base/bli_obj_scalar.h b/frame/base/bli_obj_scalar.h index ba890d5b7..f655ff46e 100644 --- a/frame/base/bli_obj_scalar.h +++ b/frame/base/bli_obj_scalar.h @@ -32,13 +32,13 @@ */ -void bli_obj_scalar_init_detached +BLIS_EXPORT_BLIS void bli_obj_scalar_init_detached ( num_t dt, obj_t* beta ); -void bli_obj_scalar_init_detached_copy_of +BLIS_EXPORT_BLIS void bli_obj_scalar_init_detached_copy_of ( num_t dt, conj_t conj, @@ -46,42 +46,42 @@ void bli_obj_scalar_init_detached_copy_of obj_t* beta ); -void bli_obj_scalar_detach +BLIS_EXPORT_BLIS void bli_obj_scalar_detach ( obj_t* a, obj_t* alpha ); -void bli_obj_scalar_attach +BLIS_EXPORT_BLIS void bli_obj_scalar_attach ( conj_t conj, obj_t* alpha, obj_t* a ); -void bli_obj_scalar_cast_to +BLIS_EXPORT_BLIS void bli_obj_scalar_cast_to ( num_t dt, obj_t* a ); -void bli_obj_scalar_apply_scalar +BLIS_EXPORT_BLIS void bli_obj_scalar_apply_scalar ( obj_t* alpha, obj_t* a ); -void bli_obj_scalar_reset +BLIS_EXPORT_BLIS void bli_obj_scalar_reset ( obj_t* a ); -bool_t bli_obj_scalar_has_nonzero_imag +BLIS_EXPORT_BLIS bool_t bli_obj_scalar_has_nonzero_imag ( obj_t* a ); -bool_t bli_obj_scalar_equals +BLIS_EXPORT_BLIS bool_t bli_obj_scalar_equals ( obj_t* a, obj_t* beta diff --git a/frame/base/bli_param_map.c b/frame/base/bli_param_map.c index de877f686..d20eece43 100644 --- a/frame/base/bli_param_map.c +++ b/frame/base/bli_param_map.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -98,61 +99,8 @@ void bli_param_map_blis_to_netlib_machval( machval_t machval, char* blas_machval // --- BLAS/LAPACK to BLIS mappings -------------------------------------------- -void bli_param_map_netlib_to_blis_side( char side, side_t* blis_side ) -{ - if ( side == 'l' || side == 'L' ) *blis_side = BLIS_LEFT; - else if ( side == 'r' || side == 'R' ) *blis_side = BLIS_RIGHT; - else - { - // Instead of reporting an error to the framework, default to - // an arbitrary value. This is needed because this function is - // called by the BLAS compatibility layer AFTER it has already - // checked errors and called xerbla(). If the application wants - // to override the BLAS compatibility layer's xerbla--which - // responds to errors with abort()--we need to also NOT call - // abort() here, since either way it has already been dealt - // with. - //bli_check_error_code( BLIS_INVALID_SIDE ); - *blis_side = BLIS_LEFT; - } -} - -void bli_param_map_netlib_to_blis_uplo( char uplo, uplo_t* blis_uplo ) -{ - if ( uplo == 'l' || uplo == 'L' ) *blis_uplo = BLIS_LOWER; - else if ( uplo == 'u' || uplo == 'U' ) *blis_uplo = BLIS_UPPER; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_UPLO ); - *blis_uplo = BLIS_LOWER; - } -} - -void bli_param_map_netlib_to_blis_trans( char trans, trans_t* blis_trans ) -{ - if ( trans == 'n' || trans == 'N' ) *blis_trans = BLIS_NO_TRANSPOSE; - else if ( trans == 't' || trans == 'T' ) *blis_trans = BLIS_TRANSPOSE; - else if ( trans == 'c' || trans == 'C' ) *blis_trans = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_TRANS ); - *blis_trans = BLIS_NO_TRANSPOSE; - } -} - -void bli_param_map_netlib_to_blis_diag( char diag, diag_t* blis_diag ) -{ - if ( diag == 'n' || diag == 'N' ) *blis_diag = BLIS_NONUNIT_DIAG; - else if ( diag == 'u' || diag == 'U' ) *blis_diag = BLIS_UNIT_DIAG; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_DIAG ); - *blis_diag = BLIS_NONUNIT_DIAG; - } -} +// NOTE: These functions were converted into static functions. Please see this +// file's corresponding header for those definitions. // --- BLIS char to BLIS mappings ---------------------------------------------- diff --git a/frame/base/bli_param_map.h b/frame/base/bli_param_map.h index 829fe808c..ac23684fe 100644 --- a/frame/base/bli_param_map.h +++ b/frame/base/bli_param_map.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,37 +36,91 @@ // --- BLIS to BLAS/LAPACK mappings -------------------------------------------- -void bli_param_map_blis_to_netlib_side( side_t side, char* blas_side ); -void bli_param_map_blis_to_netlib_uplo( uplo_t uplo, char* blas_uplo ); -void bli_param_map_blis_to_netlib_trans( trans_t trans, char* blas_trans ); -void bli_param_map_blis_to_netlib_diag( diag_t diag, char* blas_diag ); -void bli_param_map_blis_to_netlib_machval( machval_t machval, char* blas_machval ); +BLIS_EXPORT_BLIS void bli_param_map_blis_to_netlib_side( side_t side, char* blas_side ); +BLIS_EXPORT_BLIS void bli_param_map_blis_to_netlib_uplo( uplo_t uplo, char* blas_uplo ); +BLIS_EXPORT_BLIS void bli_param_map_blis_to_netlib_trans( trans_t trans, char* blas_trans ); +BLIS_EXPORT_BLIS void bli_param_map_blis_to_netlib_diag( diag_t diag, char* blas_diag ); +BLIS_EXPORT_BLIS void bli_param_map_blis_to_netlib_machval( machval_t machval, char* blas_machval ); // --- BLAS/LAPACK to BLIS mappings -------------------------------------------- -void bli_param_map_netlib_to_blis_side( char side, side_t* blis_side ); -void bli_param_map_netlib_to_blis_uplo( char uplo, uplo_t* blis_uplo ); -void bli_param_map_netlib_to_blis_trans( char trans, trans_t* blis_trans ); -void bli_param_map_netlib_to_blis_diag( char diag, diag_t* blis_diag ); +// NOTE: These static functions were converted from regular functions in order +// to reduce function call overhead within the BLAS compatibility layer. + +static void bli_param_map_netlib_to_blis_side( char side, side_t* blis_side ) +{ + if ( side == 'l' || side == 'L' ) *blis_side = BLIS_LEFT; + else if ( side == 'r' || side == 'R' ) *blis_side = BLIS_RIGHT; + else + { + // Instead of reporting an error to the framework, default to + // an arbitrary value. This is needed because this function is + // called by the BLAS compatibility layer AFTER it has already + // checked errors and called xerbla(). If the application wants + // to override the BLAS compatibility layer's xerbla--which + // responds to errors with abort()--we need to also NOT call + // abort() here, since either way it has already been dealt + // with. + //bli_check_error_code( BLIS_INVALID_SIDE ); + *blis_side = BLIS_LEFT; + } +} + +static void bli_param_map_netlib_to_blis_uplo( char uplo, uplo_t* blis_uplo ) +{ + if ( uplo == 'l' || uplo == 'L' ) *blis_uplo = BLIS_LOWER; + else if ( uplo == 'u' || uplo == 'U' ) *blis_uplo = BLIS_UPPER; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_UPLO ); + *blis_uplo = BLIS_LOWER; + } +} + +static void bli_param_map_netlib_to_blis_trans( char trans, trans_t* blis_trans ) +{ + if ( trans == 'n' || trans == 'N' ) *blis_trans = BLIS_NO_TRANSPOSE; + else if ( trans == 't' || trans == 'T' ) *blis_trans = BLIS_TRANSPOSE; + else if ( trans == 'c' || trans == 'C' ) *blis_trans = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_TRANS ); + *blis_trans = BLIS_NO_TRANSPOSE; + } +} + +static void bli_param_map_netlib_to_blis_diag( char diag, diag_t* blis_diag ) +{ + if ( diag == 'n' || diag == 'N' ) *blis_diag = BLIS_NONUNIT_DIAG; + else if ( diag == 'u' || diag == 'U' ) *blis_diag = BLIS_UNIT_DIAG; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_DIAG ); + *blis_diag = BLIS_NONUNIT_DIAG; + } +} // --- BLIS char to BLIS mappings ---------------------------------------------- -void bli_param_map_char_to_blis_side( char side, side_t* blis_side ); -void bli_param_map_char_to_blis_uplo( char uplo, uplo_t* blis_uplo ); -void bli_param_map_char_to_blis_trans( char trans, trans_t* blis_trans ); -void bli_param_map_char_to_blis_conj( char conj, conj_t* blis_conj ); -void bli_param_map_char_to_blis_diag( char diag, diag_t* blis_diag ); -void bli_param_map_char_to_blis_dt( char dt, num_t* blis_dt ); +BLIS_EXPORT_BLIS void bli_param_map_char_to_blis_side( char side, side_t* blis_side ); +BLIS_EXPORT_BLIS void bli_param_map_char_to_blis_uplo( char uplo, uplo_t* blis_uplo ); +BLIS_EXPORT_BLIS void bli_param_map_char_to_blis_trans( char trans, trans_t* blis_trans ); +BLIS_EXPORT_BLIS void bli_param_map_char_to_blis_conj( char conj, conj_t* blis_conj ); +BLIS_EXPORT_BLIS void bli_param_map_char_to_blis_diag( char diag, diag_t* blis_diag ); +BLIS_EXPORT_BLIS void bli_param_map_char_to_blis_dt( char dt, num_t* blis_dt ); // --- BLIS to BLIS char mappings ---------------------------------------------- -void bli_param_map_blis_to_char_side( side_t blis_side, char* side ); -void bli_param_map_blis_to_char_uplo( uplo_t blis_uplo, char* uplo ); -void bli_param_map_blis_to_char_trans( trans_t blis_trans, char* trans ); -void bli_param_map_blis_to_char_conj( conj_t blis_conj, char* conj ); -void bli_param_map_blis_to_char_diag( diag_t blis_diag, char* diag ); -void bli_param_map_blis_to_char_dt( num_t blis_dt, char* dt ); +BLIS_EXPORT_BLIS void bli_param_map_blis_to_char_side( side_t blis_side, char* side ); +BLIS_EXPORT_BLIS void bli_param_map_blis_to_char_uplo( uplo_t blis_uplo, char* uplo ); +BLIS_EXPORT_BLIS void bli_param_map_blis_to_char_trans( trans_t blis_trans, char* trans ); +BLIS_EXPORT_BLIS void bli_param_map_blis_to_char_conj( conj_t blis_conj, char* conj ); +BLIS_EXPORT_BLIS void bli_param_map_blis_to_char_diag( diag_t blis_diag, char* diag ); +BLIS_EXPORT_BLIS void bli_param_map_blis_to_char_dt( num_t blis_dt, char* dt ); diff --git a/frame/base/bli_part.h b/frame/base/bli_part.h index 81232bccd..5e56a9fec 100644 --- a/frame/base/bli_part.h +++ b/frame/base/bli_part.h @@ -36,7 +36,7 @@ // -- Matrix partitioning ------------------------------------------------------ -void bli_acquire_mpart +BLIS_EXPORT_BLIS void bli_acquire_mpart ( dim_t i, dim_t j, @@ -49,7 +49,7 @@ void bli_acquire_mpart #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0( opname ) \ +BLIS_EXPORT_BLIS void PASTEMAC0( opname ) \ ( \ subpart_t req_part, \ dim_t i, \ @@ -69,7 +69,7 @@ GENPROT( acquire_mpart_br2tl ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0( opname ) \ +BLIS_EXPORT_BLIS void PASTEMAC0( opname ) \ ( \ dir_t direct, \ subpart_t req_part, \ @@ -89,7 +89,7 @@ GENPROT( acquire_mpart_mndim ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC0( opname ) \ +BLIS_EXPORT_BLIS void PASTEMAC0( opname ) \ ( \ subpart_t req_part, \ dim_t i, \ @@ -103,7 +103,7 @@ GENPROT( acquire_vpart_b2f ) // -- Scalar acquisition ------------------------------------------------------- -void bli_acquire_mij +BLIS_EXPORT_BLIS void bli_acquire_mij ( dim_t i, dim_t j, @@ -111,7 +111,7 @@ void bli_acquire_mij obj_t* sub_obj ); -void bli_acquire_vi +BLIS_EXPORT_BLIS void bli_acquire_vi ( dim_t i, obj_t* obj, diff --git a/frame/base/bli_query.h b/frame/base/bli_query.h index 94274f1cd..2bb5b3f6b 100644 --- a/frame/base/bli_query.h +++ b/frame/base/bli_query.h @@ -32,10 +32,8 @@ */ -bool_t bli_obj_equals( obj_t* a, - obj_t* b ); +BLIS_EXPORT_BLIS bool_t bli_obj_equals( obj_t* a, obj_t* b ); -bool_t bli_obj_imag_equals( obj_t* a, - obj_t* b ); +BLIS_EXPORT_BLIS bool_t bli_obj_imag_equals( obj_t* a, obj_t* b ); -bool_t bli_obj_imag_is_zero( obj_t* a ); +BLIS_EXPORT_BLIS bool_t bli_obj_imag_is_zero( obj_t* a ); diff --git a/frame/base/bli_rntm.h b/frame/base/bli_rntm.h index 4e8e74af8..c07686414 100644 --- a/frame/base/bli_rntm.h +++ b/frame/base/bli_rntm.h @@ -45,6 +45,13 @@ typedef struct rntm_s { dim_t num_threads; dim_t* thrloop; + + pool_t* sba_pool; + + membrk_t* membrk; + + bool_t l3_sup; + } rntm_t; */ @@ -87,6 +94,11 @@ static dim_t bli_rntm_pr_ways( rntm_t* rntm ) return bli_rntm_ways_for( BLIS_KR, rntm ); } +static bool_t bli_rntm_l3_sup( rntm_t* rntm ) +{ + return rntm->l3_sup; +} + // // -- rntm_t query (internal use only) ----------------------------------------- // @@ -175,6 +187,22 @@ static void bli_rntm_set_membrk( membrk_t* membrk, rntm_t* rntm ) rntm->membrk = membrk; } +static void bli_rntm_set_l3_sup( bool_t l3_sup, rntm_t* rntm ) +{ + // Set the bool_t indicating whether level-3 sup handling is enabled. + rntm->l3_sup = l3_sup; +} + +static void bli_rntm_enable_l3_sup( rntm_t* rntm ) +{ + bli_rntm_set_l3_sup( TRUE, rntm ); +} + +static void bli_rntm_disable_l3_sup( rntm_t* rntm ) +{ + bli_rntm_set_l3_sup( FALSE, rntm ); +} + static void bli_rntm_clear_num_threads_only( rntm_t* rntm ) { bli_rntm_set_num_threads_only( -1, rntm ); @@ -187,6 +215,14 @@ static void bli_rntm_clear_sba_pool( rntm_t* rntm ) { bli_rntm_set_sba_pool( NULL, rntm ); } +static void bli_rntm_clear_membrk( rntm_t* rntm ) +{ + bli_rntm_set_membrk( NULL, rntm ); +} +static void bli_rntm_clear_l3_sup( rntm_t* rntm ) +{ + bli_rntm_set_l3_sup( 1, rntm ); +} // // -- rntm_t modification (public API) ----------------------------------------- @@ -223,9 +259,14 @@ static void bli_rntm_set_ways( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_t ir, // of the public "set" accessors, each of which guarantees that the rntm_t // will be in a good state upon return. -#define BLIS_RNTM_INITIALIZER { .num_threads = -1, \ - .thrloop = { -1, -1, -1, -1, -1, -1 }, \ - .sba_pool = NULL } \ +#define BLIS_RNTM_INITIALIZER \ + { \ + .num_threads = -1, \ + .thrloop = { -1, -1, -1, -1, -1, -1 }, \ + .sba_pool = NULL, \ + .membrk = NULL, \ + .l3_sup = 1 \ + } \ static void bli_rntm_init( rntm_t* rntm ) { @@ -233,13 +274,16 @@ static void bli_rntm_init( rntm_t* rntm ) bli_rntm_clear_ways_only( rntm ); bli_rntm_clear_sba_pool( rntm ); + bli_rntm_clear_membrk( rntm ); + + bli_rntm_clear_l3_sup( rntm ); } // ----------------------------------------------------------------------------- // Function prototypes -void bli_rntm_set_ways_for_op +BLIS_EXPORT_BLIS void bli_rntm_set_ways_for_op ( opid_t l3_op, side_t side, diff --git a/frame/base/bli_setgetij.h b/frame/base/bli_setgetij.h index 9478bd76d..55ce0ee11 100644 --- a/frame/base/bli_setgetij.h +++ b/frame/base/bli_setgetij.h @@ -32,7 +32,7 @@ */ -err_t bli_setijm +BLIS_EXPORT_BLIS err_t bli_setijm ( double ar, double ai, @@ -44,7 +44,7 @@ err_t bli_setijm #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ double ar, \ double ai, \ @@ -57,7 +57,7 @@ INSERT_GENTPROT_BASIC0( setijm ) // ----------------------------------------------------------------------------- -err_t bli_getijm +BLIS_EXPORT_BLIS err_t bli_getijm ( dim_t i, dim_t j, @@ -69,7 +69,7 @@ err_t bli_getijm #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ dim_t i, \ dim_t j, \ diff --git a/frame/base/bli_setri.h b/frame/base/bli_setri.h index 1e124f419..dd6ce9f3f 100644 --- a/frame/base/bli_setri.h +++ b/frame/base/bli_setri.h @@ -34,13 +34,13 @@ // -- setr --------------------------------------------------------------------- -void bli_setrm +BLIS_EXPORT_BLIS void bli_setrm ( obj_t* alpha, obj_t* b ); -void bli_setrv +BLIS_EXPORT_BLIS void bli_setrv ( obj_t* alpha, obj_t* x @@ -48,13 +48,13 @@ void bli_setrv // -- seti --------------------------------------------------------------------- -void bli_setim +BLIS_EXPORT_BLIS void bli_setim ( obj_t* alpha, obj_t* b ); -void bli_setiv +BLIS_EXPORT_BLIS void bli_setiv ( obj_t* alpha, obj_t* x diff --git a/frame/base/bli_winsys.h b/frame/base/bli_winsys.h index 0ad7c408c..0c71114ad 100644 --- a/frame/base/bli_winsys.h +++ b/frame/base/bli_winsys.h @@ -33,5 +33,5 @@ */ //int bli_setenv( const char *name, const char *value, int overwrite ); -void bli_sleep( unsigned int secs ); +BLIS_EXPORT_BLIS void bli_sleep( unsigned int secs ); diff --git a/frame/base/cast/bli_castm.h b/frame/base/cast/bli_castm.h index 5ab13544b..e9e1dee21 100644 --- a/frame/base/cast/bli_castm.h +++ b/frame/base/cast/bli_castm.h @@ -36,7 +36,7 @@ // Prototype object-based interface. // -void bli_castm +BLIS_EXPORT_BLIS void bli_castm ( obj_t* a, obj_t* b @@ -49,7 +49,7 @@ void bli_castm #undef GENTPROT2 #define GENTPROT2( ctype_a, ctype_b, cha, chb, opname ) \ \ -void PASTEMAC2(cha,chb,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(cha,chb,opname) \ ( \ trans_t transa, \ dim_t m, \ diff --git a/frame/base/cast/bli_castnzm.h b/frame/base/cast/bli_castnzm.h index e4e1b1cad..42cfef8c0 100644 --- a/frame/base/cast/bli_castnzm.h +++ b/frame/base/cast/bli_castnzm.h @@ -36,7 +36,7 @@ // Prototype object-based interface. // -void bli_castnzm +BLIS_EXPORT_BLIS void bli_castnzm ( obj_t* a, obj_t* b @@ -49,7 +49,7 @@ void bli_castnzm #undef GENTPROT2 #define GENTPROT2( ctype_a, ctype_b, cha, chb, opname ) \ \ -void PASTEMAC2(cha,chb,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(cha,chb,opname) \ ( \ trans_t transa, \ dim_t m, \ diff --git a/frame/base/cast/bli_castv.h b/frame/base/cast/bli_castv.h index eeb376a89..9a8261514 100644 --- a/frame/base/cast/bli_castv.h +++ b/frame/base/cast/bli_castv.h @@ -36,7 +36,7 @@ // Prototype object-based interface. // -void bli_castv +BLIS_EXPORT_BLIS void bli_castv ( obj_t* x, obj_t* y @@ -49,7 +49,7 @@ void bli_castv #undef GENTPROT2 #define GENTPROT2( ctype_x, ctype_y, chx, chy, opname ) \ \ -void PASTEMAC2(chx,chy,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC2(chx,chy,opname) \ ( \ conj_t conjx, \ dim_t n, \ diff --git a/frame/base/proj/bli_projm.h b/frame/base/proj/bli_projm.h index b34e63dac..e95f7f2f5 100644 --- a/frame/base/proj/bli_projm.h +++ b/frame/base/proj/bli_projm.h @@ -32,7 +32,7 @@ */ -void bli_projm +BLIS_EXPORT_BLIS void bli_projm ( obj_t* a, obj_t* b diff --git a/frame/base/proj/bli_projv.h b/frame/base/proj/bli_projv.h index 7c33d834f..b738b2f97 100644 --- a/frame/base/proj/bli_projv.h +++ b/frame/base/proj/bli_projv.h @@ -32,7 +32,7 @@ */ -void bli_projv +BLIS_EXPORT_BLIS void bli_projv ( obj_t* x, obj_t* y diff --git a/frame/compat/attic/bla_gbmv.h b/frame/compat/attic/bla_gbmv.h index 69b8ea9c5..2e60bbc7c 100644 --- a/frame/compat/attic/bla_gbmv.h +++ b/frame/compat/attic/bla_gbmv.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* transa, \ f77_int* m, \ f77_int* n, \ diff --git a/frame/compat/attic/bla_hbmv.h b/frame/compat/attic/bla_hbmv.h index a6362a4ba..89054809e 100644 --- a/frame/compat/attic/bla_hbmv.h +++ b/frame/compat/attic/bla_hbmv.h @@ -39,7 +39,7 @@ #undef GENTPROTCO #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* uploa, \ f77_int* m, \ f77_int* k, \ diff --git a/frame/compat/attic/bla_hpmv.h b/frame/compat/attic/bla_hpmv.h index fe5b2238a..c58a5ebbf 100644 --- a/frame/compat/attic/bla_hpmv.h +++ b/frame/compat/attic/bla_hpmv.h @@ -39,7 +39,7 @@ #undef GENTPROTCO #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* uploa, \ f77_int* m, \ ftype* alpha, \ diff --git a/frame/compat/attic/bla_hpr.h b/frame/compat/attic/bla_hpr.h index 264cf60fe..b32c939a7 100644 --- a/frame/compat/attic/bla_hpr.h +++ b/frame/compat/attic/bla_hpr.h @@ -39,7 +39,7 @@ #undef GENTPROTCO #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* uploa, \ f77_int* m, \ ftype_r* alpha, \ diff --git a/frame/compat/attic/bla_hpr2.h b/frame/compat/attic/bla_hpr2.h index c288656e0..e62179a5a 100644 --- a/frame/compat/attic/bla_hpr2.h +++ b/frame/compat/attic/bla_hpr2.h @@ -39,7 +39,7 @@ #undef GENTPROTCO #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* uploa, \ f77_int* m, \ ftype* alpha, \ diff --git a/frame/compat/attic/bla_rot.h b/frame/compat/attic/bla_rot.h index 964d7001e..1713ccae2 100644 --- a/frame/compat/attic/bla_rot.h +++ b/frame/compat/attic/bla_rot.h @@ -39,7 +39,7 @@ #undef GENTPROTR2 #define GENTPROTR2( ftype_xy, ftype_r, chxy, chr, blasname ) \ \ -void PASTEF772(chxy,chr,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF772(chxy,chr,blasname)( \ f77_int* n, \ ftype_xy* x, f77_int* incx, \ ftype_xy* y, f77_int* incy, \ diff --git a/frame/compat/attic/bla_rotg.h b/frame/compat/attic/bla_rotg.h index 3104aa5dc..9da266113 100644 --- a/frame/compat/attic/bla_rotg.h +++ b/frame/compat/attic/bla_rotg.h @@ -39,7 +39,7 @@ #undef GENTPROTR #define GENTPROTR( ftype_xy, ftype_r, chxy, chr, blasname ) \ \ -void PASTEF77(chxy,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(chxy,blasname)( \ ftype_xy* x, \ ftype_xy* y, \ ftype_r* c, \ diff --git a/frame/compat/attic/bla_rotm.h b/frame/compat/attic/bla_rotm.h index 77ef4a040..73dc6bec9 100644 --- a/frame/compat/attic/bla_rotm.h +++ b/frame/compat/attic/bla_rotm.h @@ -39,7 +39,7 @@ #undef GENTPROTRO #define GENTPROTRO( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_int* n, \ ftype* x, f77_int* incx, \ ftype* y, f77_int* incy, \ diff --git a/frame/compat/attic/bla_rotmg.h b/frame/compat/attic/bla_rotmg.h index b18c867d7..dba9f6f08 100644 --- a/frame/compat/attic/bla_rotmg.h +++ b/frame/compat/attic/bla_rotmg.h @@ -39,7 +39,7 @@ #undef GENTPROTRO #define GENTPROTRO( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ ftype* d1, \ ftype* d2, \ ftype* x, \ diff --git a/frame/compat/attic/bla_sbmv.h b/frame/compat/attic/bla_sbmv.h index 8e68d701e..e96e88975 100644 --- a/frame/compat/attic/bla_sbmv.h +++ b/frame/compat/attic/bla_sbmv.h @@ -39,7 +39,7 @@ #undef GENTPROTRO #define GENTPROTRO( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* uploa, \ f77_int* m, \ f77_int* k, \ diff --git a/frame/compat/attic/bla_spmv.h b/frame/compat/attic/bla_spmv.h index fb3ce55cf..60c787eb9 100644 --- a/frame/compat/attic/bla_spmv.h +++ b/frame/compat/attic/bla_spmv.h @@ -39,7 +39,7 @@ #undef GENTPROTRO #define GENTPROTRO( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* uploa, \ f77_int* m, \ ftype* alpha, \ diff --git a/frame/compat/attic/bla_spr.h b/frame/compat/attic/bla_spr.h index 097931e82..59407b229 100644 --- a/frame/compat/attic/bla_spr.h +++ b/frame/compat/attic/bla_spr.h @@ -39,7 +39,7 @@ #undef GENTPROTRO #define GENTPROTRO( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* uploa, \ f77_int* m, \ ftype* alpha, \ diff --git a/frame/compat/attic/bla_spr2.h b/frame/compat/attic/bla_spr2.h index 8864523ac..911b4301c 100644 --- a/frame/compat/attic/bla_spr2.h +++ b/frame/compat/attic/bla_spr2.h @@ -39,7 +39,7 @@ #undef GENTPROTRO #define GENTPROTRO( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* uploa, \ f77_int* m, \ ftype* alpha, \ diff --git a/frame/compat/attic/bla_tbmv.h b/frame/compat/attic/bla_tbmv.h index d28b58d12..7a343c362 100644 --- a/frame/compat/attic/bla_tbmv.h +++ b/frame/compat/attic/bla_tbmv.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* uploa, \ f77_char* transa, \ f77_char* diaga, \ diff --git a/frame/compat/attic/bla_tbsv.h b/frame/compat/attic/bla_tbsv.h index a41a15bf1..0837352ae 100644 --- a/frame/compat/attic/bla_tbsv.h +++ b/frame/compat/attic/bla_tbsv.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* uploa, \ f77_char* transa, \ f77_char* diaga, \ diff --git a/frame/compat/attic/bla_tpmv.h b/frame/compat/attic/bla_tpmv.h index 7689befdc..37cd494a5 100644 --- a/frame/compat/attic/bla_tpmv.h +++ b/frame/compat/attic/bla_tpmv.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* uploa, \ f77_char* transa, \ f77_char* diaga, \ diff --git a/frame/compat/attic/bla_tpsv.h b/frame/compat/attic/bla_tpsv.h index 526769062..179fd607d 100644 --- a/frame/compat/attic/bla_tpsv.h +++ b/frame/compat/attic/bla_tpsv.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname)( \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname)( \ f77_char* uploa, \ f77_char* transa, \ f77_char* diaga, \ diff --git a/frame/compat/bla_amax.h b/frame/compat/bla_amax.h index f6e3dd0f6..1f13715dc 100644 --- a/frame/compat/bla_amax.h +++ b/frame/compat/bla_amax.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype_x, chx, blasname ) \ \ -f77_int PASTEF772(i,chx,blasname) \ +BLIS_EXPORT_BLAS f77_int PASTEF772(i,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ diff --git a/frame/compat/bla_asum.h b/frame/compat/bla_asum.h index 036cd1529..a9ef27a03 100644 --- a/frame/compat/bla_asum.h +++ b/frame/compat/bla_asum.h @@ -39,7 +39,7 @@ #undef GENTPROTR2 #define GENTPROTR2( ftype_x, ftype_r, chx, chr, blasname ) \ \ -ftype_r PASTEF772(chr,chx,blasname) \ +BLIS_EXPORT_BLAS ftype_r PASTEF772(chr,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ diff --git a/frame/compat/bla_axpy.h b/frame/compat/bla_axpy.h index 3c014f36f..294a385c7 100644 --- a/frame/compat/bla_axpy.h +++ b/frame/compat/bla_axpy.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_int* n, \ const ftype* alpha, \ diff --git a/frame/compat/bla_copy.h b/frame/compat/bla_copy.h index 5f95afff2..679017b19 100644 --- a/frame/compat/bla_copy.h +++ b/frame/compat/bla_copy.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ diff --git a/frame/compat/bla_dot.h b/frame/compat/bla_dot.h index 01eb532e6..373e1a7b7 100644 --- a/frame/compat/bla_dot.h +++ b/frame/compat/bla_dot.h @@ -39,7 +39,7 @@ #undef GENTPROTDOT #define GENTPROTDOT( ftype, ch, chc, blasname ) \ \ -ftype PASTEF772(ch,blasname,chc) \ +BLIS_EXPORT_BLAS ftype PASTEF772(ch,blasname,chc) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -52,7 +52,7 @@ INSERT_GENTPROTDOT_BLAS( dot ) // -- "Black sheep" dot product function prototypes -- -float PASTEF77(sd,sdot) +BLIS_EXPORT_BLAS float PASTEF77(sd,sdot) ( const f77_int* n, const float* sb, @@ -60,7 +60,7 @@ float PASTEF77(sd,sdot) const float* y, const f77_int* incy ); -double PASTEF77(d,sdot) +BLIS_EXPORT_BLAS double PASTEF77(d,sdot) ( const f77_int* n, const float* x, const f77_int* incx, diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 1effececa..e04e48cf5 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ @@ -84,7 +88,7 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ \ - /* Convert/typecast negative values of m, n, and k to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *n, n0 ); \ bli_convert_blas_dim1( *k, k0 ); \ @@ -118,6 +122,105 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( gemm, gemm ) #endif diff --git a/frame/compat/bla_gemm.h b/frame/compat/bla_gemm.h index 18a101da1..77111dbd8 100644 --- a/frame/compat/bla_gemm.h +++ b/frame/compat/bla_gemm.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ diff --git a/frame/compat/bla_gemv.h b/frame/compat/bla_gemv.h index da4561606..22c8bf1c0 100644 --- a/frame/compat/bla_gemv.h +++ b/frame/compat/bla_gemv.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* transa, \ const f77_int* m, \ diff --git a/frame/compat/bla_ger.h b/frame/compat/bla_ger.h index d37511c0a..a31548f61 100644 --- a/frame/compat/bla_ger.h +++ b/frame/compat/bla_ger.h @@ -39,7 +39,7 @@ #undef GENTPROTDOT #define GENTPROTDOT( ftype, chxy, chc, blasname ) \ \ -void PASTEF772(chxy,blasname,chc) \ +BLIS_EXPORT_BLAS void PASTEF772(chxy,blasname,chc) \ ( \ const f77_int* m, \ const f77_int* n, \ diff --git a/frame/compat/bla_hemm.c b/frame/compat/bla_hemm.c index 88e9c8b55..6bfb13e18 100644 --- a/frame/compat/bla_hemm.c +++ b/frame/compat/bla_hemm.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ @@ -82,7 +86,7 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ \ - /* Convert/typecast negative values of m and n to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *n, n0 ); \ \ @@ -116,6 +120,110 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNCCO +#define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + side_t blis_side; \ + uplo_t blis_uploa; \ + dim_t m0, n0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + m, \ + n, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const conj_t conja = BLIS_NO_CONJUGATE; \ + const trans_t transb = BLIS_NO_TRANSPOSE; \ + const struc_t struca = BLIS_HERMITIAN; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); \ + bli_set_dims_with_trans( transb, m0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, mn0_a, mn0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploa, &ao ); \ + bli_obj_set_conj( conja, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + blis_side, \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNCCO_BLAS( hemm, hemm ) #endif diff --git a/frame/compat/bla_hemm.h b/frame/compat/bla_hemm.h index 712fc611a..711877ede 100644 --- a/frame/compat/bla_hemm.h +++ b/frame/compat/bla_hemm.h @@ -39,7 +39,7 @@ #undef GENTPROTCO #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ diff --git a/frame/compat/bla_hemv.h b/frame/compat/bla_hemv.h index 01c25e324..4e8230114 100644 --- a/frame/compat/bla_hemv.h +++ b/frame/compat/bla_hemv.h @@ -39,7 +39,7 @@ #undef GENTPROTCO #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ diff --git a/frame/compat/bla_her.h b/frame/compat/bla_her.h index f93f49ac0..b9ae30d90 100644 --- a/frame/compat/bla_her.h +++ b/frame/compat/bla_her.h @@ -39,7 +39,7 @@ #undef GENTPROTCO #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ diff --git a/frame/compat/bla_her2.h b/frame/compat/bla_her2.h index c96374130..7cf0bb867 100644 --- a/frame/compat/bla_her2.h +++ b/frame/compat/bla_her2.h @@ -39,7 +39,7 @@ #undef GENTPROTCO #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ diff --git a/frame/compat/bla_her2k.c b/frame/compat/bla_her2k.c index 0bbe98e1c..df5121975 100644 --- a/frame/compat/bla_her2k.c +++ b/frame/compat/bla_her2k.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ @@ -82,7 +86,7 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ \ - /* Convert/typecast negative values of m and k to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *k, k0 ); \ \ @@ -132,6 +136,126 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNCCO +#define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype_r* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + uplo_t blis_uploc; \ + trans_t blis_transa; \ + dim_t m0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + uploc, \ + transa, \ + m, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* We emulate the BLAS early return behavior with the following + conditional, which returns if one of the following is true: + - matrix C is empty + - the rank-2k product is empty (either because alpha is zero or k + is zero) AND matrix C is not scaled. */ \ + if ( m0 == 0 || \ + ( ( PASTEMAC(ch,eq0)( *alpha ) || k0 == 0 ) \ + && PASTEMAC(chr,eq1)( *beta ) \ + ) \ + ) \ + { \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + return; \ + } \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt_r = PASTEMAC(chr,type); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + const trans_t transb = blis_transa; \ + const struc_t strucc = BLIS_HERMITIAN; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( transb, m0, k0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype* )alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt_r, (ftype_r*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, m0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploc, &co ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( strucc, &co ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNCCO_BLAS( her2k, her2k ) #endif diff --git a/frame/compat/bla_her2k.h b/frame/compat/bla_her2k.h index e04b11755..c771f78d4 100644 --- a/frame/compat/bla_her2k.h +++ b/frame/compat/bla_her2k.h @@ -39,7 +39,7 @@ #undef GENTPROTCO #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_herk.c b/frame/compat/bla_herk.c index 88185de0b..d9c47f5af 100644 --- a/frame/compat/bla_herk.c +++ b/frame/compat/bla_herk.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ @@ -79,7 +83,7 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ \ - /* Convert/typecast negative values of m and k to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *k, k0 ); \ \ @@ -125,6 +129,115 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNCCO +#define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype_r* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype_r* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + uplo_t blis_uploc; \ + trans_t blis_transa; \ + dim_t m0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + uploc, \ + transa, \ + m, \ + k, \ + lda, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* We emulate the BLAS early return behavior with the following + conditional, which returns if one of the following is true: + - matrix C is empty + - the rank-k product is empty (either because alpha is zero or k + is zero) AND matrix C is not scaled. */ \ + if ( m0 == 0 || \ + ( ( PASTEMAC(chr,eq0)( *alpha ) || k0 == 0 ) \ + && PASTEMAC(chr,eq1)( *beta ) \ + ) \ + ) \ + { \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + return; \ + } \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt_r = PASTEMAC(chr,type); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + const struc_t strucc = BLIS_HERMITIAN; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ +\ + bli_obj_init_finish_1x1( dt_r, (ftype_r*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt_r, (ftype_r*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0, m0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploc, &co ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ +\ + bli_obj_set_struc( strucc, &co ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNCCO_BLAS( herk, herk ) #endif diff --git a/frame/compat/bla_herk.h b/frame/compat/bla_herk.h index 6b3ebd38c..e649a74ab 100644 --- a/frame/compat/bla_herk.h +++ b/frame/compat/bla_herk.h @@ -39,7 +39,7 @@ #undef GENTPROTCO #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_nrm2.h b/frame/compat/bla_nrm2.h index af18d6ec3..a8bc25ef4 100644 --- a/frame/compat/bla_nrm2.h +++ b/frame/compat/bla_nrm2.h @@ -39,7 +39,7 @@ #undef GENTPROTR2 #define GENTPROTR2( ftype_x, ftype_r, chx, chr, blasname ) \ \ -ftype_r PASTEF772(chr,chx,blasname) \ +BLIS_EXPORT_BLAS ftype_r PASTEF772(chr,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ diff --git a/frame/compat/bla_scal.h b/frame/compat/bla_scal.h index a615ea13f..c8e898b6b 100644 --- a/frame/compat/bla_scal.h +++ b/frame/compat/bla_scal.h @@ -39,7 +39,7 @@ #undef GENTPROTSCAL #define GENTPROTSCAL( ftype_a, ftype_x, cha, chx, blasname ) \ \ -void PASTEF772(chx,cha,blasname) \ +BLIS_EXPORT_BLAS void PASTEF772(chx,cha,blasname) \ ( \ const f77_int* n, \ const ftype_a* alpha, \ diff --git a/frame/compat/bla_swap.h b/frame/compat/bla_swap.h index 4943a6504..54c0613a9 100644 --- a/frame/compat/bla_swap.h +++ b/frame/compat/bla_swap.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_int* n, \ ftype* x, const f77_int* incx, \ diff --git a/frame/compat/bla_symm.c b/frame/compat/bla_symm.c index 02d3a3b27..b4f0b66d0 100644 --- a/frame/compat/bla_symm.c +++ b/frame/compat/bla_symm.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ @@ -82,7 +86,7 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ \ - /* Convert/typecast negative values of m and n to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *n, n0 ); \ \ @@ -116,6 +120,110 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + side_t blis_side; \ + uplo_t blis_uploa; \ + dim_t m0, n0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + m, \ + n, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const conj_t conja = BLIS_NO_CONJUGATE; \ + const trans_t transb = BLIS_NO_TRANSPOSE; \ + const struc_t struca = BLIS_SYMMETRIC; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); \ + bli_set_dims_with_trans( transb, m0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, mn0_a, mn0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploa, &ao ); \ + bli_obj_set_conj( conja, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + blis_side, \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( symm, symm ) #endif diff --git a/frame/compat/bla_symm.h b/frame/compat/bla_symm.h index 6bfdec35e..b186e4b43 100644 --- a/frame/compat/bla_symm.h +++ b/frame/compat/bla_symm.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ diff --git a/frame/compat/bla_symv.h b/frame/compat/bla_symv.h index d003f1124..9d1662fad 100644 --- a/frame/compat/bla_symv.h +++ b/frame/compat/bla_symv.h @@ -39,7 +39,7 @@ #undef GENTPROTRO #define GENTPROTRO( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ diff --git a/frame/compat/bla_syr.h b/frame/compat/bla_syr.h index 4a1d79d3e..0d2a1e031 100644 --- a/frame/compat/bla_syr.h +++ b/frame/compat/bla_syr.h @@ -39,7 +39,7 @@ #undef GENTPROTRO #define GENTPROTRO( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ diff --git a/frame/compat/bla_syr2.h b/frame/compat/bla_syr2.h index 06e4c2d91..b45876794 100644 --- a/frame/compat/bla_syr2.h +++ b/frame/compat/bla_syr2.h @@ -39,7 +39,7 @@ #undef GENTPROTRO #define GENTPROTRO( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ diff --git a/frame/compat/bla_syr2k.c b/frame/compat/bla_syr2k.c index 7e611b1d6..35cfca9a3 100644 --- a/frame/compat/bla_syr2k.c +++ b/frame/compat/bla_syr2k.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ @@ -91,7 +95,7 @@ void PASTEF77(ch,blasname) \ blis_transa = BLIS_TRANSPOSE; \ } \ \ - /* Convert/typecast negative values of m and k to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *k, k0 ); \ \ @@ -124,6 +128,117 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + uplo_t blis_uploc; \ + trans_t blis_transa; \ + dim_t m0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + uploc, \ + transa, \ + m, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ +\ + /* The real domain ssyr2k and dsyr2k in netlib BLAS treat a trans value + of 'C' (conjugate-transpose) as 'T' (transpose only). So, we have + to go out of our way a little to support this behavior. */ \ + if ( bli_is_real( PASTEMAC(ch,type) ) && \ + bli_is_conjtrans( blis_transa ) ) \ + { \ + blis_transa = BLIS_TRANSPOSE; \ + } \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const trans_t transb = blis_transa; \ + const struc_t strucc = BLIS_SYMMETRIC; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( transb, m0, k0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, m0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploc, &co ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + bli_obj_set_struc( strucc, &co ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( syr2k, syr2k ) #endif diff --git a/frame/compat/bla_syr2k.h b/frame/compat/bla_syr2k.h index f1eb8e127..91d9a3acf 100644 --- a/frame/compat/bla_syr2k.h +++ b/frame/compat/bla_syr2k.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_syrk.c b/frame/compat/bla_syrk.c index 9c08dd06b..82ce2f166 100644 --- a/frame/compat/bla_syrk.c +++ b/frame/compat/bla_syrk.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ @@ -88,7 +92,7 @@ void PASTEF77(ch,blasname) \ blis_transa = BLIS_TRANSPOSE; \ } \ \ - /* Convert/typecast negative values of m and k to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *k, k0 ); \ \ @@ -117,6 +121,106 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + uplo_t blis_uploc; \ + trans_t blis_transa; \ + dim_t m0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + uploc, \ + transa, \ + m, \ + k, \ + lda, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ +\ + /* The real domain ssyrk and dsyrk in netlib BLAS treat a trans value + of 'C' (conjugate-transpose) as 'T' (transpose only). So, we have + to go out of our way a little to support this behavior. */ \ + if ( bli_is_real( PASTEMAC(ch,type) ) && \ + bli_is_conjtrans( blis_transa ) ) \ + { \ + blis_transa = BLIS_TRANSPOSE; \ + } \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const struc_t strucc = BLIS_SYMMETRIC; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0, m0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_uplo( blis_uploc, &co ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ +\ + bli_obj_set_struc( strucc, &co ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( syrk, syrk ) #endif diff --git a/frame/compat/bla_syrk.h b/frame/compat/bla_syrk.h index 9b2e49c5a..b6ca938a6 100644 --- a/frame/compat/bla_syrk.h +++ b/frame/compat/bla_syrk.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_trmm.c b/frame/compat/bla_trmm.c index 116d2b8c4..ce099dc59 100644 --- a/frame/compat/bla_trmm.c +++ b/frame/compat/bla_trmm.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ @@ -86,7 +90,7 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ \ - /* Convert/typecast negative values of m and n to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *n, n0 ); \ \ @@ -116,6 +120,103 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + side_t blis_side; \ + uplo_t blis_uploa; \ + trans_t blis_transa; \ + diag_t blis_diaga; \ + dim_t m0, n0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + transa, \ + diaga, \ + m, \ + n, \ + lda, \ + ldb \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const struc_t struca = BLIS_TRIANGULAR; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn0_a; \ +\ + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ +\ + bli_obj_init_finish( dt, mn0_a, mn0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)b, rs_b, cs_b, &bo ); \ +\ + bli_obj_set_uplo( blis_uploa, &ao ); \ + bli_obj_set_diag( blis_diaga, &ao ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + blis_side, \ + &alphao, \ + &ao, \ + &bo, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( trmm, trmm ) #endif diff --git a/frame/compat/bla_trmm.h b/frame/compat/bla_trmm.h index 7c800f9eb..4f0c20b1b 100644 --- a/frame/compat/bla_trmm.h +++ b/frame/compat/bla_trmm.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ diff --git a/frame/compat/bla_trmv.h b/frame/compat/bla_trmv.h index 4faec098b..4096ffe79 100644 --- a/frame/compat/bla_trmv.h +++ b/frame/compat/bla_trmv.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_char* transa, \ diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index 70597cc93..c0d8e4b3e 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +39,9 @@ // // Define BLAS-to-BLIS interfaces. // + +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ @@ -86,7 +90,7 @@ void PASTEF77(ch,blasname) \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ \ - /* Convert/typecast negative values of m and n to zero. */ \ + /* Typecast BLAS integers to BLIS integers. */ \ bli_convert_blas_dim1( *m, m0 ); \ bli_convert_blas_dim1( *n, n0 ); \ \ @@ -116,6 +120,103 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + side_t blis_side; \ + uplo_t blis_uploa; \ + trans_t blis_transa; \ + diag_t blis_diaga; \ + dim_t m0, n0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + transa, \ + diaga, \ + m, \ + n, \ + lda, \ + ldb \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + const struc_t struca = BLIS_TRIANGULAR; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn0_a; \ +\ + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ +\ + bli_obj_init_finish( dt, mn0_a, mn0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)b, rs_b, cs_b, &bo ); \ +\ + bli_obj_set_uplo( blis_uploa, &ao ); \ + bli_obj_set_diag( blis_diaga, &ao ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + blis_side, \ + &alphao, \ + &ao, \ + &bo, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( trsm, trsm ) #endif diff --git a/frame/compat/bla_trsm.h b/frame/compat/bla_trsm.h index a2c2222b0..5694db52a 100644 --- a/frame/compat/bla_trsm.h +++ b/frame/compat/bla_trsm.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ diff --git a/frame/compat/bla_trsv.h b/frame/compat/bla_trsv.h index cec3976be..6edb435f1 100644 --- a/frame/compat/bla_trsv.h +++ b/frame/compat/bla_trsv.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype, ch, blasname ) \ \ -void PASTEF77(ch,blasname) \ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_char* transa, \ diff --git a/frame/compat/bli_blas.h b/frame/compat/bli_blas.h index f2c3f9495..e1a7321a4 100644 --- a/frame/compat/bli_blas.h +++ b/frame/compat/bli_blas.h @@ -40,10 +40,31 @@ #endif #endif // BLIS_ENABLE_CBLAS +// By default, if the BLAS compatibility layer is enabled, we define +// (include) all of the BLAS prototypes. However, if the user is +// #including "blis.h" and also #including another header that also +// declares the BLAS functions, then we provide an opportunity to +// #undefine the BLIS_ENABLE_BLAS_DEFS macro (see below). +#ifdef BLIS_ENABLE_BLAS +#define BLIS_ENABLE_BLAS_DEFS +#else +#undef BLIS_ENABLE_BLAS_DEFS +#endif + // Skip prototyping all of the BLAS if the BLAS test drivers are being // compiled. -#ifndef BLIS_VIA_BLASTEST -#ifdef BLIS_ENABLE_BLAS +#ifdef BLIS_VIA_BLASTEST +#undef BLIS_ENABLE_BLAS_DEFS +#endif + +// Skip prototyping all of the BLAS if the environment has defined the +// macro BLIS_DISABLE_BLAS_DEFS. +#ifdef BLIS_DISABLE_BLAS_DEFS +#undef BLIS_ENABLE_BLAS_DEFS +#endif + +// Begin including all BLAS prototypes. +#ifdef BLIS_ENABLE_BLAS_DEFS // -- System headers needed by BLAS compatibility layer -- @@ -180,4 +201,3 @@ #endif // BLIS_ENABLE_BLAS -#endif // BLIS_VIA_BLASTEST diff --git a/frame/compat/blis/thread/b77_thread.h b/frame/compat/blis/thread/b77_thread.h index 0e87f6bb0..922ed6e13 100644 --- a/frame/compat/blis/thread/b77_thread.h +++ b/frame/compat/blis/thread/b77_thread.h @@ -37,7 +37,7 @@ // Prototype Fortran-compatible BLIS interfaces. // -void PASTEF770(bli_thread_set_ways) +BLIS_EXPORT_BLAS void PASTEF770(bli_thread_set_ways) ( const f77_int* jc, const f77_int* pc, @@ -46,7 +46,7 @@ void PASTEF770(bli_thread_set_ways) const f77_int* ir ); -void PASTEF770(bli_thread_set_num_threads) +BLIS_EXPORT_BLAS void PASTEF770(bli_thread_set_num_threads) ( const f77_int* nt ); diff --git a/frame/compat/cblas/f77_sub/f77_amax_sub.h b/frame/compat/cblas/f77_sub/f77_amax_sub.h index 9a4ebb0af..9cd1202d2 100644 --- a/frame/compat/cblas/f77_sub/f77_amax_sub.h +++ b/frame/compat/cblas/f77_sub/f77_amax_sub.h @@ -39,7 +39,7 @@ #undef GENTPROT #define GENTPROT( ftype_x, chx, blasname ) \ \ -void PASTEF773(i,chx,blasname,sub) \ +BLIS_EXPORT_BLAS void PASTEF773(i,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ diff --git a/frame/compat/cblas/f77_sub/f77_asum_sub.h b/frame/compat/cblas/f77_sub/f77_asum_sub.h index 2c61e14d0..4b8634c16 100644 --- a/frame/compat/cblas/f77_sub/f77_asum_sub.h +++ b/frame/compat/cblas/f77_sub/f77_asum_sub.h @@ -39,7 +39,7 @@ #undef GENTPROTR2 #define GENTPROTR2( ftype_x, ftype_r, chx, chr, blasname ) \ \ -void PASTEF773(chr,chx,blasname,sub) \ +BLIS_EXPORT_BLAS void PASTEF773(chr,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ diff --git a/frame/compat/cblas/f77_sub/f77_dot_sub.h b/frame/compat/cblas/f77_sub/f77_dot_sub.h index 2ee169591..8aab2728b 100644 --- a/frame/compat/cblas/f77_sub/f77_dot_sub.h +++ b/frame/compat/cblas/f77_sub/f77_dot_sub.h @@ -39,7 +39,7 @@ #undef GENTPROTDOT #define GENTPROTDOT( ftype, ch, chc, blasname ) \ \ -void PASTEF773(ch,blasname,chc,sub) \ +BLIS_EXPORT_BLAS void PASTEF773(ch,blasname,chc,sub) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -53,7 +53,7 @@ INSERT_GENTPROTDOT_BLAS( dot ) // -- "Black sheep" dot product function prototypes -- -void PASTEF772(sds,dot,sub) +BLIS_EXPORT_BLAS void PASTEF772(sds,dot,sub) ( const f77_int* n, const float* sb, @@ -62,7 +62,7 @@ void PASTEF772(sds,dot,sub) float* rval ); -void PASTEF772(ds,dot,sub) +BLIS_EXPORT_BLAS void PASTEF772(ds,dot,sub) ( const f77_int* n, const float* x, const f77_int* incx, diff --git a/frame/compat/cblas/f77_sub/f77_nrm2_sub.h b/frame/compat/cblas/f77_sub/f77_nrm2_sub.h index df2dd2357..c51a94292 100644 --- a/frame/compat/cblas/f77_sub/f77_nrm2_sub.h +++ b/frame/compat/cblas/f77_sub/f77_nrm2_sub.h @@ -39,7 +39,7 @@ #undef GENTPROTR2 #define GENTPROTR2( ftype_x, ftype_r, chx, chr, blasname ) \ \ -void PASTEF773(chr,chx,blasname,sub) \ +BLIS_EXPORT_BLAS void PASTEF773(chr,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ diff --git a/frame/compat/cblas/src/cblas.h b/frame/compat/cblas/src/cblas.h index 1ee6209c9..85778c8a4 100644 --- a/frame/compat/cblas/src/cblas.h +++ b/frame/compat/cblas/src/cblas.h @@ -28,52 +28,52 @@ extern "C" { * Prototypes for level 1 BLAS functions (complex are recast as routines) * =========================================================================== */ -float cblas_sdsdot(f77_int N, float alpha, const float *X, +BLIS_EXPORT_BLAS float cblas_sdsdot(f77_int N, float alpha, const float *X, f77_int incX, const float *Y, f77_int incY); -double cblas_dsdot(f77_int N, const float *X, f77_int incX, const float *Y, +BLIS_EXPORT_BLAS double cblas_dsdot(f77_int N, const float *X, f77_int incX, const float *Y, f77_int incY); -float cblas_sdot(f77_int N, const float *X, f77_int incX, +BLIS_EXPORT_BLAS float cblas_sdot(f77_int N, const float *X, f77_int incX, const float *Y, f77_int incY); -double cblas_ddot(f77_int N, const double *X, f77_int incX, +BLIS_EXPORT_BLAS double cblas_ddot(f77_int N, const double *X, f77_int incX, const double *Y, f77_int incY); /* * Functions having prefixes Z and C only */ -void cblas_cdotu_sub(f77_int N, const void *X, f77_int incX, +BLIS_EXPORT_BLAS void cblas_cdotu_sub(f77_int N, const void *X, f77_int incX, const void *Y, f77_int incY, void *dotu); -void cblas_cdotc_sub(f77_int N, const void *X, f77_int incX, +BLIS_EXPORT_BLAS void cblas_cdotc_sub(f77_int N, const void *X, f77_int incX, const void *Y, f77_int incY, void *dotc); -void cblas_zdotu_sub(f77_int N, const void *X, f77_int incX, +BLIS_EXPORT_BLAS void cblas_zdotu_sub(f77_int N, const void *X, f77_int incX, const void *Y, f77_int incY, void *dotu); -void cblas_zdotc_sub(f77_int N, const void *X, f77_int incX, +BLIS_EXPORT_BLAS void cblas_zdotc_sub(f77_int N, const void *X, f77_int incX, const void *Y, f77_int incY, void *dotc); /* * Functions having prefixes S D SC DZ */ -float cblas_snrm2(f77_int N, const float *X, f77_int incX); -float cblas_sasum(f77_int N, const float *X, f77_int incX); +BLIS_EXPORT_BLAS float cblas_snrm2(f77_int N, const float *X, f77_int incX); +BLIS_EXPORT_BLAS float cblas_sasum(f77_int N, const float *X, f77_int incX); -double cblas_dnrm2(f77_int N, const double *X, f77_int incX); -double cblas_dasum(f77_int N, const double *X, f77_int incX); +BLIS_EXPORT_BLAS double cblas_dnrm2(f77_int N, const double *X, f77_int incX); +BLIS_EXPORT_BLAS double cblas_dasum(f77_int N, const double *X, f77_int incX); -float cblas_scnrm2(f77_int N, const void *X, f77_int incX); -float cblas_scasum(f77_int N, const void *X, f77_int incX); +BLIS_EXPORT_BLAS float cblas_scnrm2(f77_int N, const void *X, f77_int incX); +BLIS_EXPORT_BLAS float cblas_scasum(f77_int N, const void *X, f77_int incX); -double cblas_dznrm2(f77_int N, const void *X, f77_int incX); -double cblas_dzasum(f77_int N, const void *X, f77_int incX); +BLIS_EXPORT_BLAS double cblas_dznrm2(f77_int N, const void *X, f77_int incX); +BLIS_EXPORT_BLAS double cblas_dzasum(f77_int N, const void *X, f77_int incX); /* * Functions having standard 4 prefixes (S D C Z) */ -f77_int cblas_isamax(f77_int N, const float *X, f77_int incX); -f77_int cblas_idamax(f77_int N, const double *X, f77_int incX); -f77_int cblas_icamax(f77_int N, const void *X, f77_int incX); -f77_int cblas_izamax(f77_int N, const void *X, f77_int incX); +BLIS_EXPORT_BLAS f77_int cblas_isamax(f77_int N, const float *X, f77_int incX); +BLIS_EXPORT_BLAS f77_int cblas_idamax(f77_int N, const double *X, f77_int incX); +BLIS_EXPORT_BLAS f77_int cblas_icamax(f77_int N, const void *X, f77_int incX); +BLIS_EXPORT_BLAS f77_int cblas_izamax(f77_int N, const void *X, f77_int incX); /* * =========================================================================== @@ -84,62 +84,62 @@ f77_int cblas_izamax(f77_int N, const void *X, f77_int incX); /* * Routines with standard 4 prefixes (s, d, c, z) */ -void cblas_sswap(f77_int N, float *X, f77_int incX, +void BLIS_EXPORT_BLAS cblas_sswap(f77_int N, float *X, f77_int incX, float *Y, f77_int incY); -void cblas_scopy(f77_int N, const float *X, f77_int incX, +void BLIS_EXPORT_BLAS cblas_scopy(f77_int N, const float *X, f77_int incX, float *Y, f77_int incY); -void cblas_saxpy(f77_int N, float alpha, const float *X, +void BLIS_EXPORT_BLAS cblas_saxpy(f77_int N, float alpha, const float *X, f77_int incX, float *Y, f77_int incY); -void cblas_dswap(f77_int N, double *X, f77_int incX, +void BLIS_EXPORT_BLAS cblas_dswap(f77_int N, double *X, f77_int incX, double *Y, f77_int incY); -void cblas_dcopy(f77_int N, const double *X, f77_int incX, +void BLIS_EXPORT_BLAS cblas_dcopy(f77_int N, const double *X, f77_int incX, double *Y, f77_int incY); -void cblas_daxpy(f77_int N, double alpha, const double *X, +void BLIS_EXPORT_BLAS cblas_daxpy(f77_int N, double alpha, const double *X, f77_int incX, double *Y, f77_int incY); -void cblas_cswap(f77_int N, void *X, f77_int incX, +void BLIS_EXPORT_BLAS cblas_cswap(f77_int N, void *X, f77_int incX, void *Y, f77_int incY); -void cblas_ccopy(f77_int N, const void *X, f77_int incX, +void BLIS_EXPORT_BLAS cblas_ccopy(f77_int N, const void *X, f77_int incX, void *Y, f77_int incY); -void cblas_caxpy(f77_int N, const void *alpha, const void *X, +void BLIS_EXPORT_BLAS cblas_caxpy(f77_int N, const void *alpha, const void *X, f77_int incX, void *Y, f77_int incY); -void cblas_zswap(f77_int N, void *X, f77_int incX, +void BLIS_EXPORT_BLAS cblas_zswap(f77_int N, void *X, f77_int incX, void *Y, f77_int incY); -void cblas_zcopy(f77_int N, const void *X, f77_int incX, +void BLIS_EXPORT_BLAS cblas_zcopy(f77_int N, const void *X, f77_int incX, void *Y, f77_int incY); -void cblas_zaxpy(f77_int N, const void *alpha, const void *X, +void BLIS_EXPORT_BLAS cblas_zaxpy(f77_int N, const void *alpha, const void *X, f77_int incX, void *Y, f77_int incY); /* * Routines with S and D prefix only */ -void cblas_srotg(float *a, float *b, float *c, float *s); -void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P); -void cblas_srot(f77_int N, float *X, f77_int incX, +void BLIS_EXPORT_BLAS cblas_srotg(float *a, float *b, float *c, float *s); +void BLIS_EXPORT_BLAS cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P); +void BLIS_EXPORT_BLAS cblas_srot(f77_int N, float *X, f77_int incX, float *Y, f77_int incY, const float c, const float s); -void cblas_srotm(f77_int N, float *X, f77_int incX, +void BLIS_EXPORT_BLAS cblas_srotm(f77_int N, float *X, f77_int incX, float *Y, f77_int incY, const float *P); -void cblas_drotg(double *a, double *b, double *c, double *s); -void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P); -void cblas_drot(f77_int N, double *X, f77_int incX, +void BLIS_EXPORT_BLAS cblas_drotg(double *a, double *b, double *c, double *s); +void BLIS_EXPORT_BLAS cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P); +void BLIS_EXPORT_BLAS cblas_drot(f77_int N, double *X, f77_int incX, double *Y, f77_int incY, const double c, const double s); -void cblas_drotm(f77_int N, double *X, f77_int incX, +void BLIS_EXPORT_BLAS cblas_drotm(f77_int N, double *X, f77_int incX, double *Y, f77_int incY, const double *P); /* * Routines with S D C Z CS and ZD prefixes */ -void cblas_sscal(f77_int N, float alpha, float *X, f77_int incX); -void cblas_dscal(f77_int N, double alpha, double *X, f77_int incX); -void cblas_cscal(f77_int N, const void *alpha, void *X, f77_int incX); -void cblas_zscal(f77_int N, const void *alpha, void *X, f77_int incX); -void cblas_csscal(f77_int N, float alpha, void *X, f77_int incX); -void cblas_zdscal(f77_int N, double alpha, void *X, f77_int incX); +void BLIS_EXPORT_BLAS cblas_sscal(f77_int N, float alpha, float *X, f77_int incX); +void BLIS_EXPORT_BLAS cblas_dscal(f77_int N, double alpha, double *X, f77_int incX); +void BLIS_EXPORT_BLAS cblas_cscal(f77_int N, const void *alpha, void *X, f77_int incX); +void BLIS_EXPORT_BLAS cblas_zscal(f77_int N, const void *alpha, void *X, f77_int incX); +void BLIS_EXPORT_BLAS cblas_csscal(f77_int N, float alpha, void *X, f77_int incX); +void BLIS_EXPORT_BLAS cblas_zdscal(f77_int N, double alpha, void *X, f77_int incX); /* * =========================================================================== @@ -150,135 +150,135 @@ void cblas_zdscal(f77_int N, double alpha, void *X, f77_int incX); /* * Routines with standard 4 prefixes (S, D, C, Z) */ -void cblas_sgemv(enum CBLAS_ORDER order, +void BLIS_EXPORT_BLAS cblas_sgemv(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, f77_int M, f77_int N, float alpha, const float *A, f77_int lda, const float *X, f77_int incX, float beta, float *Y, f77_int incY); -void cblas_sgbmv(enum CBLAS_ORDER order, +void BLIS_EXPORT_BLAS cblas_sgbmv(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, f77_int M, f77_int N, f77_int KL, f77_int KU, float alpha, const float *A, f77_int lda, const float *X, f77_int incX, float beta, float *Y, f77_int incY); -void cblas_strmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_strmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const float *A, f77_int lda, float *X, f77_int incX); -void cblas_stbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_stbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, f77_int K, const float *A, f77_int lda, float *X, f77_int incX); -void cblas_stpmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_stpmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const float *Ap, float *X, f77_int incX); -void cblas_strsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_strsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const float *A, f77_int lda, float *X, f77_int incX); -void cblas_stbsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_stbsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, f77_int K, const float *A, f77_int lda, float *X, f77_int incX); -void cblas_stpsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_stpsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const float *Ap, float *X, f77_int incX); -void cblas_dgemv(enum CBLAS_ORDER order, +void BLIS_EXPORT_BLAS cblas_dgemv(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, f77_int M, f77_int N, double alpha, const double *A, f77_int lda, const double *X, f77_int incX, double beta, double *Y, f77_int incY); -void cblas_dgbmv(enum CBLAS_ORDER order, +void BLIS_EXPORT_BLAS cblas_dgbmv(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, f77_int M, f77_int N, f77_int KL, f77_int KU, double alpha, const double *A, f77_int lda, const double *X, f77_int incX, double beta, double *Y, f77_int incY); -void cblas_dtrmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dtrmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const double *A, f77_int lda, double *X, f77_int incX); -void cblas_dtbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dtbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, f77_int K, const double *A, f77_int lda, double *X, f77_int incX); -void cblas_dtpmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dtpmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const double *Ap, double *X, f77_int incX); -void cblas_dtrsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dtrsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const double *A, f77_int lda, double *X, f77_int incX); -void cblas_dtbsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dtbsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, f77_int K, const double *A, f77_int lda, double *X, f77_int incX); -void cblas_dtpsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dtpsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const double *Ap, double *X, f77_int incX); -void cblas_cgemv(enum CBLAS_ORDER order, +void BLIS_EXPORT_BLAS cblas_cgemv(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, f77_int M, f77_int N, const void *alpha, const void *A, f77_int lda, const void *X, f77_int incX, const void *beta, void *Y, f77_int incY); -void cblas_cgbmv(enum CBLAS_ORDER order, +void BLIS_EXPORT_BLAS cblas_cgbmv(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, f77_int M, f77_int N, f77_int KL, f77_int KU, const void *alpha, const void *A, f77_int lda, const void *X, f77_int incX, const void *beta, void *Y, f77_int incY); -void cblas_ctrmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ctrmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const void *A, f77_int lda, void *X, f77_int incX); -void cblas_ctbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ctbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, f77_int K, const void *A, f77_int lda, void *X, f77_int incX); -void cblas_ctpmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ctpmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const void *Ap, void *X, f77_int incX); -void cblas_ctrsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ctrsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const void *A, f77_int lda, void *X, f77_int incX); -void cblas_ctbsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ctbsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, f77_int K, const void *A, f77_int lda, void *X, f77_int incX); -void cblas_ctpsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ctpsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const void *Ap, void *X, f77_int incX); -void cblas_zgemv(enum CBLAS_ORDER order, +void BLIS_EXPORT_BLAS cblas_zgemv(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, f77_int M, f77_int N, const void *alpha, const void *A, f77_int lda, const void *X, f77_int incX, const void *beta, void *Y, f77_int incY); -void cblas_zgbmv(enum CBLAS_ORDER order, +void BLIS_EXPORT_BLAS cblas_zgbmv(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, f77_int M, f77_int N, f77_int KL, f77_int KU, const void *alpha, const void *A, f77_int lda, const void *X, f77_int incX, const void *beta, void *Y, f77_int incY); -void cblas_ztrmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ztrmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const void *A, f77_int lda, void *X, f77_int incX); -void cblas_ztbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ztbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, f77_int K, const void *A, f77_int lda, void *X, f77_int incX); -void cblas_ztpmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ztpmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const void *Ap, void *X, f77_int incX); -void cblas_ztrsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ztrsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const void *A, f77_int lda, void *X, f77_int incX); -void cblas_ztbsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ztbsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, f77_int K, const void *A, f77_int lda, void *X, f77_int incX); -void cblas_ztpsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ztpsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int N, const void *Ap, void *X, f77_int incX); @@ -286,61 +286,61 @@ void cblas_ztpsv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, /* * Routines with S and D prefixes only */ -void cblas_ssymv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ssymv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, float alpha, const float *A, f77_int lda, const float *X, f77_int incX, float beta, float *Y, f77_int incY); -void cblas_ssbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ssbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, f77_int K, float alpha, const float *A, f77_int lda, const float *X, f77_int incX, float beta, float *Y, f77_int incY); -void cblas_sspmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_sspmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, float alpha, const float *Ap, const float *X, f77_int incX, float beta, float *Y, f77_int incY); -void cblas_sger(enum CBLAS_ORDER order, f77_int M, f77_int N, +void BLIS_EXPORT_BLAS cblas_sger(enum CBLAS_ORDER order, f77_int M, f77_int N, float alpha, const float *X, f77_int incX, const float *Y, f77_int incY, float *A, f77_int lda); -void cblas_ssyr(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ssyr(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, float alpha, const float *X, f77_int incX, float *A, f77_int lda); -void cblas_sspr(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_sspr(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, float alpha, const float *X, f77_int incX, float *Ap); -void cblas_ssyr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ssyr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, float alpha, const float *X, f77_int incX, const float *Y, f77_int incY, float *A, f77_int lda); -void cblas_sspr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_sspr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, float alpha, const float *X, f77_int incX, const float *Y, f77_int incY, float *A); -void cblas_dsymv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dsymv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, double alpha, const double *A, f77_int lda, const double *X, f77_int incX, double beta, double *Y, f77_int incY); -void cblas_dsbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dsbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, f77_int K, double alpha, const double *A, f77_int lda, const double *X, f77_int incX, double beta, double *Y, f77_int incY); -void cblas_dspmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dspmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, double alpha, const double *Ap, const double *X, f77_int incX, double beta, double *Y, f77_int incY); -void cblas_dger(enum CBLAS_ORDER order, f77_int M, f77_int N, +void BLIS_EXPORT_BLAS cblas_dger(enum CBLAS_ORDER order, f77_int M, f77_int N, double alpha, const double *X, f77_int incX, const double *Y, f77_int incY, double *A, f77_int lda); -void cblas_dsyr(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dsyr(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, double alpha, const double *X, f77_int incX, double *A, f77_int lda); -void cblas_dspr(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dspr(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, double alpha, const double *X, f77_int incX, double *Ap); -void cblas_dsyr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dsyr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, double alpha, const double *X, f77_int incX, const double *Y, f77_int incY, double *A, f77_int lda); -void cblas_dspr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dspr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, double alpha, const double *X, f77_int incX, const double *Y, f77_int incY, double *A); @@ -348,65 +348,65 @@ void cblas_dspr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, /* * Routines with C and Z prefixes only */ -void cblas_chemv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_chemv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, const void *alpha, const void *A, f77_int lda, const void *X, f77_int incX, const void *beta, void *Y, f77_int incY); -void cblas_chbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_chbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, f77_int K, const void *alpha, const void *A, f77_int lda, const void *X, f77_int incX, const void *beta, void *Y, f77_int incY); -void cblas_chpmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_chpmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, const void *alpha, const void *Ap, const void *X, f77_int incX, const void *beta, void *Y, f77_int incY); -void cblas_cgeru(enum CBLAS_ORDER order, f77_int M, f77_int N, +void BLIS_EXPORT_BLAS cblas_cgeru(enum CBLAS_ORDER order, f77_int M, f77_int N, const void *alpha, const void *X, f77_int incX, const void *Y, f77_int incY, void *A, f77_int lda); -void cblas_cgerc(enum CBLAS_ORDER order, f77_int M, f77_int N, +void BLIS_EXPORT_BLAS cblas_cgerc(enum CBLAS_ORDER order, f77_int M, f77_int N, const void *alpha, const void *X, f77_int incX, const void *Y, f77_int incY, void *A, f77_int lda); -void cblas_cher(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_cher(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, float alpha, const void *X, f77_int incX, void *A, f77_int lda); -void cblas_chpr(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_chpr(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, float alpha, const void *X, f77_int incX, void *A); -void cblas_cher2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, +void BLIS_EXPORT_BLAS cblas_cher2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, const void *alpha, const void *X, f77_int incX, const void *Y, f77_int incY, void *A, f77_int lda); -void cblas_chpr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, +void BLIS_EXPORT_BLAS cblas_chpr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, const void *alpha, const void *X, f77_int incX, const void *Y, f77_int incY, void *Ap); -void cblas_zhemv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_zhemv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, const void *alpha, const void *A, f77_int lda, const void *X, f77_int incX, const void *beta, void *Y, f77_int incY); -void cblas_zhbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_zhbmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, f77_int K, const void *alpha, const void *A, f77_int lda, const void *X, f77_int incX, const void *beta, void *Y, f77_int incY); -void cblas_zhpmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_zhpmv(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, const void *alpha, const void *Ap, const void *X, f77_int incX, const void *beta, void *Y, f77_int incY); -void cblas_zgeru(enum CBLAS_ORDER order, f77_int M, f77_int N, +void BLIS_EXPORT_BLAS cblas_zgeru(enum CBLAS_ORDER order, f77_int M, f77_int N, const void *alpha, const void *X, f77_int incX, const void *Y, f77_int incY, void *A, f77_int lda); -void cblas_zgerc(enum CBLAS_ORDER order, f77_int M, f77_int N, +void BLIS_EXPORT_BLAS cblas_zgerc(enum CBLAS_ORDER order, f77_int M, f77_int N, const void *alpha, const void *X, f77_int incX, const void *Y, f77_int incY, void *A, f77_int lda); -void cblas_zher(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_zher(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, double alpha, const void *X, f77_int incX, void *A, f77_int lda); -void cblas_zhpr(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_zhpr(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, double alpha, const void *X, f77_int incX, void *A); -void cblas_zher2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, +void BLIS_EXPORT_BLAS cblas_zher2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, const void *alpha, const void *X, f77_int incX, const void *Y, f77_int incY, void *A, f77_int lda); -void cblas_zhpr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, +void BLIS_EXPORT_BLAS cblas_zhpr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, const void *alpha, const void *X, f77_int incX, const void *Y, f77_int incY, void *Ap); @@ -419,121 +419,121 @@ void cblas_zhpr2(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, f77_int N, /* * Routines with standard 4 prefixes (S, D, C, Z) */ -void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, +void BLIS_EXPORT_BLAS cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N, f77_int K, float alpha, const float *A, f77_int lda, const float *B, f77_int ldb, float beta, float *C, f77_int ldc); -void cblas_ssymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_ssymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, f77_int M, f77_int N, float alpha, const float *A, f77_int lda, const float *B, f77_int ldb, float beta, float *C, f77_int ldc); -void cblas_ssyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ssyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, f77_int N, f77_int K, float alpha, const float *A, f77_int lda, float beta, float *C, f77_int ldc); -void cblas_ssyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_ssyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, f77_int N, f77_int K, float alpha, const float *A, f77_int lda, const float *B, f77_int ldb, float beta, float *C, f77_int ldc); -void cblas_strmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_strmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int M, f77_int N, float alpha, const float *A, f77_int lda, float *B, f77_int ldb); -void cblas_strsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_strsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int M, f77_int N, float alpha, const float *A, f77_int lda, float *B, f77_int ldb); -void cblas_dgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, +void BLIS_EXPORT_BLAS cblas_dgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N, f77_int K, double alpha, const double *A, f77_int lda, const double *B, f77_int ldb, double beta, double *C, f77_int ldc); -void cblas_dsymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_dsymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, f77_int M, f77_int N, double alpha, const double *A, f77_int lda, const double *B, f77_int ldb, double beta, double *C, f77_int ldc); -void cblas_dsyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dsyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, f77_int N, f77_int K, double alpha, const double *A, f77_int lda, double beta, double *C, f77_int ldc); -void cblas_dsyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_dsyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, f77_int N, f77_int K, double alpha, const double *A, f77_int lda, const double *B, f77_int ldb, double beta, double *C, f77_int ldc); -void cblas_dtrmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_dtrmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int M, f77_int N, double alpha, const double *A, f77_int lda, double *B, f77_int ldb); -void cblas_dtrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_dtrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int M, f77_int N, double alpha, const double *A, f77_int lda, double *B, f77_int ldb); -void cblas_cgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, +void BLIS_EXPORT_BLAS cblas_cgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N, f77_int K, const void *alpha, const void *A, f77_int lda, const void *B, f77_int ldb, const void *beta, void *C, f77_int ldc); -void cblas_csymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_csymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, f77_int M, f77_int N, const void *alpha, const void *A, f77_int lda, const void *B, f77_int ldb, const void *beta, void *C, f77_int ldc); -void cblas_csyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_csyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, f77_int N, f77_int K, const void *alpha, const void *A, f77_int lda, const void *beta, void *C, f77_int ldc); -void cblas_csyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_csyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, f77_int N, f77_int K, const void *alpha, const void *A, f77_int lda, const void *B, f77_int ldb, const void *beta, void *C, f77_int ldc); -void cblas_ctrmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_ctrmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int M, f77_int N, const void *alpha, const void *A, f77_int lda, void *B, f77_int ldb); -void cblas_ctrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_ctrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int M, f77_int N, const void *alpha, const void *A, f77_int lda, void *B, f77_int ldb); -void cblas_zgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, +void BLIS_EXPORT_BLAS cblas_zgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N, f77_int K, const void *alpha, const void *A, f77_int lda, const void *B, f77_int ldb, const void *beta, void *C, f77_int ldc); -void cblas_zsymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_zsymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, f77_int M, f77_int N, const void *alpha, const void *A, f77_int lda, const void *B, f77_int ldb, const void *beta, void *C, f77_int ldc); -void cblas_zsyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_zsyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, f77_int N, f77_int K, const void *alpha, const void *A, f77_int lda, const void *beta, void *C, f77_int ldc); -void cblas_zsyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_zsyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, f77_int N, f77_int K, const void *alpha, const void *A, f77_int lda, const void *B, f77_int ldb, const void *beta, void *C, f77_int ldc); -void cblas_ztrmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_ztrmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int M, f77_int N, const void *alpha, const void *A, f77_int lda, void *B, f77_int ldb); -void cblas_ztrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_ztrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, f77_int M, f77_int N, const void *alpha, const void *A, f77_int lda, @@ -543,37 +543,37 @@ void cblas_ztrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, /* * Routines with prefixes C and Z only */ -void cblas_chemm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_chemm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, f77_int M, f77_int N, const void *alpha, const void *A, f77_int lda, const void *B, f77_int ldb, const void *beta, void *C, f77_int ldc); -void cblas_cherk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_cherk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, f77_int N, f77_int K, float alpha, const void *A, f77_int lda, float beta, void *C, f77_int ldc); -void cblas_cher2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_cher2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, f77_int N, f77_int K, const void *alpha, const void *A, f77_int lda, const void *B, f77_int ldb, float beta, void *C, f77_int ldc); -void cblas_zhemm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, +void BLIS_EXPORT_BLAS cblas_zhemm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, f77_int M, f77_int N, const void *alpha, const void *A, f77_int lda, const void *B, f77_int ldb, const void *beta, void *C, f77_int ldc); -void cblas_zherk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_zherk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, f77_int N, f77_int K, double alpha, const void *A, f77_int lda, double beta, void *C, f77_int ldc); -void cblas_zher2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, +void BLIS_EXPORT_BLAS cblas_zher2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE Trans, f77_int N, f77_int K, const void *alpha, const void *A, f77_int lda, const void *B, f77_int ldb, double beta, void *C, f77_int ldc); -void cblas_xerbla(f77_int p, const char *rout, const char *form, ...); +void BLIS_EXPORT_BLAS cblas_xerbla(f77_int p, const char *rout, const char *form, ...); #ifdef __cplusplus } diff --git a/frame/compat/f2c/bla_cabs1.h b/frame/compat/f2c/bla_cabs1.h index edec8c86e..753765a1d 100644 --- a/frame/compat/f2c/bla_cabs1.h +++ b/frame/compat/f2c/bla_cabs1.h @@ -34,7 +34,7 @@ #ifdef BLIS_ENABLE_BLAS -bla_real PASTEF77(s,cabs1)(bla_scomplex *z); -bla_double PASTEF77(d,cabs1)(bla_dcomplex *z); +BLIS_EXPORT_BLAS bla_real PASTEF77(s,cabs1)(bla_scomplex *z); +BLIS_EXPORT_BLAS bla_double PASTEF77(d,cabs1)(bla_dcomplex *z); #endif diff --git a/frame/compat/f2c/bla_gbmv.h b/frame/compat/f2c/bla_gbmv.h index b0fd7f30b..eb8ce2534 100644 --- a/frame/compat/f2c/bla_gbmv.h +++ b/frame/compat/f2c/bla_gbmv.h @@ -34,9 +34,9 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(c,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); -int PASTEF77(d,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); -int PASTEF77(s,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer * incx, const bla_real *beta, bla_real *y, const bla_integer *incy); -int PASTEF77(z,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex * y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77(c,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77(d,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77(s,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer * incx, const bla_real *beta, bla_real *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77(z,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex * y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_hbmv.h b/frame/compat/f2c/bla_hbmv.h index 8a10c75da..1ddb83807 100644 --- a/frame/compat/f2c/bla_hbmv.h +++ b/frame/compat/f2c/bla_hbmv.h @@ -34,7 +34,7 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); -int PASTEF77(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_hpmv.h b/frame/compat/f2c/bla_hpmv.h index fc744f985..26d055eff 100644 --- a/frame/compat/f2c/bla_hpmv.h +++ b/frame/compat/f2c/bla_hpmv.h @@ -34,7 +34,7 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *ap, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); -int PASTEF77(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *ap, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *ap, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *ap, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_hpr.h b/frame/compat/f2c/bla_hpr.h index 1c1a96fc8..cfce9e177 100644 --- a/frame/compat/f2c/bla_hpr.h +++ b/frame/compat/f2c/bla_hpr.h @@ -34,7 +34,7 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_scomplex *x, const bla_integer *incx, bla_scomplex *ap); -int PASTEF77(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_dcomplex *x, const bla_integer *incx, bla_dcomplex *ap); +BLIS_EXPORT_BLAS int PASTEF77(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_scomplex *x, const bla_integer *incx, bla_scomplex *ap); +BLIS_EXPORT_BLAS int PASTEF77(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_dcomplex *x, const bla_integer *incx, bla_dcomplex *ap); #endif diff --git a/frame/compat/f2c/bla_hpr2.h b/frame/compat/f2c/bla_hpr2.h index 766974eaf..16f929d61 100644 --- a/frame/compat/f2c/bla_hpr2.h +++ b/frame/compat/f2c/bla_hpr2.h @@ -34,7 +34,7 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *y, const bla_integer *incy, bla_scomplex *ap); -int PASTEF77(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *y, const bla_integer *incy, bla_dcomplex *ap); +BLIS_EXPORT_BLAS int PASTEF77(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *y, const bla_integer *incy, bla_scomplex *ap); +BLIS_EXPORT_BLAS int PASTEF77(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *y, const bla_integer *incy, bla_dcomplex *ap); #endif diff --git a/frame/compat/f2c/bla_lsame.h b/frame/compat/f2c/bla_lsame.h index 738ce08b8..656032688 100644 --- a/frame/compat/f2c/bla_lsame.h +++ b/frame/compat/f2c/bla_lsame.h @@ -37,7 +37,7 @@ #ifdef LAPACK_ILP64 long PASTEF770(lsame)(const char *ca, const char *cb, long ca_len, long cb_len); #else -int PASTEF770(lsame)(const char *ca, const char *cb, int ca_len, int cb_len); +BLIS_EXPORT_BLAS int PASTEF770(lsame)(const char *ca, const char *cb, int ca_len, int cb_len); #endif #endif diff --git a/frame/compat/f2c/bla_rot.h b/frame/compat/f2c/bla_rot.h index 1532a7cfc..609355560 100644 --- a/frame/compat/f2c/bla_rot.h +++ b/frame/compat/f2c/bla_rot.h @@ -34,9 +34,9 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(s,rot)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *c__, const bla_real *s); -int PASTEF77(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *c__, const bla_double *s); -int PASTEF77(cs,rot)(const bla_integer *n, bla_scomplex *cx, const bla_integer *incx, bla_scomplex *cy, const bla_integer *incy, const bla_real *c__, const bla_real *s); -int PASTEF77(zd,rot)(const bla_integer *n, bla_dcomplex *zx, const bla_integer *incx, bla_dcomplex *zy, const bla_integer *incy, const bla_double *c__, const bla_double *s); +BLIS_EXPORT_BLAS int PASTEF77(s,rot)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *c__, const bla_real *s); +BLIS_EXPORT_BLAS int PASTEF77(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *c__, const bla_double *s); +BLIS_EXPORT_BLAS int PASTEF77(cs,rot)(const bla_integer *n, bla_scomplex *cx, const bla_integer *incx, bla_scomplex *cy, const bla_integer *incy, const bla_real *c__, const bla_real *s); +BLIS_EXPORT_BLAS int PASTEF77(zd,rot)(const bla_integer *n, bla_dcomplex *zx, const bla_integer *incx, bla_dcomplex *zy, const bla_integer *incy, const bla_double *c__, const bla_double *s); #endif diff --git a/frame/compat/f2c/bla_rotg.h b/frame/compat/f2c/bla_rotg.h index c89f0279b..b968ebbea 100644 --- a/frame/compat/f2c/bla_rotg.h +++ b/frame/compat/f2c/bla_rotg.h @@ -34,9 +34,9 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, bla_real *s); -int PASTEF77(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_double *s); -int PASTEF77(c,rotg)(bla_scomplex *ca, bla_scomplex *cb, bla_real *c__, bla_scomplex *s); -int PASTEF77(z,rotg)(bla_dcomplex *ca, bla_dcomplex *cb, bla_double *c__, bla_dcomplex *s); +BLIS_EXPORT_BLAS int PASTEF77(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, bla_real *s); +BLIS_EXPORT_BLAS int PASTEF77(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_double *s); +BLIS_EXPORT_BLAS int PASTEF77(c,rotg)(bla_scomplex *ca, bla_scomplex *cb, bla_real *c__, bla_scomplex *s); +BLIS_EXPORT_BLAS int PASTEF77(z,rotg)(bla_dcomplex *ca, bla_dcomplex *cb, bla_double *c__, bla_dcomplex *s); #endif diff --git a/frame/compat/f2c/bla_rotm.h b/frame/compat/f2c/bla_rotm.h index d28f0919b..21906358b 100644 --- a/frame/compat/f2c/bla_rotm.h +++ b/frame/compat/f2c/bla_rotm.h @@ -34,7 +34,7 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *sparam); -int PASTEF77(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *dparam); +BLIS_EXPORT_BLAS int PASTEF77(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *sparam); +BLIS_EXPORT_BLAS int PASTEF77(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *dparam); #endif diff --git a/frame/compat/f2c/bla_rotmg.h b/frame/compat/f2c/bla_rotmg.h index 29d42a90f..63e9710da 100644 --- a/frame/compat/f2c/bla_rotmg.h +++ b/frame/compat/f2c/bla_rotmg.h @@ -34,7 +34,7 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_real *sy1, bla_real *sparam); -int PASTEF77(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const bla_double *dy1, bla_double *dparam); +BLIS_EXPORT_BLAS int PASTEF77(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_real *sy1, bla_real *sparam); +BLIS_EXPORT_BLAS int PASTEF77(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const bla_double *dy1, bla_double *dparam); #endif diff --git a/frame/compat/f2c/bla_sbmv.h b/frame/compat/f2c/bla_sbmv.h index 75442fc7c..c3f3fc24f 100644 --- a/frame/compat/f2c/bla_sbmv.h +++ b/frame/compat/f2c/bla_sbmv.h @@ -34,7 +34,7 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); -int PASTEF77(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_spmv.h b/frame/compat/f2c/bla_spmv.h index 8e0d0c1ff..7db7d4a8b 100644 --- a/frame/compat/f2c/bla_spmv.h +++ b/frame/compat/f2c/bla_spmv.h @@ -34,7 +34,7 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *ap, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); -int PASTEF77(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *ap, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *ap, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *ap, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_spr.h b/frame/compat/f2c/bla_spr.h index af63cea52..6712d7c16 100644 --- a/frame/compat/f2c/bla_spr.h +++ b/frame/compat/f2c/bla_spr.h @@ -34,7 +34,7 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, bla_double *ap); -int PASTEF77(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, bla_real *ap); +BLIS_EXPORT_BLAS int PASTEF77(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, bla_double *ap); +BLIS_EXPORT_BLAS int PASTEF77(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, bla_real *ap); #endif diff --git a/frame/compat/f2c/bla_spr2.h b/frame/compat/f2c/bla_spr2.h index b6322cc37..5a1d60747 100644 --- a/frame/compat/f2c/bla_spr2.h +++ b/frame/compat/f2c/bla_spr2.h @@ -34,7 +34,7 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, const bla_double *y, const bla_integer *incy, bla_double *ap); -int PASTEF77(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, const bla_real *y, const bla_integer *incy, bla_real *ap); +BLIS_EXPORT_BLAS int PASTEF77(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, const bla_double *y, const bla_integer *incy, bla_double *ap); +BLIS_EXPORT_BLAS int PASTEF77(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, const bla_real *y, const bla_integer *incy, bla_real *ap); #endif diff --git a/frame/compat/f2c/bla_tbmv.h b/frame/compat/f2c/bla_tbmv.h index c524f0ee2..f34654762 100644 --- a/frame/compat/f2c/bla_tbmv.h +++ b/frame/compat/f2c/bla_tbmv.h @@ -34,9 +34,9 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(c,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx); -int PASTEF77(d,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx); -int PASTEF77(s,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx); -int PASTEF77(z,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(c,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(d,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(s,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(z,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx); #endif diff --git a/frame/compat/f2c/bla_tbsv.h b/frame/compat/f2c/bla_tbsv.h index e48de07e1..5e84f5c36 100644 --- a/frame/compat/f2c/bla_tbsv.h +++ b/frame/compat/f2c/bla_tbsv.h @@ -34,9 +34,9 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(c,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx); -int PASTEF77(d,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx); -int PASTEF77(s,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx); -int PASTEF77(z,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(c,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(d,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(s,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(z,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx); #endif diff --git a/frame/compat/f2c/bla_tpmv.h b/frame/compat/f2c/bla_tpmv.h index 095d7d414..2376ecfe3 100644 --- a/frame/compat/f2c/bla_tpmv.h +++ b/frame/compat/f2c/bla_tpmv.h @@ -34,9 +34,9 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(c,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx); -int PASTEF77(d,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx); -int PASTEF77(s,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx); -int PASTEF77(z,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(c,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(d,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(s,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(z,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx); #endif diff --git a/frame/compat/f2c/bla_tpsv.h b/frame/compat/f2c/bla_tpsv.h index 9c3de2ea0..77bd55979 100644 --- a/frame/compat/f2c/bla_tpsv.h +++ b/frame/compat/f2c/bla_tpsv.h @@ -34,9 +34,9 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF77(c,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx); -int PASTEF77(d,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx); -int PASTEF77(s,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx); -int PASTEF77(z,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(c,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(d,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(s,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77(z,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx); #endif diff --git a/frame/compat/f2c/bla_xerbla.h b/frame/compat/f2c/bla_xerbla.h index 4110cf281..44c168e58 100644 --- a/frame/compat/f2c/bla_xerbla.h +++ b/frame/compat/f2c/bla_xerbla.h @@ -34,6 +34,6 @@ #ifdef BLIS_ENABLE_BLAS -int PASTEF770(xerbla)(const bla_character *srname, const bla_integer *info, ftnlen srname_len); +BLIS_EXPORT_BLAS int PASTEF770(xerbla)(const bla_character *srname, const bla_integer *info, ftnlen srname_len); #endif diff --git a/frame/include/bli_arch_config.h b/frame/include/bli_arch_config.h index 9f9eee19b..12cc01852 100644 --- a/frame/include/bli_arch_config.h +++ b/frame/include/bli_arch_config.h @@ -177,6 +177,9 @@ CNTX_INIT_PROTS( generic ) // -- ARM architectures -- +#ifdef BLIS_FAMILY_THUNDERX2 +#include "bli_family_thunderx2.h" +#endif #ifdef BLIS_FAMILY_CORTEXA57 #include "bli_family_cortexa57.h" #endif diff --git a/frame/include/bli_config_macro_defs.h b/frame/include/bli_config_macro_defs.h index 46f78c27f..cef0b8432 100644 --- a/frame/include/bli_config_macro_defs.h +++ b/frame/include/bli_config_macro_defs.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -45,11 +46,11 @@ // internally within BLIS as well as those exposed in the native BLAS-like BLIS // interface. #ifndef BLIS_INT_TYPE_SIZE -#ifdef BLIS_ARCH_64 -#define BLIS_INT_TYPE_SIZE 64 -#else -#define BLIS_INT_TYPE_SIZE 32 -#endif + #ifdef BLIS_ARCH_64 + #define BLIS_INT_TYPE_SIZE 64 + #else + #define BLIS_INT_TYPE_SIZE 32 + #endif #endif @@ -157,7 +158,19 @@ // C99 type "long int". Note that this ONLY affects integers used within the // BLAS compatibility layer. #ifndef BLIS_BLAS_INT_TYPE_SIZE -#define BLIS_BLAS_INT_TYPE_SIZE 32 + #define BLIS_BLAS_INT_TYPE_SIZE 32 +#endif + +// By default, the level-3 BLAS routines are implemented by directly calling +// the BLIS object API. Alternatively, they may first call the typed BLIS +// API, which will then call the object API. +//#define BLIS_BLAS3_CALLS_TAPI +#ifdef BLIS_BLAS3_CALLS_TAPI + #undef BLIS_BLAS3_CALLS_OAPI +#else + // Default behavior is to call object API directly. + #undef BLIS_BLAS3_CALLS_OAPI // In case user explicitly enabled. + #define BLIS_BLAS3_CALLS_OAPI #endif @@ -176,5 +189,41 @@ #endif +// -- SHARED LIBRARY SYMBOL EXPORT --------------------------------------------- + +// When building shared libraries, we can control which symbols are exported for +// linking by external applications. BLIS annotates all function prototypes that +// are meant to be "public" with BLIS_EXPORT_BLIS (with BLIS_EXPORT_BLAS playing +// a similar role for BLAS compatibility routines). Which symbols are exported +// is controlled by the default symbol visibility, as specifed by the gcc option +// -fvisibility=[default|hidden]. The default for this option is 'default', or, +// "public", which, if allowed to stand, causes all symbols in BLIS to be +// linkable from the outside. But when compiling with -fvisibility=hidden, all +// symbols start out hidden (that is, restricted only for internal use by BLIS), +// with that setting overridden only for function prototypes or variable +// declarations that are annotated with BLIS_EXPORT_BLIS. + +#ifndef BLIS_EXPORT + #if !defined(BLIS_ENABLE_SHARED) + #define BLIS_EXPORT + #else + #if defined(_WIN32) || defined(__CYGWIN__) + #ifdef BLIS_IS_BUILDING_LIBRARY + #define BLIS_EXPORT __declspec(dllexport) + #else + #define BLIS_EXPORT __declspec(dllimport) + #endif + #elif defined(__GNUC__) && __GNUC__ >= 4 + #define BLIS_EXPORT __attribute__ ((visibility ("default"))) + #else + #define BLIS_EXPORT + #endif + #endif +#endif + +#define BLIS_EXPORT_BLIS BLIS_EXPORT +#define BLIS_EXPORT_BLAS BLIS_EXPORT + + #endif diff --git a/frame/include/bli_extern_defs.h b/frame/include/bli_extern_defs.h index d577de4bf..9773e5e69 100644 --- a/frame/include/bli_extern_defs.h +++ b/frame/include/bli_extern_defs.h @@ -35,28 +35,16 @@ #ifndef BLIS_EXTERN_DEFS_H #define BLIS_EXTERN_DEFS_H -#if !defined(BLIS_ENABLE_SHARED) || !defined(_MSC_VER) -#define BLIS_EXPORT -#else -// Windows builds require us to explicitly identify global variable symbols -// to be imported from the .dll. -#ifdef BLIS_IS_BUILDING_LIBRARY -#define BLIS_EXPORT __declspec(dllexport) -#else -#define BLIS_EXPORT __declspec(dllimport) -#endif -#endif +BLIS_EXPORT_BLIS extern obj_t BLIS_TWO; +BLIS_EXPORT_BLIS extern obj_t BLIS_ONE; +//BLIS_EXPORT_BLIS extern obj_t BLIS_ONE_HALF; +BLIS_EXPORT_BLIS extern obj_t BLIS_ZERO; +//BLIS_EXPORT_BLIS extern obj_t BLIS_MINUS_ONE_HALF; +BLIS_EXPORT_BLIS extern obj_t BLIS_MINUS_ONE; +BLIS_EXPORT_BLIS extern obj_t BLIS_MINUS_TWO; -BLIS_EXPORT extern obj_t BLIS_TWO; -BLIS_EXPORT extern obj_t BLIS_ONE; -//BLIS_EXPORT extern obj_t BLIS_ONE_HALF; -BLIS_EXPORT extern obj_t BLIS_ZERO; -//BLIS_EXPORT extern obj_t BLIS_MINUS_ONE_HALF; -BLIS_EXPORT extern obj_t BLIS_MINUS_ONE; -BLIS_EXPORT extern obj_t BLIS_MINUS_TWO; - -BLIS_EXPORT extern thrcomm_t BLIS_SINGLE_COMM; -BLIS_EXPORT extern thrinfo_t BLIS_PACKM_SINGLE_THREADED; -BLIS_EXPORT extern thrinfo_t BLIS_GEMM_SINGLE_THREADED; +BLIS_EXPORT_BLIS extern thrcomm_t BLIS_SINGLE_COMM; +BLIS_EXPORT_BLIS extern thrinfo_t BLIS_PACKM_SINGLE_THREADED; +BLIS_EXPORT_BLIS extern thrinfo_t BLIS_GEMM_SINGLE_THREADED; #endif diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index ff23597d1..907a5a26c 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -140,6 +140,9 @@ #define PASTEBLACHK_(op) bla_ ## op ## _check #define PASTEBLACHK(op) PASTEBLACHK_(op) +#define PASTECH0_(op) op +#define PASTECH0(op) PASTECH0_(op) + #define PASTECH_(ch,op) ch ## op #define PASTECH(ch,op) PASTECH_(ch,op) @@ -153,10 +156,10 @@ #define STRINGIFY_INT( s ) MKSTR( s ) // Fortran-77 name-mangling macros. -#define PASTEF770(name) name ## _ -#define PASTEF77(ch1,name) ch1 ## name ## _ -#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name ## _ -#define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _ +#define PASTEF770(name) name ## _ +#define PASTEF77(ch1,name) ch1 ## name ## _ +#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name ## _ +#define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _ // -- Include other groups of macros diff --git a/frame/include/bli_obj_macro_defs.h b/frame/include/bli_obj_macro_defs.h index e3eb2b874..ccfcf096b 100644 --- a/frame/include/bli_obj_macro_defs.h +++ b/frame/include/bli_obj_macro_defs.h @@ -6,6 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1127,6 +1128,93 @@ static void bli_obj_set_panel_stride( inc_t ps, obj_t* obj ) obj->ps = ps; } +// stor3_t-related + +static stor3_t bli_obj_stor3_from_strides( obj_t* c, obj_t* a, obj_t* b ) +{ + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + inc_t rs_a, cs_a; + inc_t rs_b, cs_b; + + if ( bli_obj_has_notrans( a ) ) + { + rs_a = bli_obj_row_stride( a ); + cs_a = bli_obj_col_stride( a ); + } + else + { + rs_a = bli_obj_col_stride( a ); + cs_a = bli_obj_row_stride( a ); + } + + if ( bli_obj_has_notrans( b ) ) + { + rs_b = bli_obj_row_stride( b ); + cs_b = bli_obj_col_stride( b ); + } + else + { + rs_b = bli_obj_col_stride( b ); + cs_b = bli_obj_row_stride( b ); + } + + return bli_stor3_from_strides( rs_c, cs_c, + rs_a, cs_a, + rs_b, cs_b ); +} + + +// -- Initialization-related macros -- + +// Finish the initialization started by the matrix-specific static initializer +// (e.g. BLIS_OBJECT_PREINITIALIZER) +// NOTE: This is intended only for use in the BLAS compatibility API and typed +// BLIS API. + +static void bli_obj_init_finish( num_t dt, dim_t m, dim_t n, void* p, inc_t rs, inc_t cs, obj_t* obj ) +{ + bli_obj_set_as_root( obj ); + + bli_obj_set_dt( dt, obj ); + bli_obj_set_target_dt( dt, obj ); + bli_obj_set_exec_dt( dt, obj ); + bli_obj_set_comp_dt( dt, obj ); + + bli_obj_set_dims( m, n, obj ); + bli_obj_set_strides( rs, cs, obj ); + + siz_t elem_size = sizeof( float ); + if ( bli_dt_prec_is_double( dt ) ) elem_size *= 2; + if ( bli_dt_dom_is_complex( dt ) ) elem_size *= 2; + bli_obj_set_elem_size( elem_size, obj ); + + bli_obj_set_buffer( p, obj ); + + bli_obj_set_scalar_dt( dt, obj ); + void* restrict s = bli_obj_internal_scalar_buffer( obj ); + + if ( bli_dt_prec_is_single( dt ) ) { (( scomplex* )s)->real = 1.0F; + (( scomplex* )s)->imag = 0.0F; } + else if ( bli_dt_prec_is_double( dt ) ) { (( dcomplex* )s)->real = 1.0; + (( dcomplex* )s)->imag = 0.0; } +} + +// Finish the initialization started by the 1x1-specific static initializer +// (e.g. BLIS_OBJECT_PREINITIALIZER_1X1) +// NOTE: This is intended only for use in the BLAS compatibility API and typed +// BLIS API. + +static void bli_obj_init_finish_1x1( num_t dt, void* p, obj_t* obj ) +{ + bli_obj_set_as_root( obj ); + + bli_obj_set_dt( dt, obj ); + + bli_obj_set_buffer( p, obj ); +} + // -- Miscellaneous object macros -- // Toggle the region referenced (or "stored"). @@ -1158,38 +1246,6 @@ static void bli_obj_set_defaults( obj_t* obj ) obj->info = obj->info | BLIS_BITVAL_DENSE | BLIS_BITVAL_GENERAL; } -// Initializors for global scalar constants. -// NOTE: These must remain cpp macros since they are initializor -// expressions, not functions. - -#define bli_obj_init_const( buffer0 ) \ -{ \ - .root = NULL, \ -\ - .off = { 0, 0 }, \ - .dim = { 1, 1 }, \ - .diag_off = 0, \ -\ - .info = 0x0 | BLIS_BITVAL_CONST_TYPE | \ - BLIS_BITVAL_DENSE | \ - BLIS_BITVAL_GENERAL, \ - .elem_size = sizeof( constdata_t ), \ -\ - .buffer = buffer0, \ - .rs = 1, \ - .cs = 1, \ - .is = 1 \ -} - -#define bli_obj_init_constdata( val ) \ -{ \ - .s = ( float )val, \ - .d = ( double )val, \ - .c = { .real = ( float )val, .imag = 0.0f }, \ - .z = { .real = ( double )val, .imag = 0.0 }, \ - .i = ( gint_t )val, \ -} - // Acquire buffer at object's submatrix offset (offset-aware buffer query). static void* bli_obj_buffer_at_off( obj_t* obj ) @@ -1401,7 +1457,32 @@ static void bli_obj_induce_trans( obj_t* obj ) bli_obj_set_panel_dims( n_panel, m_panel, obj ); // Note that this macro DOES NOT touch the transposition bit! If - // the calling code is using this macro to handle an object whose + // the calling code is using this function to handle an object whose + // transposition bit is set prior to computation, that code needs + // to manually clear or toggle the bit, via + // bli_obj_set_onlytrans() or bli_obj_toggle_trans(), + // respectively. +} + +static void bli_obj_induce_fast_trans( obj_t* obj ) +{ + // NOTE: This function is only used in situations where the matrices + // are guaranteed to not have structure or be packed. + + // Induce transposition among basic fields. + dim_t m = bli_obj_length( obj ); + dim_t n = bli_obj_width( obj ); + inc_t rs = bli_obj_row_stride( obj ); + inc_t cs = bli_obj_col_stride( obj ); + dim_t offm = bli_obj_row_off( obj ); + dim_t offn = bli_obj_col_off( obj ); + + bli_obj_set_dims( n, m, obj ); + bli_obj_set_strides( cs, rs, obj ); + bli_obj_set_offs( offn, offm, obj ); + + // Note that this macro DOES NOT touch the transposition bit! If + // the calling code is using this function to handle an object whose // transposition bit is set prior to computation, that code needs // to manually clear or toggle the bit, via // bli_obj_set_onlytrans() or bli_obj_toggle_trans(), diff --git a/frame/include/bli_param_macro_defs.h b/frame/include/bli_param_macro_defs.h index b22949c07..cc1737e91 100644 --- a/frame/include/bli_param_macro_defs.h +++ b/frame/include/bli_param_macro_defs.h @@ -132,12 +132,36 @@ static dom_t bli_dt_domain( num_t dt ) ( dt & BLIS_DOMAIN_BIT ); } +static bool_t bli_dt_dom_is_real( num_t dt ) +{ + return ( bool_t ) + ( ( dt & BLIS_DOMAIN_BIT ) == BLIS_REAL ); +} + +static bool_t bli_dt_dom_is_complex( num_t dt ) +{ + return ( bool_t ) + ( ( dt & BLIS_DOMAIN_BIT ) == BLIS_COMPLEX ); +} + static prec_t bli_dt_prec( num_t dt ) { return ( prec_t ) ( dt & BLIS_PRECISION_BIT ); } +static bool_t bli_dt_prec_is_single( num_t dt ) +{ + return ( bool_t ) + ( ( dt & BLIS_PRECISION_BIT ) == BLIS_SINGLE_PREC ); +} + +static bool_t bli_dt_prec_is_double( num_t dt ) +{ + return ( bool_t ) + ( ( dt & BLIS_PRECISION_BIT ) == BLIS_DOUBLE_PREC ); +} + static num_t bli_dt_proj_to_real( num_t dt ) { return ( num_t ) @@ -765,6 +789,97 @@ static void bli_toggle_dim( mdim_t* mdim ) } +// stor3_t-related + +static stor3_t bli_stor3_from_strides( inc_t rs_c, inc_t cs_c, + inc_t rs_a, inc_t cs_a, + inc_t rs_b, inc_t cs_b ) +{ + // If any matrix is general-stored, return the stor3_t id for the + // general-purpose sup microkernel. + if ( bli_is_gen_stored( rs_c, cs_c ) || + bli_is_gen_stored( rs_a, cs_a ) || + bli_is_gen_stored( rs_b, cs_b ) ) return BLIS_XXX; + + // Otherwise, compute and return the stor3_t id as follows. + const bool_t c_is_col = bli_is_col_stored( rs_c, cs_c ); + const bool_t a_is_col = bli_is_col_stored( rs_a, cs_a ); + const bool_t b_is_col = bli_is_col_stored( rs_b, cs_b ); + + return ( stor3_t )( 4 * c_is_col + + 2 * a_is_col + + 1 * b_is_col ); +} + +static stor3_t bli_stor3_trans( stor3_t id ) +{ +#if 1 + stor3_t map[ BLIS_NUM_3OP_RC_COMBOS ] + = + { + ( stor3_t )7, // BLIS_RRR = 0 -> BLIS_CCC = 7 + ( stor3_t )5, // BLIS_RRC = 1 -> BLIS_CRC = 5 + ( stor3_t )6, // BLIS_RCR = 2 -> BLIS_CCR = 6 + ( stor3_t )4, // BLIS_RCC = 3 -> BLIS_CRR = 4 + ( stor3_t )3, // BLIS_CRR = 4 -> BLIS_RCC = 3 + ( stor3_t )1, // BLIS_CRC = 5 -> BLIS_RRC = 1 + ( stor3_t )2, // BLIS_CCR = 6 -> BLIS_RCR = 2 + ( stor3_t )0, // BLIS_CCC = 7 -> BLIS_RRR = 0 + }; + + return map[id]; +#else + return ( ( id & 0x4 ) ^ 0x4 ) | // flip c bit + ( ( ( id & 0x1 ) ^ 0x1 ) << 1 ) | // flip b bit and move to a position + ( ( ( id & 0x2 ) ^ 0x2 ) >> 1 ); // flip a bit and move to b position +#endif +} + +static stor3_t bli_stor3_transa( stor3_t id ) +{ +#if 0 + stor3_t map[ BLIS_NUM_3OP_RC_COMBOS ] + = + { + ( stor3_t )1, // BLIS_RRR = 0 -> BLIS_RRC = 1 + ( stor3_t )0, // BLIS_RRC = 1 -> BLIS_RRR = 0 + ( stor3_t )3, // BLIS_RCR = 2 -> BLIS_RCC = 3 + ( stor3_t )2, // BLIS_RCC = 3 -> BLIS_RCR = 2 + ( stor3_t )5, // BLIS_CRR = 4 -> BLIS_CRC = 5 + ( stor3_t )4, // BLIS_CRC = 5 -> BLIS_CRR = 4 + ( stor3_t )7, // BLIS_CCR = 6 -> BLIS_CCC = 7 + ( stor3_t )6, // BLIS_CCC = 7 -> BLIS_CCR = 6 + }; + + return map[id]; +#else + return ( stor3_t )( id ^ 0x1 ); +#endif +} + +static stor3_t bli_stor3_transb( stor3_t id ) +{ +#if 0 + stor3_t map[ BLIS_NUM_3OP_RC_COMBOS ] + = + { + ( stor3_t )2, // BLIS_RRR = 0 -> BLIS_RCR = 2 + ( stor3_t )3, // BLIS_RRC = 1 -> BLIS_RCC = 3 + ( stor3_t )0, // BLIS_RCR = 2 -> BLIS_RRR = 0 + ( stor3_t )1, // BLIS_RCC = 3 -> BLIS_RRC = 1 + ( stor3_t )6, // BLIS_CRR = 4 -> BLIS_CCR = 6 + ( stor3_t )7, // BLIS_CRC = 5 -> BLIS_CCC = 7 + ( stor3_t )4, // BLIS_CCR = 6 -> BLIS_CRR = 4 + ( stor3_t )5, // BLIS_CCC = 7 -> BLIS_CRC = 5 + }; + + return map[id]; +#else + return ( stor3_t )( id ^ 0x2 ); +#endif +} + + // index-related @@ -938,30 +1053,19 @@ static guint_t bli_pack_schema_index( pack_t schema ) // Increment a pointer by an integer fraction: // p0 + (num/dem) // where p0 is a pointer to a datatype of size sizeof_p0. -static void* bli_ptr_inc_by_frac( void* p0, siz_t sizeof_p0, dim_t num, dim_t den ) +static void_fp bli_ptr_inc_by_frac( void_fp p0, siz_t sizeof_p0, dim_t num, dim_t den ) { - return ( void* ) + return ( void_fp ) ( ( char* )p0 + ( ( num * ( dim_t )sizeof_p0 ) / den ) ); } -static bool_t bli_is_null( void* p ) -{ - return ( bool_t ) - ( p == NULL ); -} - -static bool_t bli_is_nonnull( void* p ) -{ - return ( bool_t ) - ( p != NULL ); -} // Set dimensions, increments, effective uplo/diagoff, etc for ONE matrix // argument. static -void bli_set_dims_incs_uplo_1m +void bli_set_dims_incs_uplo_1m ( doff_t diagoffa, diag_t diaga, uplo_t uploa, dim_t m, dim_t n, inc_t rs_a, inc_t cs_a, @@ -1056,7 +1160,7 @@ void bli_set_dims_incs_uplo_1m // argument (without column-wise stride optimization). static -void bli_set_dims_incs_uplo_1m_noswap +void bli_set_dims_incs_uplo_1m_noswap ( doff_t diagoffa, diag_t diaga, uplo_t uploa, dim_t m, dim_t n, inc_t rs_a, inc_t cs_a, @@ -1142,7 +1246,7 @@ void bli_set_dims_incs_uplo_1m_noswap // Set dimensions and increments for TWO matrix arguments. static -void bli_set_dims_incs_2m +void bli_set_dims_incs_2m ( trans_t transa, dim_t m, dim_t n, inc_t rs_a, inc_t cs_a, @@ -1178,7 +1282,7 @@ void bli_set_dims_incs_2m // arguments. static -void bli_set_dims_incs_uplo_2m +void bli_set_dims_incs_uplo_2m ( doff_t diagoffa, diag_t diaga, trans_t transa, uplo_t uploa, dim_t m, dim_t n, inc_t rs_a, inc_t cs_a, @@ -1286,7 +1390,7 @@ void bli_set_dims_incs_uplo_2m // on the diagonal. static -void bli_set_dims_incs_1d +void bli_set_dims_incs_1d ( doff_t diagoffx, dim_t m, dim_t n, inc_t rs_x, inc_t cs_x, @@ -1310,7 +1414,7 @@ void bli_set_dims_incs_1d // Set dimensions, increments, etc for TWO matrix arguments when operating // on diagonals. static -void bli_set_dims_incs_2d +void bli_set_dims_incs_2d ( doff_t diagoffx, trans_t transx, dim_t m, dim_t n, inc_t rs_x, inc_t cs_x, diff --git a/frame/include/bli_system.h b/frame/include/bli_system.h index 084999ab5..d91df6803 100644 --- a/frame/include/bli_system.h +++ b/frame/include/bli_system.h @@ -36,6 +36,12 @@ #ifndef BLIS_SYSTEM_H #define BLIS_SYSTEM_H +// NOTE: If not yet defined, we define _POSIX_C_SOURCE to make sure that +// various parts of POSIX are defined and made available. +#ifndef _POSIX_C_SOURCE +#define _POSIX_C_SOURCE 200809L +#endif + #include #include #include @@ -111,7 +117,8 @@ #elif BLIS_OS_OSX #include #else - #include + //#include + #include #endif diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 2b778e663..ea06a11ec 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -197,6 +197,16 @@ typedef double f77_double; typedef scomplex f77_scomplex; typedef dcomplex f77_dcomplex; +// -- Void function pointer types -- + +// Note: This type should be used in any situation where the address of a +// *function* will be conveyed or stored prior to it being typecast back +// to the correct function type. It does not need to be used when conveying +// or storing the address of *data* (such as an array of float or double). + +//typedef void (*void_fp)( void ); +typedef void* void_fp; + // // -- BLIS info bit field offsets ---------------------------------------------- @@ -801,6 +811,80 @@ typedef enum #define BLIS_NUM_UKR_IMPL_TYPES 4 +#if 0 +typedef enum +{ + // RV = row-stored, contiguous vector-loading + // RG = row-stored, non-contiguous gather-loading + // CV = column-stored, contiguous vector-loading + // CG = column-stored, non-contiguous gather-loading + + // RD = row-stored, dot-based + // CD = col-stored, dot-based + + // RC = row-stored, column-times-column + // CR = column-stored, row-times-row + + // GX = general-stored generic implementation + + BLIS_GEMMSUP_RV_UKR = 0, + BLIS_GEMMSUP_RG_UKR, + BLIS_GEMMSUP_CV_UKR, + BLIS_GEMMSUP_CG_UKR, + + BLIS_GEMMSUP_RD_UKR, + BLIS_GEMMSUP_CD_UKR, + + BLIS_GEMMSUP_RC_UKR, + BLIS_GEMMSUP_CR_UKR, + + BLIS_GEMMSUP_GX_UKR, +} l3sup_t; + +#define BLIS_NUM_LEVEL3_SUP_UKRS 9 +#endif + + +typedef enum +{ + // 3-operand storage combinations + BLIS_RRR = 0, + BLIS_RRC, // 1 + BLIS_RCR, // 2 + BLIS_RCC, // 3 + BLIS_CRR, // 4 + BLIS_CRC, // 5 + BLIS_CCR, // 6 + BLIS_CCC, // 7 + BLIS_XXX, // 8 + +#if 0 + BLIS_RRG, + BLIS_RCG, + BLIS_RGR, + BLIS_RGC, + BLIS_RGG, + BLIS_CRG, + BLIS_CCG, + BLIS_CGR, + BLIS_CGC, + BLIS_CGG, + BLIS_GRR, + BLIS_GRC, + BLIS_GRG, + BLIS_GCR, + BLIS_GCC, + BLIS_GCG, + BLIS_GGR, + BLIS_GGC, + BLIS_GGG, +#endif +} stor3_t; + +#define BLIS_NUM_3OP_RC_COMBOS 9 +//#define BLIS_NUM_3OP_RCG_COMBOS 27 + + #if 0 typedef enum { @@ -863,8 +947,10 @@ typedef enum BLIS_MC, BLIS_KC, BLIS_NC, + BLIS_M2, // level-2 blocksize in m dimension BLIS_N2, // level-2 blocksize in n dimension + BLIS_AF, // level-1f axpyf fusing factor BLIS_DF, // level-1f dotxf fusing factor BLIS_XF, // level-1f dotxaxpyf fusing factor @@ -875,6 +961,19 @@ typedef enum #define BLIS_NUM_BLKSZS 11 +// -- Threshold ID type -- + +typedef enum +{ + BLIS_MT = 0, // level-3 small/unpacked matrix threshold in m dimension + BLIS_NT, // level-3 small/unpacked matrix threshold in n dimension + BLIS_KT // level-3 small/unpacked matrix threshold in k dimension + +} threshid_t; + +#define BLIS_NUM_THRESH 3 + + // -- Architecture ID type -- // NOTE: This typedef enum must be kept up-to-date with the arch_t @@ -1015,7 +1114,7 @@ struct cntl_s // Basic fields (usually required). opid_t family; bszid_t bszid; - void* var_func; + void_fp var_func; struct cntl_s* sub_prenode; struct cntl_s* sub_node; @@ -1048,7 +1147,7 @@ typedef struct blksz_s typedef struct func_s { // Kernel function address. - void* ptr[BLIS_NUM_FP_TYPES]; + void_fp ptr[BLIS_NUM_FP_TYPES]; } func_t; @@ -1139,6 +1238,71 @@ typedef struct obj_s dim_t n_panel; // n dimension of a "full" panel } obj_t; +// Pre-initializors. Things that must be set afterwards: +// - root object pointer +// - info bitfields: dt, target_dt, exec_dt, comp_dt +// - info2 bitfields: scalar_dt +// - elem_size +// - dims, strides +// - buffer +// - internal scalar buffer (must always set imaginary component) + +#define BLIS_OBJECT_INITIALIZER \ +{ \ + .root = NULL, \ +\ + .off = { 0, 0 }, \ + .dim = { 0, 0 }, \ + .diag_off = 0, \ +\ + .info = 0x0 | BLIS_BITVAL_DENSE | \ + BLIS_BITVAL_GENERAL, \ + .info2 = 0x0, \ + .elem_size = sizeof( float ), /* this is changed later. */ \ +\ + .buffer = NULL, \ + .rs = 0, \ + .cs = 0, \ + .is = 1, \ +\ + .scalar = { 0.0, 0.0 }, \ +\ + .m_padded = 0, \ + .n_padded = 0, \ + .ps = 0, \ + .pd = 0, \ + .m_panel = 0, \ + .n_panel = 0 \ +} + +#define BLIS_OBJECT_INITIALIZER_1X1 \ +{ \ + .root = NULL, \ +\ + .off = { 0, 0 }, \ + .dim = { 1, 1 }, \ + .diag_off = 0, \ +\ + .info = 0x0 | BLIS_BITVAL_DENSE | \ + BLIS_BITVAL_GENERAL, \ + .info2 = 0x0, \ + .elem_size = sizeof( float ), /* this is changed later. */ \ +\ + .buffer = NULL, \ + .rs = 0, \ + .cs = 0, \ + .is = 1, \ +\ + .scalar = { 0.0, 0.0 }, \ +\ + .m_padded = 0, \ + .n_padded = 0, \ + .ps = 0, \ + .pd = 0, \ + .m_panel = 0, \ + .n_panel = 0 \ +} + // Define these macros here since they must be updated if contents of // obj_t changes. @@ -1205,6 +1369,39 @@ static void bli_obj_init_subpart_from( obj_t* a, obj_t* b ) b->n_panel = a->n_panel; } +// Initializors for global scalar constants. +// NOTE: These must remain cpp macros since they are initializor +// expressions, not functions. + +#define bli_obj_init_const( buffer0 ) \ +{ \ + .root = NULL, \ +\ + .off = { 0, 0 }, \ + .dim = { 1, 1 }, \ + .diag_off = 0, \ +\ + .info = 0x0 | BLIS_BITVAL_CONST_TYPE | \ + BLIS_BITVAL_DENSE | \ + BLIS_BITVAL_GENERAL, \ + .info2 = 0x0, \ + .elem_size = sizeof( constdata_t ), \ +\ + .buffer = buffer0, \ + .rs = 1, \ + .cs = 1, \ + .is = 1 \ +} + +#define bli_obj_init_constdata( val ) \ +{ \ + .s = ( float )val, \ + .d = ( double )val, \ + .c = { .real = ( float )val, .imag = 0.0f }, \ + .z = { .real = ( double )val, .imag = 0.0 }, \ + .i = ( gint_t )val, \ +} + // -- Context type -- @@ -1217,6 +1414,12 @@ typedef struct cntx_s func_t l3_nat_ukrs[ BLIS_NUM_LEVEL3_UKRS ]; mbool_t l3_nat_ukrs_prefs[ BLIS_NUM_LEVEL3_UKRS ]; + blksz_t l3_sup_thresh[ BLIS_NUM_THRESH ]; + void* l3_sup_handlers[ BLIS_NUM_LEVEL3_OPS ]; + blksz_t l3_sup_blkszs[ BLIS_NUM_BLKSZS ]; + func_t l3_sup_kers[ BLIS_NUM_3OP_RC_COMBOS ]; + mbool_t l3_sup_kers_prefs[ BLIS_NUM_3OP_RC_COMBOS ]; + func_t l1f_kers[ BLIS_NUM_LEVEL1F_KERS ]; func_t l1v_kers[ BLIS_NUM_LEVEL1V_KERS ]; @@ -1247,6 +1450,9 @@ typedef struct rntm_s // The packing block allocator, which is attached in the l3 thread decorator. membrk_t* membrk; + // A switch to enable/disable small/unpacked matrix handling in level-3 ops. + bool_t l3_sup; + } rntm_t; diff --git a/frame/include/bli_x86_asm_macros.h b/frame/include/bli_x86_asm_macros.h index d329a2c3a..eca0b6959 100644 --- a/frame/include/bli_x86_asm_macros.h +++ b/frame/include/bli_x86_asm_macros.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -643,6 +644,7 @@ #define XOR(_0, _1) INSTR_(xor, _0, _1) #define ADD(_0, _1) INSTR_(add, _0, _1) #define SUB(_0, _1) INSTR_(sub, _0, _1) +#define IMUL(_0, _1) INSTR_(imul, _0, _1) #define SAL(...) INSTR_(sal, __VA_ARGS__) #define SAR(...) INSTR_(sar, __VA_ARGS__) #define SHLX(_0, _1, _2) INSTR_(shlx, _0, _1, _2) @@ -656,6 +658,7 @@ #define xor(_0, _1) XOR(_0, _1) #define add(_0, _1) ADD(_0, _1) #define sub(_0, _1) SUB(_0, _1) +#define imul(_0, _1) IMUL(_0, _1) #define sal(...) SAL(__VA_ARGS__) #define sar(...) SAR(__VA_ARGS__) #define shlx(_0, _1, _2) SHLX(_0, _1, _2) @@ -780,13 +783,13 @@ #define VPBROADCASTQ(_0, _1) INSTR_(vpbroadcastq, _0, _1) #define VBROADCASTF128(_0, _1) INSTR_(vbroadcastf128, _0, _1) #define VBROADCASTF64X4(_0, _1) INSTR_(vbroadcastf64x4, _0, _1) -#define VGATHERDPS(_0, _1) INSTR_(vgatherdps, _0, _1) +#define VGATHERDPS(...) INSTR_(vgatherdps, __VA_ARGS__) #define VSCATTERDPS(_0, _1) INSTR_(vscatterdps, _0, _1) -#define VGATHERDPD(_0, _1) INSTR_(vgatherdpd, _0, _1) +#define VGATHERDPD(...) INSTR_(vgatherdpd, __VA_ARGS__) #define VSCATTERDPD(_0, _1) INSTR_(vscatterdpd, _0, _1) -#define VGATHERQPS(_0, _1) INSTR_(vgatherqps, _0, _1) +#define VGATHERQPS(...) INSTR_(vgatherqps, __VA_ARGS__) #define VSCATTERQPS(_0, _1) INSTR_(vscatterqps, _0, _1) -#define VGATHERQPD(_0, _1) INSTR_(vgatherqpd, _0, _1) +#define VGATHERQPD(...) INSTR_(vgatherqpd, __VA_ARGS__) #define VSCATTERQPD(_0, _1) INSTR_(vscatterqpd, _0, _1) #define vmovddup(_0, _1) VMOVDDUP(_0, _1) @@ -809,19 +812,41 @@ #define vmovdqa64(_0, _1) VMOVDQA64(_0, _1) #define vbroadcastss(_0, _1) VBROADCASTSS(_0, _1) #define vbroadcastsd(_0, _1) VBROADCASTSD(_0, _1) -#define vpbraodcastd(_0, _1) VPBROADCASTD(_0, _1) +#define vpbroadcastd(_0, _1) VPBROADCASTD(_0, _1) #define vpbroadcastq(_0, _1) VPBROADCASTQ(_0, _1) #define vbroadcastf128(_0, _1) VBROADCASTF128(_0, _1) #define vbroadcastf64x4(_0, _1) VBROADCASTF64X4(_0, _1) -#define vgatherdps(_0, _1) VGATHERDPS(_0, _1) +#define vgatherdps(...) VGATHERDPS(__VA_ARGS__) #define vscatterdps(_0, _1) VSCATTERDPS(_0, _1) -#define vgatherdpd(_0, _1) VGATHERDPD(_0, _1) +#define vgatherdpd(...) VGATHERDPD(__VA_ARGS__) #define vscatterdpd(_0, _1) VSCATTERDPD(_0, _1) -#define vgatherqps(_0, _1) VGATHERQPS(_0, _1) +#define vgatherqps(...) VGATHERQPS(__VA_ARGS__) #define vscatterqps(_0, _1) VSCATTERQPS(_0, _1) -#define vgatherqpd(_0, _1) VGATHERQPD(_0, _1) +#define vgatherqpd(...) VGATHERQPD(__VA_ARGS__) #define vscatterqpd(_0, _1) VSCATTERQPD(_0, _1) +// Vector comparisons + +#define VPCMPEQB(_0, _1, _2) INSTR_(vpcmpeqb, _0, _1, _2) +#define VPCMPEQW(_0, _1, _2) INSTR_(vpcmpeqw, _0, _1, _2) +#define VPCMPEQD(_0, _1, _2) INSTR_(vpcmpeqd, _0, _1, _2) + +#define vpcmpeqb(_0, _1, _2) VPCMPEQB(_0, _1, _2) +#define vpcmpeqw(_0, _1, _2) VPCMPEQW(_0, _1, _2) +#define vpcmpeqd(_0, _1, _2) VPCMPEQD(_0, _1, _2) + +// Vector integer math + +#define VPADDB(_0, _1, _2) INSTR_(vpaddb, _0, _1, _2) +#define VPADDW(_0, _1, _2) INSTR_(vpaddw, _0, _1, _2) +#define VPADDD(_0, _1, _2) INSTR_(vpaddd, _0, _1, _2) +#define VPADDQ(_0, _1, _2) INSTR_(vpaddq, _0, _1, _2) + +#define vpaddb(_0, _1, _2) VPADDB(_0, _1, _2) +#define vpaddw(_0, _1, _2) VPADDW(_0, _1, _2) +#define vpaddd(_0, _1, _2) VPADDD(_0, _1, _2) +#define vpaddq(_0, _1, _2) VPADDQ(_0, _1, _2) + // Vector math #define ADDPS(_0, _1) INSTR_(addps, _0, _1) @@ -852,6 +877,8 @@ #define VADDSUBPS(_0, _1, _2) INSTR_(vaddsubps, _0, _1, _2) #define VADDSUBPD(_0, _1, _2) INSTR_(vaddsubpd, _0, _1, _2) +#define VHADDPD(_0, _1, _2) INSTR_(vhaddpd, _0, _1, _2) +#define VHADDPS(_0, _1, _2) INSTR_(vhaddps, _0, _1, _2) #define VUCOMISS(_0, _1) INSTR_(vucomiss, _0, _1) #define VUCOMISD(_0, _1) INSTR_(vucomisd, _0, _1) #define VCOMISS(_0, _1) INSTR_(vcomiss, _0, _1) @@ -974,6 +1001,8 @@ #define vaddsubps(_0, _1, _2) VADDSUBPS(_0, _1, _2) #define vaddsubpd(_0, _1, _2) VADDSUBPD(_0, _1, _2) +#define vhaddpd(_0, _1, _2) VHADDPD(_0, _1, _2) +#define vhaddps(_0, _1, _2) VHADDPS(_0, _1, _2) #define vucomiss(_0, _1) VUCOMISS(_0, _1) #define vucomisd(_0, _1) VUCOMISD(_0, _1) #define vcomiss(_0, _1) VCOMISS(_0, _1) diff --git a/frame/include/blis.h b/frame/include/blis.h index 95f9bc5b0..00789c231 100644 --- a/frame/include/blis.h +++ b/frame/include/blis.h @@ -88,6 +88,7 @@ extern "C" { #include "bli_l1f_ker_prot.h" #include "bli_l1m_ker_prot.h" #include "bli_l3_ukr_prot.h" +#include "bli_l3_sup_ker_prot.h" #include "bli_arch_config_pre.h" #include "bli_arch_config.h" diff --git a/frame/ind/bli_ind.c b/frame/ind/bli_ind.c index 41419c6ce..09393e611 100644 --- a/frame/ind/bli_ind.c +++ b/frame/ind/bli_ind.c @@ -168,9 +168,9 @@ bool_t bli_ind_oper_has_avail( opid_t oper, num_t dt ) } #endif -void* bli_ind_oper_get_avail( opid_t oper, num_t dt ) +void_fp bli_ind_oper_get_avail( opid_t oper, num_t dt ) { - void* func_p; + void_fp func_p; if ( bli_opid_is_level3( oper ) ) { diff --git a/frame/ind/bli_ind.h b/frame/ind/bli_ind.h index 9618acd8e..645829d8b 100644 --- a/frame/ind/bli_ind.h +++ b/frame/ind/bli_ind.h @@ -51,21 +51,21 @@ void bli_ind_init( void ); void bli_ind_finalize( void ); -void bli_ind_enable( ind_t method ); -void bli_ind_disable( ind_t method ); -void bli_ind_disable_all( void ); +BLIS_EXPORT_BLIS void bli_ind_enable( ind_t method ); +BLIS_EXPORT_BLIS void bli_ind_disable( ind_t method ); +BLIS_EXPORT_BLIS void bli_ind_disable_all( void ); -void bli_ind_enable_dt( ind_t method, num_t dt ); -void bli_ind_disable_dt( ind_t method, num_t dt ); -void bli_ind_disable_all_dt( num_t dt ); +BLIS_EXPORT_BLIS void bli_ind_enable_dt( ind_t method, num_t dt ); +BLIS_EXPORT_BLIS void bli_ind_disable_dt( ind_t method, num_t dt ); +BLIS_EXPORT_BLIS void bli_ind_disable_all_dt( num_t dt ); -void bli_ind_oper_enable_only( opid_t oper, ind_t method, num_t dt ); +BLIS_EXPORT_BLIS void bli_ind_oper_enable_only( opid_t oper, ind_t method, num_t dt ); -bool_t bli_ind_oper_is_impl( opid_t oper, ind_t method ); +BLIS_EXPORT_BLIS bool_t bli_ind_oper_is_impl( opid_t oper, ind_t method ); //bool_t bli_ind_oper_has_avail( opid_t oper, num_t dt ); -void* bli_ind_oper_get_avail( opid_t oper, num_t dt ); -ind_t bli_ind_oper_find_avail( opid_t oper, num_t dt ); -char* bli_ind_oper_get_avail_impl_string( opid_t oper, num_t dt ); +BLIS_EXPORT_BLIS void_fp bli_ind_oper_get_avail( opid_t oper, num_t dt ); +BLIS_EXPORT_BLIS ind_t bli_ind_oper_find_avail( opid_t oper, num_t dt ); +BLIS_EXPORT_BLIS char* bli_ind_oper_get_avail_impl_string( opid_t oper, num_t dt ); char* bli_ind_get_impl_string( ind_t method ); num_t bli_ind_map_cdt_to_index( num_t dt ); diff --git a/frame/ind/bli_l3_ind.c b/frame/ind/bli_l3_ind.c index 2cb75cae8..02a7668b8 100644 --- a/frame/ind/bli_l3_ind.c +++ b/frame/ind/bli_l3_ind.c @@ -35,7 +35,7 @@ #include "blis.h" -static void* bli_l3_ind_oper_fp[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS] = +static void_fp bli_l3_ind_oper_fp[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS] = { /* gemm hemm herk her2k symm syrk, syr2k trmm3 trmm trsm */ /* 3mh */ { bli_gemm3mh, bli_hemm3mh, bli_herk3mh, bli_her2k3mh, bli_symm3mh, @@ -87,7 +87,7 @@ bool_t bli_l3_ind_oper_st[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS][2] = #undef GENFUNC #define GENFUNC( opname, optype ) \ \ -void* PASTEMAC(opname,ind_get_avail)( num_t dt ) \ +void_fp PASTEMAC(opname,ind_get_avail)( num_t dt ) \ { \ return bli_ind_oper_get_avail( optype, dt ); \ } @@ -114,8 +114,8 @@ GENFUNC( trsm, BLIS_TRSM ) #if 0 bool_t bli_l3_ind_oper_is_avail( opid_t oper, ind_t method, num_t dt ) { - void* func; - bool_t stat; + void_fp func; + bool_t stat; // If the datatype is real, it is never available. if ( !bli_is_complex( dt ) ) return FALSE; @@ -146,7 +146,7 @@ ind_t bli_l3_ind_oper_find_avail( opid_t oper, num_t dt ) // current operation and datatype. for ( im = 0; im < BLIS_NUM_IND_METHODS; ++im ) { - void* func = bli_l3_ind_oper_get_func( oper, im ); + void_fp func = bli_l3_ind_oper_get_func( oper, im ); bool_t stat = bli_l3_ind_oper_get_enable( oper, im, dt ); if ( func != NULL && @@ -256,7 +256,7 @@ bool_t bli_l3_ind_oper_get_enable( opid_t oper, ind_t method, num_t dt ) // ----------------------------------------------------------------------------- -void* bli_l3_ind_oper_get_func( opid_t oper, ind_t method ) +void_fp bli_l3_ind_oper_get_func( opid_t oper, ind_t method ) { return bli_l3_ind_oper_fp[ method ][ oper ]; } diff --git a/frame/ind/bli_l3_ind.h b/frame/ind/bli_l3_ind.h index 0b9f49ec3..216eaedcb 100644 --- a/frame/ind/bli_l3_ind.h +++ b/frame/ind/bli_l3_ind.h @@ -40,7 +40,7 @@ #undef GENPROT #define GENPROT( opname ) \ \ -void* PASTEMAC(opname,ind_get_avail)( num_t dt ); +void_fp PASTEMAC(opname,ind_get_avail)( num_t dt ); /*bool_t PASTEMAC(opname,ind_has_avail)( num_t dt ); */ GENPROT( gemm ) @@ -58,17 +58,17 @@ GENPROT( trsm ) //bool_t bli_l3_ind_oper_is_avail( opid_t oper, ind_t method, num_t dt ); -ind_t bli_l3_ind_oper_find_avail( opid_t oper, num_t dt ); +ind_t bli_l3_ind_oper_find_avail( opid_t oper, num_t dt ); -void bli_l3_ind_set_enable_dt( ind_t method, num_t dt, bool_t status ); +void bli_l3_ind_set_enable_dt( ind_t method, num_t dt, bool_t status ); -void bli_l3_ind_oper_enable_only( opid_t oper, ind_t method, num_t dt ); -void bli_l3_ind_oper_set_enable_all( opid_t oper, num_t dt, bool_t status ); +void bli_l3_ind_oper_enable_only( opid_t oper, ind_t method, num_t dt ); +void bli_l3_ind_oper_set_enable_all( opid_t oper, num_t dt, bool_t status ); -void bli_l3_ind_oper_set_enable( opid_t oper, ind_t method, num_t dt, bool_t status ); -bool_t bli_l3_ind_oper_get_enable( opid_t oper, ind_t method, num_t dt ); +void bli_l3_ind_oper_set_enable( opid_t oper, ind_t method, num_t dt, bool_t status ); +bool_t bli_l3_ind_oper_get_enable( opid_t oper, ind_t method, num_t dt ); -void* bli_l3_ind_oper_get_func( opid_t oper, ind_t method ); +void_fp bli_l3_ind_oper_get_func( opid_t oper, ind_t method ); #endif diff --git a/frame/ind/cntx/bli_cntx_ind_stage.c b/frame/ind/cntx/bli_cntx_ind_stage.c index 671be681d..b5c15d5d7 100644 --- a/frame/ind/cntx/bli_cntx_ind_stage.c +++ b/frame/ind/cntx/bli_cntx_ind_stage.c @@ -36,7 +36,7 @@ typedef void (*cntx_stage_ft)( dim_t stage, cntx_t* cntx ); -static void* bli_cntx_ind_stage_fp[BLIS_NUM_IND_METHODS] = +static void_fp bli_cntx_ind_stage_fp[BLIS_NUM_IND_METHODS] = { /* 3mh */ bli_cntx_3mh_stage, /* 3m1 */ bli_cntx_3m1_stage, diff --git a/frame/ind/cntx/bli_cntx_ind_stage.h b/frame/ind/cntx/bli_cntx_ind_stage.h index affaa84f4..124421665 100644 --- a/frame/ind/cntx/bli_cntx_ind_stage.h +++ b/frame/ind/cntx/bli_cntx_ind_stage.h @@ -32,13 +32,13 @@ */ -void bli_cntx_ind_stage( ind_t method, dim_t stage, cntx_t* cntx ); +void bli_cntx_ind_stage( ind_t method, dim_t stage, cntx_t* cntx ); -void bli_cntx_3mh_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_3m1_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_4mh_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_4mb_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_4m1_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_1m_stage( dim_t stage, cntx_t* cntx ); -void bli_cntx_nat_stage( dim_t stage, cntx_t* cntx ); +void bli_cntx_3mh_stage( dim_t stage, cntx_t* cntx ); +void bli_cntx_3m1_stage( dim_t stage, cntx_t* cntx ); +void bli_cntx_4mh_stage( dim_t stage, cntx_t* cntx ); +void bli_cntx_4mb_stage( dim_t stage, cntx_t* cntx ); +void bli_cntx_4m1_stage( dim_t stage, cntx_t* cntx ); +void bli_cntx_1m_stage( dim_t stage, cntx_t* cntx ); +void bli_cntx_nat_stage( dim_t stage, cntx_t* cntx ); diff --git a/frame/ind/oapi/bli_l3_ind_oapi.h b/frame/ind/oapi/bli_l3_ind_oapi.h index c8370a2a5..d4767925d 100644 --- a/frame/ind/oapi/bli_l3_ind_oapi.h +++ b/frame/ind/oapi/bli_l3_ind_oapi.h @@ -40,16 +40,16 @@ #undef GENPROT #define GENPROT( imeth ) \ \ -void PASTEMAC(gemm,imeth) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(hemm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(herk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(her2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(symm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(syrk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(syr2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(trmm3,imeth)( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(trmm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(trsm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, rntm_t* rntm ); +BLIS_EXPORT_BLIS void PASTEMAC(gemm,imeth) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(hemm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(herk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(her2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(symm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(syrk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(syr2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(trmm3,imeth)( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(trmm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(trsm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, rntm_t* rntm ); GENPROT( nat ) GENPROT( ind ) @@ -65,14 +65,14 @@ GENPROT( 1m ) #undef GENPROT_NO2OP #define GENPROT_NO2OP( imeth ) \ \ -void PASTEMAC(gemm,imeth) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(hemm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(herk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(her2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(symm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(syrk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(syr2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ -void PASTEMAC(trmm3,imeth)( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); +BLIS_EXPORT_BLIS void PASTEMAC(gemm,imeth) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(hemm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(herk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(her2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(symm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(syrk,imeth) ( obj_t* alpha, obj_t* a, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(syr2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(trmm3,imeth)( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); GENPROT_NO2OP( 3mh ) GENPROT_NO2OP( 4mh ) @@ -88,7 +88,7 @@ GENPROT_NO2OP( 4mb ) #undef GENPROT #define GENPROT( imeth, alg ) \ \ -void PASTEMAC2(gemm,imeth,alg) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c ); \ +BLIS_EXPORT_BLIS void PASTEMAC2(gemm,imeth,alg) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c ); \ */ //GENPROT( 1m, bp ) diff --git a/frame/thread/bli_pthread.h b/frame/thread/bli_pthread.h index 337eadc33..56ede89b5 100644 --- a/frame/thread/bli_pthread.h +++ b/frame/thread/bli_pthread.h @@ -41,39 +41,82 @@ // This branch defines a pthread-like API, bli_pthread_*(), and implements it // in terms of Windows API calls. +// -- pthread_mutex_*() -- + typedef SRWLOCK bli_pthread_mutex_t; typedef void bli_pthread_mutexattr_t; #define BLIS_PTHREAD_MUTEX_INITIALIZER SRWLOCK_INIT -int bli_pthread_mutex_init( bli_pthread_mutex_t* mutex, const bli_pthread_mutexattr_t *attr ); +BLIS_EXPORT_BLIS int bli_pthread_mutex_init + ( + bli_pthread_mutex_t* mutex, + const bli_pthread_mutexattr_t* attr + ); -int bli_pthread_mutex_destroy( bli_pthread_mutex_t* mutex ); +BLIS_EXPORT_BLIS int bli_pthread_mutex_destroy + ( + bli_pthread_mutex_t* mutex + ); -int bli_pthread_mutex_lock( bli_pthread_mutex_t* mutex ); +BLIS_EXPORT_BLIS int bli_pthread_mutex_lock + ( + bli_pthread_mutex_t* mutex + ); -int bli_pthread_mutex_trylock( bli_pthread_mutex_t* mutex ); +BLIS_EXPORT_BLIS int bli_pthread_mutex_trylock + ( + bli_pthread_mutex_t* mutex + ); -int bli_pthread_mutex_unlock( bli_pthread_mutex_t* mutex ); +BLIS_EXPORT_BLIS int bli_pthread_mutex_unlock + ( + bli_pthread_mutex_t* mutex + ); + +// -- pthread_once_*() -- typedef INIT_ONCE bli_pthread_once_t; #define BLIS_PTHREAD_ONCE_INIT INIT_ONCE_STATIC_INIT -void bli_pthread_once( bli_pthread_once_t* once, void (*init)( void ) ); +BLIS_EXPORT_BLIS void bli_pthread_once + ( + bli_pthread_once_t* once, + void (*init)(void) + ); + +// -- pthread_cond_*() -- typedef CONDITION_VARIABLE bli_pthread_cond_t; typedef void bli_pthread_condattr_t; #define BLIS_PTHREAD_COND_INITIALIZER CONDITION_VARIABLE_INIT -int bli_pthread_cond_init( bli_pthread_cond_t* cond, const bli_pthread_condattr_t* attr ); +BLIS_EXPORT_BLIS int bli_pthread_cond_init + ( + bli_pthread_cond_t* cond, + const bli_pthread_condattr_t* attr + ); -int bli_pthread_cond_destroy( bli_pthread_cond_t* cond ); +BLIS_EXPORT_BLIS int bli_pthread_cond_destroy + ( + bli_pthread_cond_t* cond + ); -int bli_pthread_cond_wait( bli_pthread_cond_t* cond, bli_pthread_mutex_t* mutex ); +BLIS_EXPORT_BLIS int bli_pthread_cond_wait + ( + bli_pthread_cond_t* cond, + bli_pthread_mutex_t* mutex + ); + +BLIS_EXPORT_BLIS int bli_pthread_cond_broadcast + ( + bli_pthread_cond_t* cond + ); + +// -- pthread_create(), pthread_join() -- -int bli_pthread_cond_broadcast( bli_pthread_cond_t* cond ); typedef struct { HANDLE handle; @@ -82,11 +125,21 @@ typedef struct typedef void bli_pthread_attr_t; -int bli_pthread_create( bli_pthread_t *thread, const bli_pthread_attr_t *attr, void* (*start_routine)( void* ), void *arg ); +BLIS_EXPORT_BLIS int bli_pthread_create + ( + bli_pthread_t* thread, + const bli_pthread_attr_t* attr, + void* (*start_routine)(void*), + void* arg + ); -int bli_pthread_join( bli_pthread_t thread, void **retval ); +BLIS_EXPORT_BLIS int bli_pthread_join + ( + bli_pthread_t thread, + void** retval + ); -// barrier-related definitions +// -- pthread_barrier_*() -- typedef void bli_pthread_barrierattr_t; @@ -98,11 +151,22 @@ typedef struct int tripCount; } bli_pthread_barrier_t; -int bli_pthread_barrier_init( bli_pthread_barrier_t *barrier, const bli_pthread_barrierattr_t *attr, unsigned int count ); +BLIS_EXPORT_BLIS int bli_pthread_barrier_init + ( + bli_pthread_barrier_t* barrier, + const bli_pthread_barrierattr_t* attr, + unsigned int count + ); -int bli_pthread_barrier_destroy( bli_pthread_barrier_t *barrier ); +BLIS_EXPORT_BLIS int bli_pthread_barrier_destroy + ( + bli_pthread_barrier_t* barrier + ); -int bli_pthread_barrier_wait( bli_pthread_barrier_t *barrier ); +BLIS_EXPORT_BLIS int bli_pthread_barrier_wait + ( + bli_pthread_barrier_t* barrier + ); #else // !defined(_MSC_VER) @@ -155,7 +219,7 @@ typedef pthread_barrierattr_t bli_pthread_barrierattr_t; // -- pthread_create(), pthread_join() -- -int bli_pthread_create +BLIS_EXPORT_BLIS int bli_pthread_create ( bli_pthread_t* thread, const bli_pthread_attr_t* attr, @@ -163,7 +227,7 @@ int bli_pthread_create void* arg ); -int bli_pthread_join +BLIS_EXPORT_BLIS int bli_pthread_join ( bli_pthread_t thread, void** retval @@ -171,59 +235,59 @@ int bli_pthread_join // -- pthread_mutex_*() -- -int bli_pthread_mutex_init +BLIS_EXPORT_BLIS int bli_pthread_mutex_init ( bli_pthread_mutex_t* mutex, const bli_pthread_mutexattr_t* attr ); -int bli_pthread_mutex_destroy +BLIS_EXPORT_BLIS int bli_pthread_mutex_destroy ( bli_pthread_mutex_t* mutex ); -int bli_pthread_mutex_lock +BLIS_EXPORT_BLIS int bli_pthread_mutex_lock ( bli_pthread_mutex_t* mutex ); -int bli_pthread_mutex_trylock +BLIS_EXPORT_BLIS int bli_pthread_mutex_trylock ( bli_pthread_mutex_t* mutex ); -int bli_pthread_mutex_unlock +BLIS_EXPORT_BLIS int bli_pthread_mutex_unlock ( bli_pthread_mutex_t* mutex ); // -- pthread_cond_*() -- -int bli_pthread_cond_init +BLIS_EXPORT_BLIS int bli_pthread_cond_init ( bli_pthread_cond_t* cond, const bli_pthread_condattr_t* attr ); -int bli_pthread_cond_destroy +BLIS_EXPORT_BLIS int bli_pthread_cond_destroy ( bli_pthread_cond_t* cond ); -int bli_pthread_cond_wait +BLIS_EXPORT_BLIS int bli_pthread_cond_wait ( bli_pthread_cond_t* cond, bli_pthread_mutex_t* mutex ); -int bli_pthread_cond_broadcast +BLIS_EXPORT_BLIS int bli_pthread_cond_broadcast ( bli_pthread_cond_t* cond ); // -- pthread_once_*() -- -void bli_pthread_once +BLIS_EXPORT_BLIS void bli_pthread_once ( bli_pthread_once_t* once, void (*init)(void) @@ -231,19 +295,19 @@ void bli_pthread_once // -- pthread_barrier_*() -- -int bli_pthread_barrier_init +BLIS_EXPORT_BLIS int bli_pthread_barrier_init ( bli_pthread_barrier_t* barrier, const bli_pthread_barrierattr_t* attr, unsigned int count ); -int bli_pthread_barrier_destroy +BLIS_EXPORT_BLIS int bli_pthread_barrier_destroy ( bli_pthread_barrier_t* barrier ); -int bli_pthread_barrier_wait +BLIS_EXPORT_BLIS int bli_pthread_barrier_wait ( bli_pthread_barrier_t* barrier ); diff --git a/frame/thread/bli_thrcomm_openmp.h b/frame/thread/bli_thrcomm_openmp.h index d655bd131..945d9a4b5 100644 --- a/frame/thread/bli_thrcomm_openmp.h +++ b/frame/thread/bli_thrcomm_openmp.h @@ -75,8 +75,8 @@ typedef struct thrcomm_s thrcomm_t; // Prototypes specific to tree barriers. #ifdef BLIS_TREE_BARRIER barrier_t* bli_thrcomm_tree_barrier_create( int num_threads, int arity, barrier_t** leaves, int leaf_index ); -void bli_thrcomm_tree_barrier_free( barrier_t* barrier ); -void bli_thrcomm_tree_barrier( barrier_t* barack ); +void bli_thrcomm_tree_barrier_free( barrier_t* barrier ); +void bli_thrcomm_tree_barrier( barrier_t* barack ); #endif void bli_l3_thread_decorator_thread_check diff --git a/frame/thread/bli_thread.h b/frame/thread/bli_thread.h index d9bd3d241..d3ae01e9c 100644 --- a/frame/thread/bli_thread.h +++ b/frame/thread/bli_thread.h @@ -196,20 +196,20 @@ dim_t bli_ipow( dim_t base, dim_t power ); // ----------------------------------------------------------------------------- -dim_t bli_thread_get_env( const char* env, dim_t fallback ); +BLIS_EXPORT_BLIS dim_t bli_thread_get_env( const char* env, dim_t fallback ); //void bli_thread_set_env( const char* env, dim_t value ); -dim_t bli_thread_get_jc_nt( void ); -dim_t bli_thread_get_pc_nt( void ); -dim_t bli_thread_get_ic_nt( void ); -dim_t bli_thread_get_jr_nt( void ); -dim_t bli_thread_get_ir_nt( void ); -dim_t bli_thread_get_num_threads( void ); +BLIS_EXPORT_BLIS dim_t bli_thread_get_jc_nt( void ); +BLIS_EXPORT_BLIS dim_t bli_thread_get_pc_nt( void ); +BLIS_EXPORT_BLIS dim_t bli_thread_get_ic_nt( void ); +BLIS_EXPORT_BLIS dim_t bli_thread_get_jr_nt( void ); +BLIS_EXPORT_BLIS dim_t bli_thread_get_ir_nt( void ); +BLIS_EXPORT_BLIS dim_t bli_thread_get_num_threads( void ); -void bli_thread_set_ways( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_t ir ); -void bli_thread_set_num_threads( dim_t value ); +BLIS_EXPORT_BLIS void bli_thread_set_ways( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_t ir ); +BLIS_EXPORT_BLIS void bli_thread_set_num_threads( dim_t value ); -void bli_thread_init_rntm( rntm_t* rntm ); +BLIS_EXPORT_BLIS void bli_thread_init_rntm( rntm_t* rntm ); void bli_thread_init_rntm_from_env( rntm_t* rntm ); diff --git a/frame/util/bli_util_fpa.c b/frame/util/bli_util_fpa.c index b68f608eb..e46163e89 100644 --- a/frame/util/bli_util_fpa.c +++ b/frame/util/bli_util_fpa.c @@ -70,7 +70,12 @@ GENFRONT( sumsqv ) #undef GENFRONT #define GENFRONT( opname ) \ \ -GENARRAY_FPA( void*, opname ); \ +/* +GENARRAY_FPA( void_fp, opname ); \ +*/ \ +\ +GENARRAY_FPA( PASTECH(opname,_vft), \ + PASTECH0(opname) ); \ \ PASTECH(opname,_vft) \ PASTEMAC(opname,_qfp)( num_t dt ) \ diff --git a/frame/util/bli_util_oapi.c b/frame/util/bli_util_oapi.c index f9f9b4c93..128b1f92e 100644 --- a/frame/util/bli_util_oapi.c +++ b/frame/util/bli_util_oapi.c @@ -66,7 +66,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, asum ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -108,7 +108,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( a ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -152,7 +152,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, norm ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -201,7 +201,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, norm ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -257,7 +257,7 @@ void PASTEMAC(opname,EX_SUF) \ } \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = \ PASTEMAC(opname,_qfp)( dt ); \ \ @@ -325,7 +325,7 @@ void PASTEMAC(opname,EX_SUF) \ } \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(opname,_vft) f = \ PASTEMAC(opname,_qfp)( dt ); \ \ @@ -401,7 +401,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -445,7 +445,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ @@ -492,7 +492,7 @@ void PASTEMAC(opname,EX_SUF) \ PASTEMAC(opname,_check)( x, scale, sumsq ); \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH2(opname,BLIS_TAPI_EX_SUF,_vft) f = \ PASTEMAC2(opname,BLIS_TAPI_EX_SUF,_qfp)( dt ); \ \ diff --git a/frame/util/bli_util_oapi.h b/frame/util/bli_util_oapi.h index b7952c822..1acce1606 100644 --- a/frame/util/bli_util_oapi.h +++ b/frame/util/bli_util_oapi.h @@ -40,7 +40,7 @@ #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* asum \ @@ -53,7 +53,7 @@ GENPROT( asumv ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* a \ BLIS_OAPI_EX_PARAMS \ @@ -67,7 +67,7 @@ GENPROT( mktrim ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* norm \ @@ -82,7 +82,7 @@ GENPROT( normiv ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* norm \ @@ -97,7 +97,7 @@ GENPROT( normim ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ FILE* file, \ char* s1, \ @@ -114,7 +114,7 @@ GENPROT( fprintm ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ char* s1, \ obj_t* x, \ @@ -130,7 +130,7 @@ GENPROT( printm ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x \ BLIS_OAPI_EX_PARAMS \ @@ -143,7 +143,7 @@ GENPROT( randnv ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x \ BLIS_OAPI_EX_PARAMS \ @@ -156,7 +156,7 @@ GENPROT( randnm ) #undef GENPROT #define GENPROT( opname ) \ \ -void PASTEMAC(opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ( \ obj_t* x, \ obj_t* scale, \ diff --git a/frame/util/bli_util_tapi.h b/frame/util/bli_util_tapi.h index f48acbd11..c35702cbc 100644 --- a/frame/util/bli_util_tapi.h +++ b/frame/util/bli_util_tapi.h @@ -40,7 +40,7 @@ #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ dim_t n, \ ctype* x, inc_t incx, \ @@ -54,7 +54,7 @@ INSERT_GENTPROTR_BASIC0( asumv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ uplo_t uploa, \ dim_t m, \ @@ -70,7 +70,7 @@ INSERT_GENTPROT_BASIC0( mktrim ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ dim_t n, \ ctype* x, inc_t incx, \ @@ -86,7 +86,7 @@ INSERT_GENTPROTR_BASIC0( normiv ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ doff_t diagoffx, \ diag_t diagx, \ @@ -106,7 +106,7 @@ INSERT_GENTPROTR_BASIC0( normim ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ char* s1, \ dim_t n, \ @@ -121,7 +121,7 @@ INSERT_GENTPROT_BASIC0_I( printv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ char* s1, \ dim_t m, \ @@ -137,7 +137,7 @@ INSERT_GENTPROT_BASIC0_I( printm ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ dim_t n, \ ctype* x, inc_t incx \ @@ -151,7 +151,7 @@ INSERT_GENTPROT_BASIC0( randnv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ doff_t diagoffx, \ uplo_t uplox, \ @@ -168,7 +168,7 @@ INSERT_GENTPROT_BASIC0( randnm ) #undef GENTPROTR #define GENTPROTR( ctype, ctype_r, ch, chr, opname ) \ \ -void PASTEMAC2(ch,opname,EX_SUF) \ +BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ( \ dim_t n, \ ctype* x, inc_t incx, \ diff --git a/frame/util/bli_util_unb_var1.c b/frame/util/bli_util_unb_var1.c index 32197819a..a12cdbd5c 100644 --- a/frame/util/bli_util_unb_var1.c +++ b/frame/util/bli_util_unb_var1.c @@ -466,7 +466,7 @@ void PASTEMAC(ch,varname) \ /* If the absolute value of the current element exceeds that of the previous largest, save it and its index. If NaN is encountered, then treat it the same as if it were a valid - value that was smaller than any previously seen. This + value that was larger than any previously seen. This behavior mimics that of LAPACK's ?lange(). */ \ if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \ { \ diff --git a/frame/util/bli_util_unb_var1.h b/frame/util/bli_util_unb_var1.h index 6f2a3fc85..3fb517eec 100644 --- a/frame/util/bli_util_unb_var1.h +++ b/frame/util/bli_util_unb_var1.h @@ -110,7 +110,7 @@ INSERT_GENTPROTR_BASIC0( normim_unb_var1 ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ FILE* file, \ char* s1, \ @@ -126,7 +126,7 @@ INSERT_GENTPROT_BASIC0_I( fprintv ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ \ -void PASTEMAC(ch,opname) \ +BLIS_EXPORT_BLIS void PASTEMAC(ch,opname) \ ( \ FILE* file, \ char* s1, \ diff --git a/kernels/armv8a/3/bli_gemm_armv8a_opt_4x4.c b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c similarity index 100% rename from kernels/armv8a/3/bli_gemm_armv8a_opt_4x4.c rename to kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c diff --git a/kernels/haswell/3/old/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/old/bli_gemm_haswell_asm_d6x8.c index adb194f1f..e5e5a74fd 100644 --- a/kernels/haswell/3/old/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/old/bli_gemm_haswell_asm_d6x8.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8.c new file mode 100644 index 000000000..1b80af8b7 --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8.c @@ -0,0 +1,4565 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + +#if 0 +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8m, bli_dgemmsup_rd_haswell_asm_6x4m, bli_dgemmsup_rd_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref_6x1 }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8m, bli_dgemmsup_rd_haswell_asm_3x4m, bli_dgemmsup_rd_haswell_asm_3x2m, bli_dgemmsup_r_haswell_ref_3x1 }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8m, bli_dgemmsup_rd_haswell_asm_2x4m, bli_dgemmsup_rd_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref_2x1 }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8m, bli_dgemmsup_rd_haswell_asm_1x4m, bli_dgemmsup_rd_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref_1x1 } +}; +#endif + + +void bli_dgemmsup_rd_haswell_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rd_haswell_asm_6x4 + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2 + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; + lea(mem(r14), rax) // rax = a + 3*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | | | + -------- -- -- -- ... | | | | + -------- += -- -- -- | | | | + -------- | | | | + -------- : + -------- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c; + lea(mem(r14), rax) // rax = a + 6*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + // xmm4 = sum(ymm4) sum(ymm5) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c new file mode 100644 index 000000000..be7ad8d9c --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c @@ -0,0 +1,1901 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + +#if 0 +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8m, bli_dgemmsup_rd_haswell_asm_6x4m, bli_dgemmsup_rd_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref_6x1 }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8m, bli_dgemmsup_rd_haswell_asm_3x4m, bli_dgemmsup_rd_haswell_asm_3x2m, bli_dgemmsup_r_haswell_ref_3x1 }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8m, bli_dgemmsup_rd_haswell_asm_2x4m, bli_dgemmsup_rd_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref_2x1 }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8m, bli_dgemmsup_rd_haswell_asm_1x4m, bli_dgemmsup_rd_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref_1x1 } +}; +#endif + + +void bli_dgemmsup_rd_haswell_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rd_haswell_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + // r10 = unused + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; + lea(mem(r14), rax) // rax = a + 3*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c; + lea(mem(r14), rax) // rax = a + 6*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c new file mode 100644 index 000000000..9012e921b --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c @@ -0,0 +1,2402 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + +#if 0 +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8n, bli_dgemmsup_rd_haswell_asm_6x4n, bli_dgemmsup_rd_haswell_asm_6x2n, bli_dgemmsup_r_haswell_ref_6x1 }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8n, bli_dgemmsup_rd_haswell_asm_3x4n, bli_dgemmsup_rd_haswell_asm_3x2n, bli_dgemmsup_r_haswell_ref_3x1 }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8n, bli_dgemmsup_rd_haswell_asm_2x4n, bli_dgemmsup_rd_haswell_asm_2x2n, bli_dgemmsup_r_haswell_ref_2x1 }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8n, bli_dgemmsup_rd_haswell_asm_1x4n, bli_dgemmsup_rd_haswell_asm_1x2n, bli_dgemmsup_r_haswell_ref_1x1 } +}; +#endif + + +void bli_dgemmsup_rd_haswell_asm_6x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t m_left = m0 % 6; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other ?x8m kernels, as needed. + if ( m_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + +#if 1 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m0 ) + { + dgemmsup_ker_ft ker_fp1 = NULL; + dgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m0 == 7 ) + { + mr1 = 6; mr2 = 1; + ker_fp1 = bli_dgemmsup_rd_haswell_asm_6x8n; + ker_fp2 = bli_dgemmsup_rd_haswell_asm_1x8n; + } + else if ( m0 == 8 ) + { + mr1 = 6; mr2 = 2; + ker_fp1 = bli_dgemmsup_rd_haswell_asm_6x8n; + ker_fp2 = bli_dgemmsup_rd_haswell_asm_2x8n; + } + else // if ( m0 == 9 ) + { + mr1 = 6; mr2 = 3; + ker_fp1 = bli_dgemmsup_rd_haswell_asm_6x8n; + ker_fp2 = bli_dgemmsup_rd_haswell_asm_3x8n; + } + + ker_fp1 + ( + conja, conjb, mr1, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { +#if 0 + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_TRANSPOSE, conja, k0, n0, + alpha, bj, rs_b0, cs_b0, ai, cs_a0, + beta, cij, cs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + //mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r9) // ii = 0; + + label(.DLOOP3X4I) // LOOP OVER ii = [ 0 1 ... ] + + + + mov(var(b), r14) // load address of b + mov(var(c), r12) // load address of c + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(rdi, rsi) // rsi *= rs_c + lea(mem(r12, rsi, 1), r12) // r12 = c + 3*ii*rs_c; + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(r8, rsi) // rsi *= rs_a; + lea(mem(rdx, rsi, 1), rdx) // rax = a + 3*ii*rs_a; + + + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; +#else + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + add(imm(3), r9) // ii += 3; + cmp(imm(3), r9) // compare ii to 3 + jle(.DLOOP3X4I) // if ii <= 3, jump to beginning + // of ii loop; otherwise, loop ends. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 6; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + //bli_dgemmsup_rd_haswell_asm_6x1n + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; +#else + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_3x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_2x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; +#else + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 2; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_2x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_1x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; +#else + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 1; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_1x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_ddotxv_ex + ( + conja, conjb, k0, + alpha, ai, cs_a0, bj, rs_b0, + beta, cij, cntx, NULL + ); +#endif + } + } +} + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8.c new file mode 100644 index 000000000..ebe396317 --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8.c @@ -0,0 +1,11047 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 4, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rv_haswell_asm_6x8, bli_dgemmsup_rv_haswell_asm_6x4, bli_dgemmsup_rv_haswell_asm_6x2, bli_dgemmsup_r_haswell_ref_6x1 }, +/* 4 */ { bli_dgemmsup_rv_haswell_asm_4x8, bli_dgemmsup_rv_haswell_asm_4x4, bli_dgemmsup_rv_haswell_asm_4x2, bli_dgemmsup_r_haswell_ref_4x1 }, +/* 2 */ { bli_dgemmsup_rv_haswell_asm_2x8, bli_dgemmsup_rv_haswell_asm_2x4, bli_dgemmsup_rv_haswell_asm_2x2, bli_dgemmsup_r_haswell_ref_2x1 }, +/* 1 */ { bli_dgemmsup_rv_haswell_asm_1x8, bli_dgemmsup_rv_haswell_asm_1x4, bli_dgemmsup_rv_haswell_asm_1x2, bli_dgemmsup_r_haswell_ref_1x1 }, +}; + + +void bli_dgemmsup_rv_haswell_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + // Use a reference kernel if this is an edge case in the m or n + // dimensions. + if ( m0 < 6 || n0 < 8 ) + { +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + dim_t n_left = n0; + double* restrict cj = c; + double* restrict bj = b; + + // Iterate across columns (corresponding to elements of nrs) until + // n_left is zero. + for ( dim_t j = 0; n_left != 0; ++j ) + { + const dim_t nr_cur = nrs[ j ]; + + // Once we find the value of nrs that is less than (or equal to) + // n_left, we use the kernels in that column. + if ( nr_cur <= n_left ) + { + dim_t m_left = m0; + double* restrict cij = cj; + double* restrict ai = a; + + // Iterate down the current column (corresponding to elements + // of mrs) until m_left is zero. + for ( dim_t i = 0; m_left != 0; ++i ) + { + const dim_t mr_cur = mrs[ i ]; + + // Once we find the value of mrs that is less than (or equal + // to) m_left, we select that kernel. + if ( mr_cur <= m_left ) + { + FUNCPTR_T ker_fp = kmap[i][j]; + + //printf( "executing %d x %d sup kernel.\n", (int)mr_cur, (int)nr_cur ); + + // Call the kernel using current mrs and nrs values. + ker_fp + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + // Advance C and A pointers by the mrs and nrs we just + // used, and decrement m_left. + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + } + + // Advance C and B pointers by the mrs and nrs we just used, and + // decrement n_left. + cj += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx)) + vmovupd(ymm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm13, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovupd(ymm12, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) +#endif + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovupd(ymm13, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + vextractf128(imm(0x1), ymm9, xmm14) + vextractf128(imm(0x1), ymm11, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm9) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm11) + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + vextractf128(imm(0x1), ymm9, xmm14) + vextractf128(imm(0x1), ymm11, xmm15) + + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm5, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vmovupd(ymm5, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_6x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(xmm0, xmm13, xmm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(xmm0, xmm15, xmm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm13) + vmovupd(xmm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm15) + vmovupd(xmm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + //vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + //vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + //vextractf128(imm(0x1), ymm0, xmm2) + //vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + //vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + //vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + //vmovupd(xmm2, mem(rdx, rsi, 2)) + //vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(xmm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(xmm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx)) + vmovupd(xmm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + //vextractf128(imm(0x1), ymm0, xmm2) + //vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + //vmovupd(xmm2, mem(rdx, rsi, 2)) + //vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(xmm0, xmm13, xmm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm13) + vmovupd(xmm13, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + //vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + //vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + //vextractf128(imm(0x1), ymm0, xmm2) + //vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + //vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + //vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + //vmovupd(xmm2, mem(rdx, rsi, 2)) + //vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + //vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + //vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + //vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(xmm13, xmm3, xmm0) + //vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + //vmovlpd(xmm1, mem(rdx, rsi, 2)) + //vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(xmm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(xmm13, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovupd(ymm12, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + //vextractf128(imm(0x1), ymm0, xmm2) + //vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + //vmovupd(xmm2, mem(rdx, rsi, 2)) + //vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovupd(ymm13, ymm0) + + //vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + //vmovlpd(xmm1, mem(rdx, rsi, 2)) + //vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + //vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + //vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(xmm11, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + //vextractf128(imm(0x1), ymm9, xmm14) + //vextractf128(imm(0x1), ymm11, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm9) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm11) + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + //vmovupd(xmm9, mem(rcx, rsi, 2)) + //vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + //vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + //vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + //vmovsd(xmm14, mem(rdx, rsi, 2)) + //vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + vmovupd(xmm9, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + //vextractf128(imm(0x1), ymm9, xmm14) + //vextractf128(imm(0x1), ymm11, xmm15) + + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + //vmovupd(xmm9, mem(rcx, rsi, 2)) + //vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + //vmovsd(xmm14, mem(rdx, rsi, 2)) + //vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + //vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + //vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + +#if 1 + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + //vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + //vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + //vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(xmm5, xmm3, xmm0) + + //vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + //vmovlpd(xmm1, mem(rcx, rsi, 2)) + //vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vmovupd(xmm5, xmm0) + + //vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + //vmovlpd(xmm1, mem(rcx, rsi, 2)) + //vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm14, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + +#if 0 + lea(mem(rax, r9, 8), rdx) // use rdx for prefetching b. + lea(mem(rdx, r9, 8), rdx) // rdx = b + 16*rs_b; +#else + #if 1 + mov(r9, rsi) // rsi = rs_b; + sal(imm(5), rsi) // rsi = 16*rs_b; + lea(mem(rax, rsi, 1), rdx) // rdx = b + 16*rs_b; + #endif +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovupd(ymm12, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rcx, rsi, 2), rdx) // + //lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // r13 = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + + vfmadd213pd(xmm12, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + + vmovupd(xmm12, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) +#else + vmovupd(xmm12, xmm0) + + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rcx, rsi, 2), rdx) // + //lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + //vextractf128(imm(0x1), ymm8, xmm14) + //vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + //vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + //vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + //vmovupd(xmm8, mem(rcx, rsi, 2)) + //vmovupd(xmm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + //vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + //vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + //vmovsd(xmm14, mem(rdx, rsi, 2)) + //vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + //vextractf128(imm(0x1), ymm8, xmm14) + //vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + //vmovupd(xmm8, mem(rcx, rsi, 2)) + //vmovupd(xmm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + //vmovsd(xmm14, mem(rdx, rsi, 2)) + //vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rcx, rsi, 2), rdx) // + //lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + + vfmadd231pd(mem(rcx), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rcx, rsi, 2), rdx) // + //lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + + vfmadd213pd(xmm4, xmm3, xmm0) + + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vmovlpd(xmm4, mem(rcx)) + vmovhpd(xmm4, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +// ----------------------------------------------------------------------------- + +// NOTE: Normally, for any "?x1" kernel, we would call the reference kernel. +// However, at least one other subconfiguration (zen) uses this kernel set, so +// we need to be able to call a set of "?x1" kernels that we know will actually +// exist regardless of which subconfiguration these kernels were used by. Thus, +// the compromise employed here is to inline the reference kernel so it gets +// compiled as part of the haswell kernel set, and hence can unconditionally be +// called by other kernels within that kernel set. + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, mdim ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + for ( dim_t i = 0; i < mdim; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + /* for ( dim_t j = 0; j < 1; ++j ) */ \ + { \ + ctype* restrict cij = ci /*[ j*cs_c ]*/ ; \ + ctype* restrict bj = b /*[ j*cs_b ]*/ ; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(d,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ +} + +GENTFUNC( double, d, gemmsup_r_haswell_ref_6x1, 6 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_5x1, 5 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_4x1, 4 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_3x1, 3 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_2x1, 2 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_1x1, 1 ) + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c new file mode 100644 index 000000000..24df267f0 --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -0,0 +1,3265 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + +#if 0 +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 4, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rv_haswell_asm_6x8m, bli_dgemmsup_rv_haswell_asm_6x4m, bli_dgemmsup_rv_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref_6x1 }, +/* 4 */ { bli_dgemmsup_rv_haswell_asm_4x8m, bli_dgemmsup_rv_haswell_asm_4x4m, bli_dgemmsup_rv_haswell_asm_4x2m, bli_dgemmsup_r_haswell_ref_4x1 }, +/* 2 */ { bli_dgemmsup_rv_haswell_asm_2x8m, bli_dgemmsup_rv_haswell_asm_2x4m, bli_dgemmsup_rv_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref_2x1 }, +/* 1 */ { bli_dgemmsup_rv_haswell_asm_1x8m, bli_dgemmsup_rv_haswell_asm_1x4m, bli_dgemmsup_rv_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref_1x1 }, +}; +#endif + + +void bli_dgemmsup_rv_haswell_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 6 <= n_left ) + { + const dim_t nr_cur = 6; + + bli_dgemmsup_rv_haswell_asm_6x6m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. +#else + lea(mem(rax, r9, 8), rdx) // use rdx for prefetching a. + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; + //mov(r9, rsi) // rsi = cs_a; + //sal(imm(4), rsi) // rsi = 16*cs_a; + //lea(mem(rax, rsi, 1), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) + //prefetch(0, mem(rax, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + //prefetch(0, mem(rax, 5*8)) +#else +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) + //prefetch(0, mem(rax, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) + //prefetch(0, mem(rdx, r9, 2)) + //lea(mem(rdx, r9, 4), rdx) // rdx += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + //prefetch(0, mem(rax, 5*8)) +#else + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx)) + vmovupd(ymm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r11) // ii -= 1; + jne(.DLOOP6X8I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict ai = a + i_edge*rs_a; + double* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + dgemmsup_ker_ft ker_fp1 = NULL; + dgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_3x8; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_4x8; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_5x8; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + +#if 1 + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x8, + bli_dgemmsup_rv_haswell_asm_2x8, + bli_dgemmsup_rv_haswell_asm_3x8, + bli_dgemmsup_rv_haswell_asm_4x8, + bli_dgemmsup_rv_haswell_asm_5x8 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; +#else + if ( 5 <= m_left ) + { + const dim_t mr_cur = 5; + + bli_dgemmsup_rv_haswell_asm_5x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 4 <= m_left ) + { + const dim_t mr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_4x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rv_haswell_asm_3x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_2x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rv_haswell_asm_1x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } +#endif + } +} + +void bli_dgemmsup_rv_haswell_asm_6x6m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm1, ymm1, ymm1) // zero ymm1 since we only use the lower + vxorpd(ymm4, ymm4, ymm4) // half (xmm1), and nans/infs may slow us + vxorpd(ymm5, ymm5, ymm5) // down. + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. +#else + lea(mem(rax, r9, 8), rdx) // use rdx for prefetching a. + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; + //mov(r9, rsi) // rsi = cs_a; + //sal(imm(4), rsi) // rsi = 16*cs_a; + //lea(mem(rax, rsi, 1), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) + //prefetch(0, mem(rax, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + //prefetch(0, mem(rax, 5*8)) +#else +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) + //prefetch(0, mem(rax, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) + //prefetch(0, mem(rdx, r9, 2)) + //lea(mem(rdx, r9, 4), rdx) // rdx += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + //prefetch(0, mem(rax, 5*8)) +#else + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(xmm0, xmm13, xmm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(xmm0, xmm15, xmm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm13) + vmovupd(xmm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm15) + vmovupd(xmm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + //vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + //vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + //vextractf128(imm(0x1), ymm0, xmm2) + //vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + //vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + //vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + //vmovupd(xmm2, mem(rdx, rsi, 2)) + //vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(xmm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(xmm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(xmm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(xmm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx)) + vmovupd(xmm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + //vmovupd(ymm9, mem(rcx, rsi, 2)) + //vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + //vextractf128(imm(0x1), ymm0, xmm2) + //vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + //vmovupd(xmm2, mem(rdx, rsi, 2)) + //vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r11) // ii -= 1; + jne(.DLOOP6X8I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 6; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict ai = a + i_edge*rs_a; + double* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + dgemmsup_ker_ft ker_fp1 = NULL; + dgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x6; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_3x6; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x6; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_4x6; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x6; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_5x6; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + +#if 1 + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x6, + bli_dgemmsup_rv_haswell_asm_2x6, + bli_dgemmsup_rv_haswell_asm_3x6, + bli_dgemmsup_rv_haswell_asm_4x6, + bli_dgemmsup_rv_haswell_asm_5x6 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; +#else + if ( 5 <= m_left ) + { + const dim_t mr_cur = 5; + + bli_dgemmsup_rv_haswell_asm_5x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 4 <= m_left ) + { + const dim_t mr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_4x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rv_haswell_asm_3x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_2x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rv_haswell_asm_1x6 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } +#endif + } +} + +void bli_dgemmsup_rv_haswell_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm14, ymm14, ymm14) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) + + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + + +#if 1 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + //lea(mem(rax, r9, 8), rdx) // use rdx for prefetching a. + //lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // rdx += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) // rdx += cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm14, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r11) // ii -= 1; + jne(.DLOOP6X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict ai = a + i_edge*rs_a; + double* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + dgemmsup_ker_ft ker_fp1 = NULL; + dgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x4; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_3x4; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x4; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_4x4; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x4; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_5x4; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + +#if 1 + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x4, + bli_dgemmsup_rv_haswell_asm_2x4, + bli_dgemmsup_rv_haswell_asm_3x4, + bli_dgemmsup_rv_haswell_asm_4x4, + bli_dgemmsup_rv_haswell_asm_5x4 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; +#else + if ( 5 <= m_left ) + { + const dim_t mr_cur = 5; + + bli_dgemmsup_rv_haswell_asm_5x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 4 <= m_left ) + { + const dim_t mr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_4x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rv_haswell_asm_3x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rv_haswell_asm_1x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } +#endif + } +} + +void bli_dgemmsup_rv_haswell_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X2I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(xmm4, xmm4, xmm4) + vxorpd(xmm6, xmm6, xmm6) + vxorpd(xmm8, xmm8, xmm8) + vxorpd(xmm10, xmm10, xmm10) + vxorpd(xmm12, xmm12, xmm12) + vxorpd(xmm14, xmm14, xmm14) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) + + +#if 0 + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c +#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c +#endif + + +#if 1 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + //lea(mem(rax, r9, 8), rdx) // use rdx for prefetching a. + //lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // rdx += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) // rdx += cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r11) // ii -= 1; + jne(.DLOOP6X2I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict ai = a + i_edge*rs_a; + double* restrict bj = b; + +#if 0 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m_left ) + { + dgemmsup_ker_ft ker_fp1 = NULL; + dgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m_left == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x2; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_3x2; + } + else if ( m_left == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x2; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_4x2; + } + else // if ( m_left == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x2; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_5x2; + } + + ker_fp1 + ( + conja, conjb, mr1, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + +#if 1 + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x2, + bli_dgemmsup_rv_haswell_asm_2x2, + bli_dgemmsup_rv_haswell_asm_3x2, + bli_dgemmsup_rv_haswell_asm_4x2, + bli_dgemmsup_rv_haswell_asm_5x2 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; +#else + if ( 5 <= m_left ) + { + const dim_t mr_cur = 5; + + bli_dgemmsup_rv_haswell_asm_5x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 4 <= m_left ) + { + const dim_t mr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_4x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rv_haswell_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rv_haswell_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } +#endif + } +} + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c new file mode 100644 index 000000000..abce959a2 --- /dev/null +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c @@ -0,0 +1,4117 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + +#if 0 +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 4, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rv_haswell_asm_6x8n, bli_dgemmsup_rv_haswell_asm_6x4n, bli_dgemmsup_rv_haswell_asm_6x2n, bli_dgemmsup_r_haswell_ref_6x1 }, +/* 4 */ { bli_dgemmsup_rv_haswell_asm_4x8n, bli_dgemmsup_rv_haswell_asm_4x4n, bli_dgemmsup_rv_haswell_asm_4x2n, bli_dgemmsup_r_haswell_ref_4x1 }, +/* 2 */ { bli_dgemmsup_rv_haswell_asm_2x8n, bli_dgemmsup_rv_haswell_asm_2x4n, bli_dgemmsup_rv_haswell_asm_2x2n, bli_dgemmsup_r_haswell_ref_2x1 }, +/* 1 */ { bli_dgemmsup_rv_haswell_asm_1x8n, bli_dgemmsup_rv_haswell_asm_1x4n, bli_dgemmsup_rv_haswell_asm_1x2n, bli_dgemmsup_r_haswell_ref_1x1 }, +}; +#endif + + +void bli_dgemmsup_rv_haswell_asm_6x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t m_left = m0 % 6; + +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); return; +#endif + +//printf( "rv_6x8n: %d %d %d\n", (int)m0, (int)n0, (int)k0 ); + // First check whether this is a edge case in the m dimension. If so, + // dispatch other ?x8m kernels, as needed. + if ( m_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + +#if 1 + // We add special handling for slightly inflated MR blocksizes + // at edge cases, up to a maximum of 9. + if ( 6 < m0 ) + { + dgemmsup_ker_ft ker_fp1 = NULL; + dgemmsup_ker_ft ker_fp2 = NULL; + dim_t mr1, mr2; + + if ( m0 == 7 ) + { + mr1 = 4; mr2 = 3; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8n; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_3x8n; + } + else if ( m0 == 8 ) + { + mr1 = 4; mr2 = 4; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8n; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_4x8n; + } + else // if ( m0 == 9 ) + { + mr1 = 4; mr2 = 5; + ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8n; + ker_fp2 = bli_dgemmsup_rv_haswell_asm_5x8n; + } + + ker_fp1 + ( + conja, conjb, mr1, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr1*rs_c0; ai += mr1*rs_a0; + + ker_fp2 + ( + conja, conjb, mr2, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +#endif + +#if 1 + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x8n, + bli_dgemmsup_rv_haswell_asm_2x8n, + bli_dgemmsup_rv_haswell_asm_3x8n, + bli_dgemmsup_rv_haswell_asm_4x8n, + bli_dgemmsup_rv_haswell_asm_5x8n + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; +#else + if ( 5 <= m_left ) + { + const dim_t mr_cur = 5; + + bli_dgemmsup_rv_haswell_asm_5x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 4 <= m_left ) + { + const dim_t mr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_4x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rv_haswell_asm_3x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_2x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { +#if 1 + const dim_t mr_cur = 1; + + bli_dgemmsup_rv_haswell_asm_1x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_TRANSPOSE, conja, k0, n0, + alpha, bj, rs_b0, cs_b0, ai, cs_a0, + beta, cij, cs_c0, cntx, NULL + ); +#endif + } + return; +#endif + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.DLOOP6X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + +#if 1 + // use byte offsets from rbx to + // prefetch lines from next upanel + // of b. +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + //prefetch(0, mem(rdx, 11*8)) // prefetch line of next upanel of b + prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + //prefetch(0, mem(rdx, r10, 1, 11*8)) + prefetch(0, mem(rbx, 11*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + //prefetch(0, mem(rdx, r10, 2, 11*8)) + prefetch(0, mem(rbx, 11*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + //prefetch(0, mem(rdx, rcx, 1, 11*8)) + prefetch(0, mem(rbx, 11*8)) + //prefetch(0, mem(rdx, r9, 1, 7*8)) + //lea(mem(rdx, r10, 4), rdx) // a_prefetch += 4*cs_a; +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx)) + vmovupd(ymm15, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + + dec(r11) // jj -= 1; + jne(.DLOOP6X8J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 6; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_6x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_6x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 1 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_6x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rv_haswell_asm_5x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.DLOOP6X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + //vxorpd(ymm14, ymm14, ymm14) + //vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + +#if 1 + + // use byte offsets from rbx to + // prefetch lines from next upanel + // of b. +#else + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + lea(mem(rbx, r10, 8), rdx) // use rdx for prefetching b. + lea(mem(rdx, r10, 8), rdx) // rdx = b + 16*rs_b; + + #if 0 + mov(r9, rsi) // rsi = rs_b; + sal(imm(5), rsi) // rsi = 16*rs_b; + lea(mem(rax, rsi, 1), rdx) // rdx = b + 16*rs_b; + #endif +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + //prefetch(0, mem(rdx, 11*8)) // prefetch line of next upanel of b + prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + +#if 1 + //prefetch(0, mem(rdx, r10, 1, 11*8)) + prefetch(0, mem(rbx, 11*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + +#if 1 + //prefetch(0, mem(rdx, r10, 2, 11*8)) + prefetch(0, mem(rbx, 11*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + +#if 1 + //prefetch(0, mem(rdx, rcx, 1, 11*8)) + prefetch(0, mem(rbx, 11*8)) + //prefetch(0, mem(rdx, r9, 1, 7*8)) + //lea(mem(rdx, r10, 4), rdx) // a_prefetch += 4*cs_a; +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovlpd(mem(rdx), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm13, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovupd(ymm12, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + +#if 0 + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) +#else + vmovupd(ymm13, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx)) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) +#endif + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + + dec(r11) // jj -= 1; + jne(.DLOOP6X8J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 5; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_5x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_5x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 1 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_5x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rv_haswell_asm_4x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.DLOOP4X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + +#if 1 + //lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + + // use byte offsets from rbx to + // prefetch lines from next upanel + // of b. +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + //prefetch(0, mem(rdx, 11*8)) // prefetch line of next upanel of b + prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + +#if 1 + //prefetch(0, mem(rdx, r10, 1, 11*8)) + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + +#if 1 + //prefetch(0, mem(rdx, r10, 2, 11*8)) + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + +#if 1 + //prefetch(0, mem(rdx, rcx, 1, 11*8)) + //lea(mem(rdx, r10, 4), rdx) // a_prefetch += 4*cs_a; + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + + dec(r11) // jj -= 1; + jne(.DLOOP4X8J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 4; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_4x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_4x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_4x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rv_haswell_asm_3x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.DLOOP4X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(r12, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + +#if 1 + //lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + + // use byte offsets from rbx to + // prefetch lines from next upanel + // of b. +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + //prefetch(0, mem(rdx, 11*8)) // prefetch line of next upanel of b + prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + +#if 1 + //prefetch(0, mem(rdx, r10, 1, 11*8)) + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + +#if 1 + //prefetch(0, mem(rdx, r10, 2, 11*8)) + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + +#if 1 + //prefetch(0, mem(rdx, rcx, 1, 11*8)) + //lea(mem(rdx, r10, 4), rdx) // a_prefetch += 4*cs_a; + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + vextractf128(imm(0x1), ymm9, xmm14) + vextractf128(imm(0x1), ymm11, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm9) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm11) + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx)) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + vextractf128(imm(0x1), ymm9, xmm14) + vextractf128(imm(0x1), ymm11, xmm15) + + vmovupd(xmm5, mem(rcx)) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx)) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + + dec(r11) // jj -= 1; + jne(.DLOOP4X8J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_3x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_3x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rv_haswell_asm_2x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.DLOOP2X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(r12, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + +#if 1 + //lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + + // use byte offsets from rbx to + // prefetch lines from next upanel + // of b. +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + //prefetch(0, mem(rdx, 11*8)) // prefetch line of next upanel of b + prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + +#if 1 + //prefetch(0, mem(rdx, r10, 1, 11*8)) + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + +#if 1 + //prefetch(0, mem(rdx, r10, 2, 11*8)) + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + +#if 1 + //prefetch(0, mem(rdx, rcx, 1, 11*8)) + //lea(mem(rdx, r10, 4), rdx) // a_prefetch += 4*cs_a; + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx)) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + + dec(r11) // jj -= 1; + jne(.DLOOP2X8J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 2; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_2x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_2x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rv_haswell_asm_1x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t n_iter = n0 / 8; + uint64_t n_left = n0 % 8; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rbx = b + // read rax from var(a) near beginning of loop + // r11 = m dim index ii + + mov(var(n_iter), r11) // jj = n_iter; + + label(.DLOOP1X8J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rbx) // reset rbx to current upanel of b. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(r12, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + +#if 1 + //lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + + // use byte offsets from rbx to + // prefetch lines from next upanel + // of b. +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + //prefetch(0, mem(rdx, 11*8)) // prefetch line of next upanel of b + prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + +#if 1 + //prefetch(0, mem(rdx, r10, 1, 11*8)) + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + +#if 1 + //prefetch(0, mem(rdx, r10, 2, 11*8)) + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + +#if 1 + //prefetch(0, mem(rdx, rcx, 1, 11*8)) + //lea(mem(rdx, r10, 4), rdx) // a_prefetch += 4*cs_a; + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rbx, 11*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + + vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm5, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 4)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + vmovupd(ymm5, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx)) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rsi, 8), r12) // c_jj = r12 += 8*cs_c + + add(imm(8*8), r14) // b_jj = r14 += 8*cs_b + + dec(r11) // jj -= 1; + jne(.DLOOP1X8J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 1; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rv_haswell_asm_1x4 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rv_haswell_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 1 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref_1x1 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_ddotxv_ex + ( + conja, conjb, k0, + alpha, ai, cs_a0, bj, rs_b0, + beta, cij, cntx, NULL + ); +#endif + } + } +} + diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8.c b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8.c new file mode 100644 index 000000000..c5addd9cf --- /dev/null +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8.c @@ -0,0 +1,5249 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8, bli_dgemmsup_rd_haswell_asm_6x4, bli_dgemmsup_rd_haswell_asm_6x2, bli_dgemmsup_r_haswell_ref }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8, bli_dgemmsup_rd_haswell_asm_3x4, bli_dgemmsup_rd_haswell_asm_3x2, bli_dgemmsup_r_haswell_ref }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8, bli_dgemmsup_rd_haswell_asm_2x4, bli_dgemmsup_rd_haswell_asm_2x2, bli_dgemmsup_r_haswell_ref }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8, bli_dgemmsup_rd_haswell_asm_1x4, bli_dgemmsup_rd_haswell_asm_1x2, bli_dgemmsup_r_haswell_ref } +}; + + +void bli_dgemmsup_rd_haswell_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Use a reference kernel if this is an edge case in the m or n + // dimensions. + if ( m0 < 6 || n0 < 8 ) + { + dim_t n_left = n0; + double* restrict cj = c; + double* restrict bj = b; + + // Iterate across columns (corresponding to elements of nrs) until + // n_left is zero. + for ( dim_t j = 0; n_left != 0; ++j ) + { + const dim_t nr_cur = nrs[ j ]; + + // Once we find the value of nrs that is less than (or equal to) + // n_left, we use the kernels in that column. + if ( nr_cur <= n_left ) + { + dim_t m_left = m0; + double* restrict cij = cj; + double* restrict ai = a; + + // Iterate down the current column (corresponding to elements + // of mrs) until m_left is zero. + for ( dim_t i = 0; m_left != 0; ++i ) + { + const dim_t mr_cur = mrs[ i ]; + + // Once we find the value of mrs that is less than (or equal + // to) m_left, we select that kernel. + if ( mr_cur <= m_left ) + { + FUNCPTR_T ker_fp = kmap[i][j]; + + //printf( "executing %d x %d sup kernel.\n", (int)mr_cur, (int)nr_cur ); + + // Call the kernel using current mrs and nrs values. + ker_fp + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + // Advance C and A pointers by the mrs and nrs we just + // used, and decrement m_left. + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + } + + // Advance C and B pointers by the mrs and nrs we just used, and + // decrement n_left. + cj += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r12) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r10) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r10 = rcx = c + // r12 = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + +#if 1 + mov(imm(0), r9) // ii = 0; + + label(.DLOOP3X4I) // LOOP OVER ii = [ 0 1 ... ] + + + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(rdi, rsi) // rsi *= rs_c; + lea(mem(r10, rsi, 1), rdx) // rdx = c_jj + 3*ii*rs_c; + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(r8, rsi) // rsi *= rs_a; + lea(mem(r12, rsi, 1), r12) // rax = a + 3*ii*rs_a; + + + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + vzeroall() // zero all xmm/ymm registers. + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(rdx, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(r14, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem( , r12, 1), rax) // rax = a_ii; +#endif + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + +#if 1 + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + add(imm(3), r9) // ii += 3; + cmp(imm(3), r9) // compare ii to 3 + jle(.DLOOP3X4I) // if ii <= 3, jump to beginning +#endif + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_3x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r12) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r10) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r10 = rcx = c + // r12 = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + vzeroall() // zero all xmm/ymm registers. + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r10, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(r14, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem( , r12, 1), rax) // rax = a; + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r12) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r10) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r10 = rcx = c + // r12 = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + vzeroall() // zero all xmm/ymm registers. + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r10, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(r14, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem( , r12, 1), rax) // rax = a; + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r12) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r10) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r10 = rcx = c + // r12 = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + vzeroall() // zero all xmm/ymm registers. + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r10, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(r14, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem( , r12, 1), rax) // rax = a; + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r12) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r10) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r10 = rcx = c + // r12 = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r9) // ii = 0; + + label(.DLOOP3X4I) // LOOP OVER ii = [ 0 1 ... ] + + + + vzeroall() // zero all xmm/ymm registers. + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(rdi, rsi) // rsi *= rs_c; + lea(mem(r10, rsi, 1), rcx) // rcx = c + 3*ii*rs_c; + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(r8, rsi) // rsi *= rs_a; + lea(mem(r12, rsi, 1), rax) // rax = a + 3*ii*rs_a; + + lea(mem( , r14, 1), rbx) // rbx = b; + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(3), r9) // ii += 3; + cmp(imm(3), r9) // compare ii to 3 + jle(.DLOOP3X4I) // if ii <= 3, jump to beginning + // of ii loop; otherwise, loop ends. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | | | + -------- -- -- -- ... | | | | + -------- += -- -- -- | | | | + -------- | | | | + -------- : + -------- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + // xmm4 = sum(ymm4) sum(ymm5) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c new file mode 100644 index 000000000..55ae6d0f9 --- /dev/null +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c @@ -0,0 +1,5543 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +#if 0 +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8m, bli_dgemmsup_rd_haswell_asm_6x4m, bli_dgemmsup_rd_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8m, bli_dgemmsup_rd_haswell_asm_3x4m, bli_dgemmsup_rd_haswell_asm_3x2m, bli_dgemmsup_r_haswell_ref }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8m, bli_dgemmsup_rd_haswell_asm_2x4m, bli_dgemmsup_rd_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8m, bli_dgemmsup_rd_haswell_asm_1x4m, bli_dgemmsup_rd_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref } +}; +#endif + + +void bli_dgemmsup_rd_haswell_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rd_haswell_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + // r10 = unused + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; + lea(mem(r14), rax) // rax = a + 3*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x4m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x4m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | | | + -------- -- -- -- ... | | | | + -------- += -- -- -- | | | | + -------- | | | | + -------- : + -------- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c; + lea(mem(r14), rax) // rax = a + 6*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + // xmm4 = sum(ymm4) sum(ymm5) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.newji b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.newji new file mode 100644 index 000000000..c1cb37214 --- /dev/null +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.newji @@ -0,0 +1,5628 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +#if 0 +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8m, bli_dgemmsup_rd_haswell_asm_6x4m, bli_dgemmsup_rd_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8m, bli_dgemmsup_rd_haswell_asm_3x4m, bli_dgemmsup_rd_haswell_asm_3x2m, bli_dgemmsup_r_haswell_ref }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8m, bli_dgemmsup_rd_haswell_asm_2x4m, bli_dgemmsup_rd_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8m, bli_dgemmsup_rd_haswell_asm_1x4m, bli_dgemmsup_rd_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref } +}; +#endif + + +void bli_dgemmsup_rd_haswell_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rd_haswell_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + // r10 = unused + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; + lea(mem(r14), rax) // rax = a + 3*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x4m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x4m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | | | + -------- -- -- -- ... | | | | + -------- += -- -- -- | | | | + -------- | | | | + -------- : + -------- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c; + lea(mem(r14), rax) // rax = a + 6*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + // xmm4 = sum(ymm4) sum(ymm5) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.worksij b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.worksij new file mode 100644 index 000000000..fd1c2ae65 --- /dev/null +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.worksij @@ -0,0 +1,5634 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +#if 0 +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8m, bli_dgemmsup_rd_haswell_asm_6x4m, bli_dgemmsup_rd_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8m, bli_dgemmsup_rd_haswell_asm_3x4m, bli_dgemmsup_rd_haswell_asm_3x2m, bli_dgemmsup_r_haswell_ref }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8m, bli_dgemmsup_rd_haswell_asm_2x4m, bli_dgemmsup_rd_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8m, bli_dgemmsup_rd_haswell_asm_1x4m, bli_dgemmsup_rd_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref } +}; +#endif + + +void bli_dgemmsup_rd_haswell_asm_6x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rd_haswell_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a_ii; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + // r10 = unused + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x8m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 0 + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c +#else + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(4), r15) // compare jj to 4 + jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + // r10 = unused + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; + lea(mem(r14), rax) // rax = a + 3*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x4m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x4m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | | | + -------- -- -- -- ... | | | | + -------- += -- -- -- | | | | + -------- | | | | + -------- : + -------- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c; + lea(mem(r14), rax) // rax = a + 6*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x2m + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + // xmm4 = sum(ymm4) sum(ymm5) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8n.c b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8n.c new file mode 100644 index 000000000..a23764f8d --- /dev/null +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8n.c @@ -0,0 +1,5836 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( float, s, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( scomplex, c, gemmsup_r_haswell_ref ) +GEMMSUP_KER_PROT( dcomplex, z, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +#if 0 +static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8n, bli_dgemmsup_rd_haswell_asm_6x4n, bli_dgemmsup_rd_haswell_asm_6x2n, bli_dgemmsup_r_haswell_ref }, +/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8n, bli_dgemmsup_rd_haswell_asm_3x4n, bli_dgemmsup_rd_haswell_asm_3x2n, bli_dgemmsup_r_haswell_ref }, +/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8n, bli_dgemmsup_rd_haswell_asm_2x4n, bli_dgemmsup_rd_haswell_asm_2x2n, bli_dgemmsup_r_haswell_ref }, +/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8n, bli_dgemmsup_rd_haswell_asm_1x4n, bli_dgemmsup_rd_haswell_asm_1x2n, bli_dgemmsup_r_haswell_ref } +}; +#endif + + +void bli_dgemmsup_rd_haswell_asm_6x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); return; +#endif + uint64_t m_left = m0 % 6; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other ?x8m kernels, as needed. + if ( m_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 3 <= m_left ) + { + const dim_t mr_cur = 3; + + bli_dgemmsup_rd_haswell_asm_3x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 2 <= m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { +#if 0 + const dim_t mr_cur = 1; + + //bli_dgemmsup_r_haswell_ref + bli_dgemmsup_rd_haswell_asm_1x8n + ( + conja, conjb, mr_cur, n0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_TRANSPOSE, conja, k0, n0, + alpha, bj, rs_b0, cs_b0, ai, cs_a0, + beta, cij, cs_c0, cntx, NULL + ); +#endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + //mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r9) // ii = 0; + + label(.DLOOP3X4I) // LOOP OVER ii = [ 0 1 ... ] + + + + mov(var(b), r14) // load address of b + mov(var(c), r12) // load address of c + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(rdi, rsi) // rsi *= rs_c + lea(mem(r12, rsi, 1), r12) // r12 = c + 3*ii*rs_c; + + lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; + imul(r8, rsi) // rsi *= rs_a; + lea(mem(rdx, rsi, 1), rdx) // rax = a + 3*ii*rs_a; + + + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; +#else + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + add(imm(3), r9) // ii += 3; + cmp(imm(3), r9) // compare ii to 3 + jle(.DLOOP3X4I) // if ii <= 3, jump to beginning + // of ii loop; otherwise, loop ends. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 6; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2n + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + //bli_dgemmsup_rd_haswell_asm_6x1n + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; +#else + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_3x2n + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + //bli_dgemmsup_rd_haswell_asm_3x1n + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_2x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; +#else + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 2; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x2n + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + //bli_dgemmsup_rd_haswell_asm_2x1n + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_1x8n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rdx) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), r14) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // rdx = rax = a + // r14 = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(var(n_iter), r15) // jj = n_iter; + + label(.DLOOP3X4J) // LOOP OVER jj = [ n_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(rdx), rax) // rax = a_ii; + lea(mem(r14), rbx) // rbx = b_jj; + + +#if 1 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b + add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; +#else + prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b + prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + +#if 1 + prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b + prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(r10, 8*8)) // prefetch rbx + 4*cs_b + 8*rs_b + prefetch(0, mem(r10, r11, 1, 8*8)) // prefetch rbx + 5*cs_b + 8*rs_b +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(r10, r11, 2, 8*8)) // prefetch rbx + 6*cs_b + 8*rs_b + prefetch(0, mem(r10, r13, 1, 8*8)) // prefetch rbx + 7*cs_b + 8*rs_b + add(imm(16*8), r10) // r10 += 8*rs_b = 8*8; +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4*8), r12) // c_jj = r12 += 4*cs_c + + lea(mem(r14, r11, 4), r14) // b_jj = r14 += 4*cs_b + + dec(r15) // jj -= 1; + jne(.DLOOP3X4J) // iterate again if jj != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [n_iter] "m" (n_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 1; + const dim_t j_edge = n0 - ( dim_t )n_left; + + double* restrict cij = c + j_edge*cs_c; + double* restrict ai = a; + double* restrict bj = b + j_edge*cs_b; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_1x2n + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { +#if 0 + const dim_t nr_cur = 1; + + //bli_dgemmsup_rd_haswell_asm_1x1n + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); +#else + bli_ddotxv_ex + ( + conja, conjb, k0, + alpha, ai, cs_a0, bj, rs_b0, + beta, cij, cntx, NULL + ); +#endif + } + } +} + +void bli_dgemmsup_rd_haswell_asm_6x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + // r10 = unused + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; + lea(mem(r14), rax) // rax = a + 3*ii*rs_a; + lea(mem(rdx), rbx) // rbx = b; + + +#if 0 + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x4n + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x4n + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_3x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | | | + -------- -- -- -- ... | | | | + -------- += -- -- -- | | | | + -------- | | | | + -------- : + -------- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_b = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_6x2n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + // xmm10 = sum(ymm10) sum(ymm11) + // xmm12 = sum(ymm12) sum(ymm13) + // xmm14 = sum(ymm14) sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_3x2n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + // xmm8 = sum(ymm8) sum(ymm9) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4 = sum(ymm4) sum(ymm5) + // xmm6 = sum(ymm6) sum(ymm7) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2n + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +/* + rrc: + -------- -- -- -- | | + -------- -- -- -- ... | | + -------- += -- -- -- | | + -------- -- -- -- | | + -------- -- -- -- : + -------- -- -- -- : +*/ + // ------------------------------------------------------------------------- + + begin_asm() + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + mov(var(a), rax) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + + // initialize loop by pre-loading + // a column of a. + + mov(var(c), rcx) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm5) + + // xmm4 = sum(ymm4) sum(ymm5) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/bli_kernels_haswell.h b/kernels/haswell/bli_kernels_haswell.h index 53d434dff..df49a77dd 100644 --- a/kernels/haswell/bli_kernels_haswell.h +++ b/kernels/haswell/bli_kernels_haswell.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,7 +33,7 @@ */ -// -- level-3 -- +// -- level-3 ------------------------------------------------------------------ // gemm (asm d6x8) GEMM_UKR_PROT( float, s, gemm_haswell_asm_6x16 ) @@ -61,3 +62,87 @@ GEMMTRSM_UKR_PROT( double, d, gemmtrsm_u_haswell_asm_6x8 ) //GEMM_UKR_PROT( scomplex, c, gemm_haswell_asm_8x3 ) //GEMM_UKR_PROT( dcomplex, z, gemm_haswell_asm_4x3 ) + +// -- level-3 sup -------------------------------------------------------------- + +// gemmsup_rv + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x8 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x6 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x6 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x6 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x6 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x6 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x6 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x4 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x2 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_6x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_5x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_4x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_3x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_2x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_1x1 ) + +// gemmsup_rv (mkernel in m dim) + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x6m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x4m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x2m ) + +// gemmsup_rv (mkernel in n dim) + +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x8n ) + +// gemmsup_rd + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x8 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x4 ) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_3x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x2 ) + +// gemmsup_rd (mkernel in m dim) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x4m ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x2m ) + +// gemmsup_rd (mkernel in n dim) + +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_3x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x8n ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x8n ) + diff --git a/kernels/zen/1f/bli_axpyf_zen_int_8.c b/kernels/zen/1f/bli_axpyf_zen_int_8.c index 3d424e896..b958600ce 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_8.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_8.c @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2016 - 2018, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index ef75eeda1..e40c785d8 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2016 - 2018, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/mpi_test/Makefile b/mpi_test/Makefile index 8bf871b99..00ca01e47 100644 --- a/mpi_test/Makefile +++ b/mpi_test/Makefile @@ -134,7 +134,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) diff --git a/ref_kernels/3/bli_gemmsup_ref.c b/ref_kernels/3/bli_gemmsup_ref.c new file mode 100644 index 000000000..dc09267fa --- /dev/null +++ b/ref_kernels/3/bli_gemmsup_ref.c @@ -0,0 +1,562 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// -- Row storage case --------------------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* NOTE: This microkernel can actually handle arbitrarily large + values of m, n, and k. */ \ +\ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_r, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + +// +// -- Column storage case ------------------------------------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* NOTE: This microkernel can actually handle arbitrarily large + values of m, n, and k. */ \ +\ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_c, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + +// +// -- General storage case ----------------------------------------------------- +// + +INSERT_GENTFUNC_BASIC2( gemmsup_g, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + + + + + + + + +#if 0 + +// +// -- Row storage case --------------------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const dim_t mn = m * n; \ +\ + ctype ab[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const inc_t rs_ab = n; \ + const inc_t cs_ab = 1; \ +\ +\ + /* Assumptions: m <= mr, n <= nr so that the temporary array ab is + sufficiently large enough to hold the m x n microtile. + + The ability to handle m < mr and n < nr is being provided so that + optimized ukernels can call one of these reference implementations + for their edge cases, if they choose. When they do so, they will + need to call the function directly, by its configuration-mangled + name, since it will have been overwritten in the context when + the optimized ukernel functions are registered. */ \ +\ +\ + /* Initialize the accumulator elements in ab to zero. */ \ + for ( dim_t i = 0; i < mn; ++i ) \ + { \ + PASTEMAC(ch,set0s)( ab[i] ); \ + } \ +\ + /* Perform a series of k rank-1 updates into ab. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + /* Traverse ab by rows; assume cs_ab = 1. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + PASTEMAC(ch,dots) \ + ( \ + a[ i*rs_a ], \ + b[ j*cs_b ], \ + ab[ i*rs_ab + j*cs_ab ] \ + ); \ + } \ + } \ +\ + a += cs_a; \ + b += rs_b; \ + } \ +\ + /* Scale the result in ab by alpha. */ \ + for ( dim_t i = 0; i < mn; ++i ) \ + { \ + PASTEMAC(ch,scals)( *alpha, ab[i] ); \ + } \ +\ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c with the + result in ab. Otherwise, scale by beta and accumulate ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + /* Traverse ab and c by rows; assume cs_a = cs_a = 1. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + PASTEMAC(ch,adds) \ + ( \ + ab[ i*rs_ab + j*1 ], \ + c[ i*rs_c + j*1 ] \ + ) \ + } \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ +\ + /* Traverse ab and c by rows; assume cs_a = cs_a = 1. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + PASTEMAC(ch,copys) \ + ( \ + ab[ i*rs_ab + j*1 ], \ + c[ i*rs_c + j*1 ] \ + ) \ + } \ + } \ + else /* beta != 0 && beta != 1 */ \ + { \ + /* Traverse ab and c by rows; assume cs_a = cs_a = 1. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + PASTEMAC(ch,xpbys) \ + ( \ + ab[ i*rs_ab + j*1 ], \ + *beta, \ + c[ i*rs_c + j*1 ] \ + ) \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_r, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + +// +// -- Column storage case ------------------------------------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const dim_t mn = m * n; \ +\ + ctype ab[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const inc_t rs_ab = 1; \ + const inc_t cs_ab = m; \ +\ +\ + /* Assumptions: m <= mr, n <= nr so that the temporary array ab is + sufficiently large enough to hold the m x n microtile. + + The ability to handle m < mr and n < nr is being provided so that + optimized ukernels can call one of these reference implementations + for their edge cases, if they choose. When they do so, they will + need to call the function directly, by its configuration-mangled + name, since it will have been overwritten in the context when + the optimized ukernel functions are registered. */ \ +\ +\ + /* Initialize the accumulator elements in ab to zero. */ \ + for ( dim_t i = 0; i < mn; ++i ) \ + { \ + PASTEMAC(ch,set0s)( ab[i] ); \ + } \ +\ + /* Perform a series of k rank-1 updates into ab. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + /* Traverse ab by columns; assume rs_ab = 1. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,dots) \ + ( \ + a[ i*rs_a ], \ + b[ j*cs_b ], \ + ab[ i*rs_ab + j*cs_ab ] \ + ); \ + } \ + } \ +\ + a += cs_a; \ + b += rs_b; \ + } \ +\ + /* Scale the result in ab by alpha. */ \ + for ( dim_t i = 0; i < mn; ++i ) \ + { \ + PASTEMAC(ch,scals)( *alpha, ab[i] ); \ + } \ +\ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c with the + result in ab. Otherwise, scale by beta and accumulate ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + /* Traverse ab and c by columns; assume rs_a = rs_a = 1. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,adds) \ + ( \ + ab[ i*1 + j*cs_ab ], \ + c[ i*1 + j*cs_c ] \ + ) \ + } \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* Traverse ab and c by columns; assume rs_a = rs_a = 1. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,copys) \ + ( \ + ab[ i*1 + j*cs_ab ], \ + c[ i*1 + j*cs_c ] \ + ) \ + } \ + } \ + else /* beta != 0 && beta != 1 */ \ + { \ + /* Traverse ab and c by columns; assume rs_a = rs_a = 1. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,xpbys) \ + ( \ + ab[ i*1 + j*cs_ab ], \ + *beta, \ + c[ i*1 + j*cs_c ] \ + ) \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_c, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + +// +// -- General storage case ----------------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + const dim_t mn = m * n; \ +\ + ctype ab[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const inc_t rs_ab = 1; \ + const inc_t cs_ab = m; \ +\ +\ + /* Assumptions: m <= mr, n <= nr so that the temporary array ab is + sufficiently large enough to hold the m x n microtile. + + The ability to handle m < mr and n < nr is being provided so that + optimized ukernels can call one of these reference implementations + for their edge cases, if they choose. When they do so, they will + need to call the function directly, by its configuration-mangled + name, since it will have been overwritten in the context when + the optimized ukernel functions are registered. */ \ +\ +\ + /* Initialize the accumulator elements in ab to zero. */ \ + for ( dim_t i = 0; i < mn; ++i ) \ + { \ + PASTEMAC(ch,set0s)( ab[i] ); \ + } \ +\ + /* Perform a series of k rank-1 updates into ab. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + /* General storage: doesn't matter how we traverse ab. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,dots) \ + ( \ + a[ i*rs_a ], \ + b[ j*cs_b ], \ + ab[ i*rs_ab + j*cs_ab ] \ + ); \ + } \ + } \ +\ + a += cs_a; \ + b += rs_b; \ + } \ +\ + /* Scale the result in ab by alpha. */ \ + for ( dim_t i = 0; i < mn; ++i ) \ + { \ + PASTEMAC(ch,scals)( *alpha, ab[i] ); \ + } \ +\ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c with the + result in ab. Otherwise, scale by beta and accumulate ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + /* General storage: doesn't matter how we traverse ab and c. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,adds) \ + ( \ + ab[ i*rs_ab + j*cs_ab ], \ + c[ i*rs_c + j*cs_c ] \ + ) \ + } \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* General storage: doesn't matter how we traverse ab and c. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,copys) \ + ( \ + ab[ i*rs_ab + j*cs_ab ], \ + c[ i*rs_c + j*cs_c ] \ + ) \ + } \ + } \ + else /* beta != 0 && beta != 1 */ \ + { \ + /* General storage: doesn't matter how we traverse ab and c. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + PASTEMAC(ch,xpbys) \ + ( \ + ab[ i*rs_ab + j*cs_ab ], \ + *beta, \ + c[ i*rs_c + j*cs_c ] \ + ) \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_g, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) + +#endif diff --git a/ref_kernels/bli_cntx_ref.c b/ref_kernels/bli_cntx_ref.c index d97c89bea..36cdd52dc 100644 --- a/ref_kernels/bli_cntx_ref.c +++ b/ref_kernels/bli_cntx_ref.c @@ -124,6 +124,26 @@ // template. #include "bli_l3_ind_ukr.h" +// -- Level-3 small/unpacked micro-kernel prototype definitions ---------------- + +// NOTE: This results in redundant prototypes for gemmsup_r and gemmsup_c +// kernels, but since they will be identical the compiler won't complain. + +#undef gemmsup_rv_ukr_name +#define gemmsup_rv_ukr_name GENARNAME(gemmsup_r) +#undef gemmsup_rg_ukr_name +#define gemmsup_rg_ukr_name GENARNAME(gemmsup_r) +#undef gemmsup_cv_ukr_name +#define gemmsup_cv_ukr_name GENARNAME(gemmsup_c) +#undef gemmsup_cg_ukr_name +#define gemmsup_cg_ukr_name GENARNAME(gemmsup_c) + +#undef gemmsup_gx_ukr_name +#define gemmsup_gx_ukr_name GENARNAME(gemmsup_g) + +// Include the small/unpacked kernel API template. +#include "bli_l3_sup_ker.h" + // -- Level-1m (packm/unpackm) kernel prototype redefinitions ------------------ #undef packm_2xk_ker_name @@ -295,14 +315,24 @@ // -- Macros to help concisely instantiate bli_func_init() --------------------- #define gen_func_init_co( func_p, opname ) \ -\ +{ \ bli_func_init( func_p, NULL, NULL, \ - PASTEMAC(c,opname), PASTEMAC(z,opname) ) + PASTEMAC(c,opname), PASTEMAC(z,opname) ); \ +} #define gen_func_init( func_p, opname ) \ -\ +{ \ bli_func_init( func_p, PASTEMAC(s,opname), PASTEMAC(d,opname), \ - PASTEMAC(c,opname), PASTEMAC(z,opname) ) + PASTEMAC(c,opname), PASTEMAC(z,opname) ); \ +} + +#define gen_sup_func_init( func0_p, func1_p, opname ) \ +{ \ + bli_func_init( func0_p, PASTEMAC(s,opname), PASTEMAC(d,opname), \ + PASTEMAC(c,opname), PASTEMAC(z,opname) ); \ + bli_func_init( func1_p, PASTEMAC(s,opname), PASTEMAC(d,opname), \ + PASTEMAC(c,opname), PASTEMAC(z,opname) ); \ +} @@ -314,9 +344,11 @@ void GENBARNAME(cntx_init) ) { blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; func_t* funcs; mbool_t* mbools; dim_t i; + void** vfuncs; // -- Clear the context ---------------------------------------------------- @@ -363,6 +395,11 @@ void GENBARNAME(cntx_init) funcs = bli_cntx_l3_vir_ukrs_buf( cntx ); + // NOTE: We set the virtual micro-kernel slots to contain the addresses + // of the native micro-kernels. In general, the ukernels in the virtual + // ukernel slots are always called, and if the function called happens to + // be a virtual micro-kernel, it will then know to find its native + // ukernel in the native ukernel slots. gen_func_init( &funcs[ BLIS_GEMM_UKR ], gemm_ukr_name ); gen_func_init( &funcs[ BLIS_GEMMTRSM_L_UKR ], gemmtrsm_l_ukr_name ); gen_func_init( &funcs[ BLIS_GEMMTRSM_U_UKR ], gemmtrsm_u_ukr_name ); @@ -388,6 +425,91 @@ void GENBARNAME(cntx_init) bli_mbool_init( &mbools[ BLIS_TRSM_U_UKR ], FALSE, FALSE, FALSE, FALSE ); + // -- Set level-3 small/unpacked thresholds -------------------------------- + + // NOTE: The default thresholds are set very low so that the sup framework + // only actives for exceedingly small dimensions. If a sub-configuration + // registers optimized sup kernels, then that sub-configuration should also + // register new (probably larger) thresholds that are almost surely more + // appropriate that these default values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 0, 0, 0, 0 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 0, 0, 0, 0 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 0, 0, 0, 0 ); + + // Initialize the context with the default thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + + // -- Set level-3 small/unpacked handlers ---------------------------------- + + vfuncs = bli_cntx_l3_sup_handlers_buf( cntx ); + + // Initialize all of the function pointers to NULL; + for ( i = 0; i < BLIS_NUM_LEVEL3_OPS; ++i ) vfuncs[ i ] = NULL; + + // The level-3 sup handlers are oapi-based, so we only set one slot per + // operation. + + // Set the gemm slot to the default gemm sup handler. + vfuncs[ BLIS_GEMM ] = bli_gemmsup_ref; + + + // -- Set level-3 small/unpacked micro-kernels and preferences ------------- + + funcs = bli_cntx_l3_sup_kers_buf( cntx ); + mbools = bli_cntx_l3_sup_kers_prefs_buf( cntx ); + +#if 0 + // Adhere to the small/unpacked ukernel mappings: + // - rv -> rrr, rcr + // - rg -> rrc, rcc + // - cv -> ccr, ccc + // - cg -> crr, crc + gen_sup_func_init( &funcs[ BLIS_RRR ], + &funcs[ BLIS_RCR ], gemmsup_rv_ukr_name ); + gen_sup_func_init( &funcs[ BLIS_RRC ], + &funcs[ BLIS_RCC ], gemmsup_rg_ukr_name ); + gen_sup_func_init( &funcs[ BLIS_CCR ], + &funcs[ BLIS_CCC ], gemmsup_cv_ukr_name ); + gen_sup_func_init( &funcs[ BLIS_CRR ], + &funcs[ BLIS_CRC ], gemmsup_cg_ukr_name ); +#endif + gen_func_init( &funcs[ BLIS_RRR ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_RRC ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_RCR ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_RCC ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_CRR ], gemmsup_cv_ukr_name ); + gen_func_init( &funcs[ BLIS_CRC ], gemmsup_cv_ukr_name ); + gen_func_init( &funcs[ BLIS_CCR ], gemmsup_cv_ukr_name ); + gen_func_init( &funcs[ BLIS_CCC ], gemmsup_cv_ukr_name ); + + // Register the general-stride/generic ukernel to the "catch-all" slot + // associated with the BLIS_XXX enum value. This slot will be queried if + // *any* operand is stored with general stride. + gen_func_init( &funcs[ BLIS_XXX ], gemmsup_gx_ukr_name ); + + + // Set the l3 sup ukernel storage preferences. + bli_mbool_init( &mbools[ BLIS_RRR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_RRC ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_RCR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_RCC ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CRR ], FALSE, FALSE, FALSE, FALSE ); + bli_mbool_init( &mbools[ BLIS_CRC ], FALSE, FALSE, FALSE, FALSE ); + bli_mbool_init( &mbools[ BLIS_CCR ], FALSE, FALSE, FALSE, FALSE ); + bli_mbool_init( &mbools[ BLIS_CCC ], FALSE, FALSE, FALSE, FALSE ); + + bli_mbool_init( &mbools[ BLIS_XXX ], FALSE, FALSE, FALSE, FALSE ); + + // -- Set level-1f kernels ------------------------------------------------- funcs = bli_cntx_l1f_kers_buf( cntx ); diff --git a/sandbox/ref99/cntl/blx_gemm_cntl.c b/sandbox/ref99/cntl/blx_gemm_cntl.c index b26096855..33c97716a 100644 --- a/sandbox/ref99/cntl/blx_gemm_cntl.c +++ b/sandbox/ref99/cntl/blx_gemm_cntl.c @@ -55,9 +55,9 @@ cntl_t* blx_gemmbp_cntl_create pack_t schema_b ) { - void* macro_kernel_fp; - void* packa_fp; - void* packb_fp; + void_fp macro_kernel_fp; + void_fp packa_fp; + void_fp packb_fp; macro_kernel_fp = blx_gemm_ker_var2; @@ -158,7 +158,7 @@ cntl_t* blx_gemm_cntl_create_node ( opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, cntl_t* sub_node ) { diff --git a/sandbox/ref99/cntl/blx_gemm_cntl.h b/sandbox/ref99/cntl/blx_gemm_cntl.h index 80c26b8ac..223e8c8a9 100644 --- a/sandbox/ref99/cntl/blx_gemm_cntl.h +++ b/sandbox/ref99/cntl/blx_gemm_cntl.h @@ -62,7 +62,7 @@ cntl_t* blx_gemm_cntl_create_node ( opid_t family, bszid_t bszid, - void* var_func, + void_fp var_func, cntl_t* sub_node ); diff --git a/sandbox/ref99/cntl/blx_packm_cntl.c b/sandbox/ref99/cntl/blx_packm_cntl.c index 85a7c8578..2c2ba66ce 100644 --- a/sandbox/ref99/cntl/blx_packm_cntl.c +++ b/sandbox/ref99/cntl/blx_packm_cntl.c @@ -36,8 +36,8 @@ cntl_t* blx_packm_cntl_create_node ( - void* var_func, - void* packm_var_func, + void_fp var_func, + void_fp packm_var_func, bszid_t bmid_m, bszid_t bmid_n, bool_t does_invert_diag, diff --git a/sandbox/ref99/cntl/blx_packm_cntl.h b/sandbox/ref99/cntl/blx_packm_cntl.h index fbba97e1c..8776b913c 100644 --- a/sandbox/ref99/cntl/blx_packm_cntl.h +++ b/sandbox/ref99/cntl/blx_packm_cntl.h @@ -34,8 +34,8 @@ cntl_t* blx_packm_cntl_create_node ( - void* var_func, - void* packm_var_func, + void_fp var_func, + void_fp packm_var_func, bszid_t bmid_m, bszid_t bmid_n, bool_t does_invert_diag, diff --git a/so_version b/so_version index f851fd84b..79d9ce5e8 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ 2 -0.0 +1.0 diff --git a/test/1m4m/Makefile b/test/1m4m/Makefile new file mode 100644 index 000000000..74c0804ca --- /dev/null +++ b/test/1m4m/Makefile @@ -0,0 +1,515 @@ +#!/bin/bash +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2018, Advanced Micro Devices, Inc. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# +# Makefile +# +# Field G. Van Zee +# +# Makefile for standalone BLIS test drivers. +# + +# +# --- Makefile PHONY target definitions ---------------------------------------- +# + +.PHONY: all \ + clean cleanx + + + +# +# --- Determine makefile fragment location ------------------------------------- +# + +# Comments: +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +ifneq ($(strip $(BLIS_INSTALL_PATH)),) +LIB_PATH := $(BLIS_INSTALL_PATH)/lib +INC_PATH := $(BLIS_INSTALL_PATH)/include/blis +SHARE_PATH := $(BLIS_INSTALL_PATH)/share/blis +else +DIST_PATH := ../.. +LIB_PATH = ../../lib/$(CONFIG_NAME) +INC_PATH = ../../include/$(CONFIG_NAME) +SHARE_PATH := ../.. +endif + + + +# +# --- Include common makefile definitions -------------------------------------- +# + +# Include the common makefile fragment. +-include $(SHARE_PATH)/common.mk + + + +# +# --- BLAS implementations ----------------------------------------------------- +# + +# BLAS library path(s). This is where the BLAS libraries reside. +HOME_LIB_PATH := $(HOME)/flame/lib + +# OpenBLAS +OPENBLAS_LIB := $(HOME_LIB_PATH)/libopenblas.a +OPENBLASP_LIB := $(HOME_LIB_PATH)/libopenblasp.a + +# ATLAS +#ATLAS_LIB := $(HOME_LIB_PATH)/libf77blas.a \ +# $(HOME_LIB_PATH)/libatlas.a + +# Eigen +EIGEN_INC := $(HOME)/flame/eigen/include/eigen3 +EIGEN_LIB := $(HOME_LIB_PATH)/libeigen_blas_static.a +EIGENP_LIB := $(EIGEN_LIB) + +# MKL +MKL_LIB_PATH := $(HOME)/intel/mkl/lib/intel64 +MKL_LIB := -L$(MKL_LIB_PATH) \ + -lmkl_intel_lp64 \ + -lmkl_core \ + -lmkl_sequential \ + -lpthread -lm -ldl +#MKLP_LIB := -L$(MKL_LIB_PATH) \ +# -lmkl_intel_thread \ +# -lmkl_core \ +# -lmkl_intel_ilp64 \ +# -L$(ICC_LIB_PATH) \ +# -liomp5 +MKLP_LIB := -L$(MKL_LIB_PATH) \ + -lmkl_intel_lp64 \ + -lmkl_core \ + -lmkl_gnu_thread \ + -lpthread -lm -ldl -fopenmp + #-L$(ICC_LIB_PATH) \ + #-lgomp + +VENDOR_LIB := $(MKL_LIB) +VENDORP_LIB := $(MKLP_LIB) + + +# +# --- Problem size definitions ------------------------------------------------- +# + +# Single core (single-threaded) +PS_BEGIN := 48 +PS_MAX := 2400 +PS_INC := 48 + +# Single-socket (multithreaded) +P1_BEGIN := 96 +P1_MAX := 4800 +P1_INC := 96 + +# Dual-socket (multithreaded) +P2_BEGIN := 144 +P2_MAX := 7200 +P2_INC := 144 + + +# +# --- General build definitions ------------------------------------------------ +# + +TEST_SRC_PATH := . +TEST_OBJ_PATH := . + +# Gather all local object files. +TEST_OBJS := $(sort $(patsubst $(TEST_SRC_PATH)/%.c, \ + $(TEST_OBJ_PATH)/%.o, \ + $(wildcard $(TEST_SRC_PATH)/*.c))) + +# Override the value of CINCFLAGS so that the value of CFLAGS returned by +# get-user-cflags-for() is not cluttered up with include paths needed only +# while building BLIS. +CINCFLAGS := -I$(INC_PATH) + +# Use the "framework" CFLAGS for the configuration family. +CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) + +# Add local header paths to CFLAGS. +CFLAGS += -I$(TEST_SRC_PATH) + +# Locate the libblis library to which we will link. +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) + +# Define a set of CFLAGS for use with C++ and Eigen. +CXXFLAGS := $(subst -std=c99,-std=c++11,$(CFLAGS)) +CXXFLAGS += -I$(EIGEN_INC) + +# Create a copy of CXXFLAGS without -fopenmp in order to disable multithreading. +CXXFLAGS_ST := -march=native $(subst -fopenmp,,$(CXXFLAGS)) +CXXFLAGS_MT := -march=native $(CXXFLAGS) + + +# Which library? +BLI_DEF := -DBLIS +BLA_DEF := -DBLAS +EIG_DEF := -DEIGEN + +# Complex implementation type +D3MHW := -DIND=BLIS_3MH +D3M1 := -DIND=BLIS_3M1 +D4MHW := -DIND=BLIS_4MH +D4M1B := -DIND=BLIS_4M1B +D4M1A := -DIND=BLIS_4M1A +D1M := -DIND=BLIS_1M +DNAT := -DIND=BLIS_NAT + +# Implementation string +#STR_3MHW := -DSTR=\"3mhw\" +#STR_3M1 := -DSTR=\"3m1\" +#STR_4MHW := -DSTR=\"4mhw\" +#STR_4M1B := -DSTR=\"4m1b\" +STR_4M1A := -DSTR=\"4m1a_blis\" +STR_1M := -DSTR=\"1m_blis\" +STR_NAT := -DSTR=\"asm_blis\" +STR_OBL := -DSTR=\"openblas\" +STR_EIG := -DSTR=\"eigen\" +STR_VEN := -DSTR=\"vendor\" + +# Single or multithreaded string +STR_ST := -DTHR_STR=\"st\" +STR_1S := -DTHR_STR=\"1s\" +STR_2S := -DTHR_STR=\"2s\" + +# Problem size specification +PDEF_ST := -DP_BEGIN=$(PS_BEGIN) -DP_INC=$(PS_INC) -DP_MAX=$(PS_MAX) +PDEF_1S := -DP_BEGIN=$(P1_BEGIN) -DP_INC=$(P1_INC) -DP_MAX=$(P1_MAX) +PDEF_2S := -DP_BEGIN=$(P2_BEGIN) -DP_INC=$(P2_INC) -DP_MAX=$(P2_MAX) + + + +# +# --- Targets/rules ------------------------------------------------------------ +# + +all: all-st all-1s all-2s +blis: blis-st blis-1s blis-2s +openblas: openblas-st openblas-1s openblas-2s +eigen: eigen-st eigen-1s eigen-2s +vendor: vendor-st vendor-1s vendor-2s +mkl: vendor +armpl: vendor + +all-st: blis-st openblas-st mkl-st +all-1s: blis-1s openblas-1s mkl-1s +all-2s: blis-2s openblas-2s mkl-2s + +blis-st: blis-nat-st blis-1m-st blis-4m1a-st +blis-1s: blis-nat-1s blis-1m-1s blis-4m1a-1s +blis-2s: blis-nat-2s blis-1m-2s blis-4m1a-2s + +#blis-ind: blis-ind-st blis-ind-mt +blis-nat: blis-nat-st blis-nat-1s blis-nat-2s +blis-1m: blis-1m-st blis-1m-1s blis-1m-2s +blis-4m1a: blis-4m1a-st blis-4m1a-1s blis-4m1a-2s + +# Define the datatypes, operations, and implementations. +DTS := s d c z +OPS := gemm +BIMPLS := asm_blis 4m1a_blis 1m_blis openblas vendor +EIMPLS := eigen + +# Define functions to construct object filenames from the datatypes and +# operations given an implementation. We define one function for single- +# threaded, single-socket, and dual-socket filenames. +get-st-objs = $(foreach dt,$(DTS),$(foreach op,$(OPS),test_$(dt)$(op)_$(PS_MAX)_$(1)_st.o)) +get-1s-objs = $(foreach dt,$(DTS),$(foreach op,$(OPS),test_$(dt)$(op)_$(P1_MAX)_$(1)_1s.o)) +get-2s-objs = $(foreach dt,$(DTS),$(foreach op,$(OPS),test_$(dt)$(op)_$(P2_MAX)_$(1)_2s.o)) + +# Construct object and binary names for single-threaded, single-socket, and +# dual-socket files for BLIS, OpenBLAS, and a vendor library (e.g. MKL). +BLIS_1M_ST_OBJS := $(call get-st-objs,1m_blis) +BLIS_1M_ST_BINS := $(patsubst %.o,%.x,$(BLIS_1M_ST_OBJS)) +BLIS_1M_1S_OBJS := $(call get-1s-objs,1m_blis) +BLIS_1M_1S_BINS := $(patsubst %.o,%.x,$(BLIS_1M_1S_OBJS)) +BLIS_1M_2S_OBJS := $(call get-2s-objs,1m_blis) +BLIS_1M_2S_BINS := $(patsubst %.o,%.x,$(BLIS_1M_2S_OBJS)) + +BLIS_4M1A_ST_OBJS := $(call get-st-objs,4m1a_blis) +BLIS_4M1A_ST_BINS := $(patsubst %.o,%.x,$(BLIS_4M1A_ST_OBJS)) +BLIS_4M1A_1S_OBJS := $(call get-1s-objs,4m1a_blis) +BLIS_4M1A_1S_BINS := $(patsubst %.o,%.x,$(BLIS_4M1A_1S_OBJS)) +BLIS_4M1A_2S_OBJS := $(call get-2s-objs,4m1a_blis) +BLIS_4M1A_2S_BINS := $(patsubst %.o,%.x,$(BLIS_4M1A_2S_OBJS)) + +BLIS_NAT_ST_OBJS := $(call get-st-objs,asm_blis) +BLIS_NAT_ST_BINS := $(patsubst %.o,%.x,$(BLIS_NAT_ST_OBJS)) +BLIS_NAT_1S_OBJS := $(call get-1s-objs,asm_blis) +BLIS_NAT_1S_BINS := $(patsubst %.o,%.x,$(BLIS_NAT_1S_OBJS)) +BLIS_NAT_2S_OBJS := $(call get-2s-objs,asm_blis) +BLIS_NAT_2S_BINS := $(patsubst %.o,%.x,$(BLIS_NAT_2S_OBJS)) + +OPENBLAS_ST_OBJS := $(call get-st-objs,openblas) +OPENBLAS_ST_BINS := $(patsubst %.o,%.x,$(OPENBLAS_ST_OBJS)) +OPENBLAS_1S_OBJS := $(call get-1s-objs,openblas) +OPENBLAS_1S_BINS := $(patsubst %.o,%.x,$(OPENBLAS_1S_OBJS)) +OPENBLAS_2S_OBJS := $(call get-2s-objs,openblas) +OPENBLAS_2S_BINS := $(patsubst %.o,%.x,$(OPENBLAS_2S_OBJS)) + +EIGEN_ST_OBJS := $(call get-st-objs,eigen) +EIGEN_ST_BINS := $(patsubst %.o,%.x,$(EIGEN_ST_OBJS)) +EIGEN_1S_OBJS := $(call get-1s-objs,eigen) +EIGEN_1S_BINS := $(patsubst %.o,%.x,$(EIGEN_1S_OBJS)) +EIGEN_2S_OBJS := $(call get-2s-objs,eigen) +EIGEN_2S_BINS := $(patsubst %.o,%.x,$(EIGEN_2S_OBJS)) + +VENDOR_ST_OBJS := $(call get-st-objs,vendor) +VENDOR_ST_BINS := $(patsubst %.o,%.x,$(VENDOR_ST_OBJS)) +VENDOR_1S_OBJS := $(call get-1s-objs,vendor) +VENDOR_1S_BINS := $(patsubst %.o,%.x,$(VENDOR_1S_OBJS)) +VENDOR_2S_OBJS := $(call get-2s-objs,vendor) +VENDOR_2S_BINS := $(patsubst %.o,%.x,$(VENDOR_2S_OBJS)) + +# Define some targets associated with the above object/binary files. +blis-nat-st: $(BLIS_NAT_ST_BINS) +blis-nat-1s: $(BLIS_NAT_1S_BINS) +blis-nat-2s: $(BLIS_NAT_2S_BINS) + +blis-1m-st: $(BLIS_1M_ST_BINS) +blis-1m-1s: $(BLIS_1M_1S_BINS) +blis-1m-2s: $(BLIS_1M_2S_BINS) + +blis-4m1a-st: $(BLIS_4M1A_ST_BINS) +blis-4m1a-1s: $(BLIS_4M1A_1S_BINS) +blis-4m1a-2s: $(BLIS_4M1A_2S_BINS) + +openblas-st: $(OPENBLAS_ST_BINS) +openblas-1s: $(OPENBLAS_1S_BINS) +openblas-2s: $(OPENBLAS_2S_BINS) + +eigen-st: $(EIGEN_ST_BINS) +eigen-1s: $(EIGEN_1S_BINS) +eigen-2s: $(EIGEN_2S_BINS) + +vendor-st: $(VENDOR_ST_BINS) +vendor-1s: $(VENDOR_1S_BINS) +vendor-2s: $(VENDOR_2S_BINS) + +mkl-st: vendor-st +mkl-1s: vendor-1s +mkl-2s: vendor-2s + +armpl-st: vendor-st +armpl-1s: vendor-1s +armpl-2s: vendor-2s + +# Mark the object files as intermediate so that make will remove them +# automatically after building the binaries on which they depend. +.INTERMEDIATE: $(BLIS_NAT_ST_OBJS) $(BLIS_NAT_1S_OBJS) $(BLIS_NAT_2S_OBJS) +.INTERMEDIATE: $(BLIS_1M_ST_OBJS) $(BLIS_1M_1S_OBJS) $(BLIS_1M_2S_OBJS) +.INTERMEDIATE: $(BLIS_4M1A_ST_OBJS) $(BLIS_4M1A_1S_OBJS) $(BLIS_4M1A_2S_OBJS) +.INTERMEDIATE: $(OPENBLAS_ST_OBJS) $(OPENBLAS_1S_OBJS) $(OPENBLAS_2S_OBJS) +.INTERMEDIATE: $(EIGEN_ST_OBJS) $(EIGEN_1S_OBJS) $(EIGEN_2S_OBJS) +.INTERMEDIATE: $(VENDOR_ST_OBJS) $(VENDOR_1S_OBJS) $(VENDOR_2S_OBJS) + + +# --Object file rules -- + +#$(TEST_OBJ_PATH)/%.o: $(TEST_SRC_PATH)/%.c +# $(CC) $(CFLAGS) -c $< -o $@ + +# A function to return the datatype cpp macro def from the datatype +# character. +get-dt-cpp = $(strip \ + $(if $(findstring s,$(1)),-DDT=BLIS_FLOAT -DIS_FLOAT,\ + $(if $(findstring d,$(1)),-DDT=BLIS_DOUBLE -DIS_DOUBLE,\ + $(if $(findstring c,$(1)),-DDT=BLIS_SCOMPLEX -DIS_SCOMPLEX,\ + -DDT=BLIS_DCOMPLEX -DIS_DCOMPLEX)))) + +get-in-cpp = $(strip \ + $(if $(findstring 1m_blis,$(1)),-DIND=BLIS_1M,\ + $(if $(findstring 4m1a_blis,$(1)),-DIND=BLIS_4M1A,\ + -DIND=BLIS_NAT))) + +# A function to return other cpp macros that help the test driver +# identify the implementation. +#get-bl-cpp = $(strip \ +# $(if $(findstring blis,$(1)),$(STR_NAT) $(BLI_DEF),\ +# $(if $(findstring openblas,$(1)),$(STR_OBL) $(BLA_DEF),\ +# $(if $(findstring eigen,$(1)),$(STR_EIG) $(EIG_DEF),\ +# $(STR_VEN) $(BLA_DEF))))) + +get-bl-cpp = $(strip \ + $(if $(findstring 1m_blis,$(1)),$(STR_1M) $(BLI_DEF),\ + $(if $(findstring 4m1a_blis,$(1)),$(STR_4M1A) $(BLI_DEF),\ + $(if $(findstring asm_blis,$(1)),$(STR_NAT) $(BLI_DEF),\ + $(if $(findstring openblas,$(1)),$(STR_OBL) $(BLA_DEF),\ + $(if $(and $(findstring eigen,$(1)),\ + $(findstring gemm,$(2))),\ + $(STR_EIG) $(EIG_DEF),\ + $(if $(findstring eigen,$(1)),\ + $(STR_EIG) $(BLA_DEF),\ + $(STR_VEN) $(BLA_DEF)))))))) + + +# Rules for BLIS and BLAS libraries. +define make-st-rule +test_$(1)$(2)_$(PS_MAX)_$(3)_st.o: test_$(op).c Makefile + $(CC) $(CFLAGS) $(PDEF_ST) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(call get-in-cpp,$(3)) $(STR_ST) -c $$< -o $$@ +endef + +define make-1s-rule +test_$(1)$(2)_$(P1_MAX)_$(3)_1s.o: test_$(op).c Makefile + $(CC) $(CFLAGS) $(PDEF_1S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(call get-in-cpp,$(3)) $(STR_1S) -c $$< -o $$@ +endef + +define make-2s-rule +test_$(1)$(2)_$(P2_MAX)_$(3)_2s.o: test_$(op).c Makefile + $(CC) $(CFLAGS) $(PDEF_2S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(call get-in-cpp,$(3)) $(STR_2S) -c $$< -o $$@ +endef + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(BIMPLS),$(eval $(call make-st-rule,$(dt),$(op),$(im)))))) + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(BIMPLS),$(eval $(call make-1s-rule,$(dt),$(op),$(im)))))) + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(BIMPLS),$(eval $(call make-2s-rule,$(dt),$(op),$(im)))))) + +# Rules for Eigen. +define make-eigst-rule +test_$(1)$(2)_$(PS_MAX)_$(3)_st.o: test_$(op).c Makefile + $(CXX) $(CXXFLAGS_ST) $(PDEF_ST) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_ST) -c $$< -o $$@ +endef + +define make-eig1s-rule +test_$(1)$(2)_$(P1_MAX)_$(3)_1s.o: test_$(op).c Makefile + $(CXX) $(CXXFLAGS_MT) $(PDEF_1S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_1S) -c $$< -o $$@ +endef + +define make-eig2s-rule +test_$(1)$(2)_$(P2_MAX)_$(3)_2s.o: test_$(op).c Makefile + $(CXX) $(CXXFLAGS_MT) $(PDEF_2S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_2S) -c $$< -o $$@ +endef + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(EIMPLS),$(eval $(call make-eigst-rule,$(dt),$(op),$(im)))))) + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(EIMPLS),$(eval $(call make-eig1s-rule,$(dt),$(op),$(im)))))) + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(EIMPLS),$(eval $(call make-eig2s-rule,$(dt),$(op),$(im)))))) + + +# -- Executable file rules -- + +# NOTE: For the BLAS test drivers, we place the BLAS libraries before BLIS +# on the link command line in case BLIS was configured with the BLAS +# compatibility layer. This prevents BLIS from inadvertently getting called +# for the BLAS routines we are trying to test with. + +test_%_$(PS_MAX)_1m_blis_st.x: test_%_$(PS_MAX)_1m_blis_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P1_MAX)_1m_blis_1s.x: test_%_$(P1_MAX)_1m_blis_1s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P2_MAX)_1m_blis_2s.x: test_%_$(P2_MAX)_1m_blis_2s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +test_%_$(PS_MAX)_4m1a_blis_st.x: test_%_$(PS_MAX)_4m1a_blis_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P1_MAX)_4m1a_blis_1s.x: test_%_$(P1_MAX)_4m1a_blis_1s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P2_MAX)_4m1a_blis_2s.x: test_%_$(P2_MAX)_4m1a_blis_2s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +test_%_$(PS_MAX)_asm_blis_st.x: test_%_$(PS_MAX)_asm_blis_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P1_MAX)_asm_blis_1s.x: test_%_$(P1_MAX)_asm_blis_1s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P2_MAX)_asm_blis_2s.x: test_%_$(P2_MAX)_asm_blis_2s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +test_%_$(PS_MAX)_openblas_st.x: test_%_$(PS_MAX)_openblas_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(OPENBLAS_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P1_MAX)_openblas_1s.x: test_%_$(P1_MAX)_openblas_1s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(OPENBLASP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P2_MAX)_openblas_2s.x: test_%_$(P2_MAX)_openblas_2s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(OPENBLASP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +test_%_$(PS_MAX)_eigen_st.x: test_%_$(PS_MAX)_eigen_st.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(EIGEN_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P1_MAX)_eigen_1s.x: test_%_$(P1_MAX)_eigen_1s.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(EIGENP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P2_MAX)_eigen_2s.x: test_%_$(P2_MAX)_eigen_2s.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(EIGENP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +test_%_$(PS_MAX)_vendor_st.x: test_%_$(PS_MAX)_vendor_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDOR_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P1_MAX)_vendor_1s.x: test_%_$(P1_MAX)_vendor_1s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDORP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P2_MAX)_vendor_2s.x: test_%_$(P2_MAX)_vendor_2s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDORP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +# -- Clean rules -- + +clean: cleanx + +cleanx: + - $(RM_F) *.o *.x + diff --git a/test/1m4m/runme.sh b/test/1m4m/runme.sh new file mode 100755 index 000000000..d79d53925 --- /dev/null +++ b/test/1m4m/runme.sh @@ -0,0 +1,242 @@ +#!/bin/bash + +# File pefixes. +exec_root="test" +out_root="output" +delay=0.1 + +#sys="blis" +#sys="stampede2" +sys="lonestar5" +#sys="ul252" +#sys="ul264" + +# Bind threads to processors. +#export OMP_PROC_BIND=true +#export GOMP_CPU_AFFINITY="0 2 4 6 8 10 12 14 16 18 20 22 1 3 5 7 9 11 13 15 17 19 21 23" +#export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103" + +if [ ${sys} = "blis" ]; then + + export GOMP_CPU_AFFINITY="0 1 2 3" + + threads="jc1ic1jr1_2400 + jc2ic3jr2_6000 + jc4ic3jr2_8000" + +elif [ ${sys} = "stampede2" ]; then + + echo "Need to set GOMP_CPU_AFFINITY." + exit 1 + + threads="jc1ic1jr1_2400 + jc4ic6jr1_6000 + jc4ic12jr1_8000" + +elif [ ${sys} = "lonestar5" ]; then + + export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23" + + # A hack to use libiomp5 with gcc. + #export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/apps/intel/16.0.1.150/compilers_and_libraries_2016.1.150/linux/compiler/lib/intel64" + + #threads="jc1ic1jr1_2400 + # jc2ic3jr2_4800 + # jc4ic3jr2_9600" + threads="jc1ic1jr1_2400 + jc4ic3jr2_7200" + threads="jc4ic3jr2_7200" + +elif [ ${sys} = "ul252" ]; then + + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" + export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51" + + threads="jc1ic1jr1_2400 + jc2ic13jr1_6000 + jc4ic13jr1_8000" + +elif [ ${sys} = "ul264" ]; then + + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" + export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63" + + threads="jc1ic1jr1_2400 + jc1ic8jr4_6000 + jc2ic8jr4_8000" + +fi + +# Datatypes to test. +test_dts="s d c z" + +# Operations to test. +#test_ops="gemm hemm herk trmm trsm" +test_ops="gemm" + +# Implementations to test. +#impls="blis" +#impls="other" +#impls="eigen" +impls="all" + +if [ "${impls}" = "blis" ]; then + + test_impls="asm_blis" + +elif [ "${impls}" = "eigen" ]; then + + test_impls="eigen" + +elif [ "${impls}" = "other" ]; then + + test_impls="openblas vendor" + +elif [ "${impls}" = "eigen" ]; then + + test_impls="eigen" + +else + + test_impls="openblas vendor asm_blis 4m1a_blis 1m_blis" + #test_impls="openblas" + #test_impls="asm_blis 4m1a_blis 1m_blis" + #test_impls="asm_blis 1m_blis" +fi + +# Save a copy of GOMP_CPU_AFFINITY so that if we have to unset it, we can +# restore the value. +GOMP_CPU_AFFINITYsave=${GOMP_CPU_AFFINITY} + + +# First perform real test cases. +for th in ${threads}; do + + # Start with one way of parallelism in each loop. We will now begin + # parsing the 'th' variable to update one or more of these threading + # parameters. + jc_nt=1; pc_nt=1; ic_nt=1; jr_nt=1; ir_nt=1 + + # Strip everything before and after the underscore so that what remains + # is the problem size and threading parameter string, respectively. + psize=${th##*_}; thinfo=${th%%_*} + + # Identify each threading parameter and insert a space before it. + thsep=$(echo -e ${thinfo} | sed -e "s/\([jip][cr]\)/ \1/g" ) + + nt=1 + + for loopnum in ${thsep}; do + + # Given the current string, which identifies a loop and the + # number of ways of parallelism for that loop, strip out + # the ways and loop separately to identify each. + loop=$(echo -e ${loopnum} | sed -e "s/[0-9]//g" ) + num=$(echo -e ${loopnum} | sed -e "s/[a-z]//g" ) + + # Construct a string that we can evaluate to set the number + # of ways of parallelism for the current loop. + loop_nt_eq_num="${loop}_nt=${num}" + + # Update the total number of threads. + nt=$(expr ${nt} \* ${num}) + + # Evaluate the string to assign the ways to the variable. + eval ${loop_nt_eq_num} + + done + + echo "Switching to: jc${jc_nt} pc${pc_nt} ic${ic_nt} jr${jr_nt} ir${ir_nt} (nt = ${nt}) p_max${psize}" + + + for dt in ${test_dts}; do + + for im in ${test_impls}; do + + if [ "${dt}" = "s" -o "${dt}" = "d" ] && \ + [ "${im}" = "1m_blis" -o "${im}" = "4m1a_blis" ]; then + continue + fi + + for op in ${test_ops}; do + + # Eigen does not support multithreading for hemm, herk, trmm, + # or trsm. So if we're getting ready to execute an Eigen driver + # for one of these operations and nt > 1, we skip this test. + if [ "${im}" = "eigen" ] && \ + [ "${op}" != "gemm" ] && \ + [ "${nt}" != "1" ]; then + continue; + fi + + # Find the threading suffix by probing the executable. + binname=$(ls ${exec_root}_${dt}${op}_${psize}_${im}_*.x) + suf_ext=${binname##*_} + suf=${suf_ext%%.*} + + #echo "found file: ${binname} with suffix ${suf}" + + # Set the number of threads according to th. + if [ "${suf}" = "1s" ] || [ "${suf}" = "2s" ]; then + + # Set the threading parameters based on the implementation + # that we are preparing to run. + if [ "${im}" = "asm_blis" ]; then + unset OMP_NUM_THREADS + export BLIS_JC_NT=${jc_nt} + export BLIS_PC_NT=${pc_nt} + export BLIS_IC_NT=${ic_nt} + export BLIS_JR_NT=${jr_nt} + export BLIS_IR_NT=${ir_nt} + elif [ "${im}" = "openblas" ]; then + unset OMP_NUM_THREADS + export OPENBLAS_NUM_THREADS=${nt} + elif [ "${im}" = "eigen" ]; then + export OMP_NUM_THREADS=${nt} + elif [ "${im}" = "vendor" ]; then + unset OMP_NUM_THREADS + export MKL_NUM_THREADS=${nt} + fi + export nt_use=${nt} + + # Multithreaded OpenBLAS seems to have a problem running + # properly if GOMP_CPU_AFFINITY is set. So we temporarily + # unset it here if we are about to execute OpenBLAS, but + # otherwise restore it. + if [ ${im} = "openblas" ]; then + unset GOMP_CPU_AFFINITY + else + export GOMP_CPU_AFFINITY="${GOMP_CPU_AFFINITYsave}" + fi + else + + export BLIS_JC_NT=1 + export BLIS_PC_NT=1 + export BLIS_IC_NT=1 + export BLIS_JR_NT=1 + export BLIS_IR_NT=1 + export OMP_NUM_THREADS=1 + export OPENBLAS_NUM_THREADS=1 + export MKL_NUM_THREADS=1 + export nt_use=1 + fi + + # Construct the name of the test executable. + exec_name="${exec_root}_${dt}${op}_${psize}_${im}_${suf}.x" + + # Construct the name of the output file. + out_file="${out_root}_${suf}_${dt}${op}_${im}.m" + + #echo "Running (nt = ${nt_use}) ./${exec_name} > ${out_file}" + echo "Running ./${exec_name} > ${out_file}" + + # Run executable. + ./${exec_name} > ${out_file} + + sleep ${delay} + + done + done + done +done + diff --git a/test/1m4m/test_gemm.c b/test/1m4m/test_gemm.c new file mode 100644 index 000000000..741503f5c --- /dev/null +++ b/test/1m4m/test_gemm.c @@ -0,0 +1,425 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#ifdef EIGEN + #define BLIS_DISABLE_BLAS_DEFS + #include "blis.h" + #include + #include + using namespace Eigen; +#else + #include "blis.h" +#endif + +#define COL_STORAGE +//#define ROW_STORAGE + +//#define PRINT + +int main( int argc, char** argv ) +{ + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, n, k; + dim_t p; + dim_t p_begin, p_max, p_inc; + int m_input, n_input, k_input; + ind_t ind; + num_t dt; + char dt_ch; + int r, n_repeats; + trans_t transa; + trans_t transb; + f77_char f77_transa; + f77_char f77_transb; + + double dtime; + double dtime_save; + double gflops; + + //bli_init(); + + bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + + n_repeats = 3; + + dt = DT; + + ind = IND; + +#if 1 + p_begin = P_BEGIN; + p_max = P_MAX; + p_inc = P_INC; + + m_input = -1; + n_input = -1; + k_input = -1; +#else + p_begin = 40; + p_max = 2000; + p_inc = 40; + + m_input = -1; + n_input = -1; + k_input = -1; +#endif + + + // Supress compiler warnings about unused variable 'ind'. + ( void )ind; + +#if 0 + + cntx_t* cntx; + + ind_t ind_mod = ind; + + // A hack to use 3m1 as 1mpb (with 1m as 1mbp). + if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; + + // Initialize a context for the current induced method and datatype. + cntx = bli_gks_query_ind_cntx( ind_mod, dt ); + + // Set k to the kc blocksize for the current datatype. + k_input = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); + +#elif 0 + + #ifdef BLIS + if ( ind == BLIS_4M1A ) k_input = 128; + else if ( ind == BLIS_1M ) k_input = 128; + else k_input = 256; + #else + k_input = 192; + #endif + +#endif + + // Choose the char corresponding to the requested datatype. + if ( bli_is_float( dt ) ) dt_ch = 's'; + else if ( bli_is_double( dt ) ) dt_ch = 'd'; + else if ( bli_is_scomplex( dt ) ) dt_ch = 'c'; + else dt_ch = 'z'; + + transa = BLIS_NO_TRANSPOSE; + transb = BLIS_NO_TRANSPOSE; + + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ; + + printf( "data_%s_%cgemm_%s", THR_STR, dt_ch, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )0, + ( unsigned long )0, + ( unsigned long )0, 0.0 ); + + + //for ( p = p_begin; p <= p_max; p += p_inc ) + for ( p = p_max; p_begin <= p; p -= p_inc ) + { + + if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); + else m = ( dim_t ) m_input; + if ( n_input < 0 ) n = p / ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + if ( k_input < 0 ) k = p / ( dim_t )abs(k_input); + else k = ( dim_t ) k_input; + + bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + #ifdef COL_STORAGE + bli_obj_create( dt, m, k, 0, 0, &a ); + bli_obj_create( dt, k, n, 0, 0, &b ); + bli_obj_create( dt, m, n, 0, 0, &c ); + bli_obj_create( dt, m, n, 0, 0, &c_save ); + #else + bli_obj_create( dt, m, k, k, 1, &a ); + bli_obj_create( dt, k, n, n, 1, &b ); + bli_obj_create( dt, m, n, n, 1, &c ); + bli_obj_create( dt, m, n, n, 1, &c_save ); + #endif + + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); + + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_conjtrans( transb, &b ); + + bli_setsc( (1.0/1.0), 0.0, &alpha ); + bli_setsc( (1.0/1.0), 0.0, &beta ); + + bli_copym( &c, &c_save ); + +#ifdef BLIS + bli_ind_disable_all_dt( dt ); + bli_ind_enable_dt( ind, dt ); +#endif + +#ifdef EIGEN + double alpha_r, alpha_i; + + bli_getsc( &alpha, &alpha_r, &alpha_i ); + + void* ap = bli_obj_buffer_at_off( &a ); + void* bp = bli_obj_buffer_at_off( &b ); + void* cp = bli_obj_buffer_at_off( &c ); + + #ifdef COL_STORAGE + const int os_a = bli_obj_col_stride( &a ); + const int os_b = bli_obj_col_stride( &b ); + const int os_c = bli_obj_col_stride( &c ); + #else + const int os_a = bli_obj_row_stride( &a ); + const int os_b = bli_obj_row_stride( &b ); + const int os_c = bli_obj_row_stride( &c ); + #endif + + Stride stride_a( os_a, 1 ); + Stride stride_b( os_b, 1 ); + Stride stride_c( os_c, 1 ); + + #ifdef COL_STORAGE + #if defined(IS_FLOAT) + typedef Matrix MatrixXf_; + #elif defined (IS_DOUBLE) + typedef Matrix MatrixXd_; + #elif defined (IS_SCOMPLEX) + typedef Matrix, Dynamic, Dynamic, ColMajor> MatrixXcf_; + #elif defined (IS_DCOMPLEX) + typedef Matrix, Dynamic, Dynamic, ColMajor> MatrixXcd_; + #endif + #else + #if defined(IS_FLOAT) + typedef Matrix MatrixXf_; + #elif defined (IS_DOUBLE) + typedef Matrix MatrixXd_; + #elif defined (IS_SCOMPLEX) + typedef Matrix, Dynamic, Dynamic, RowMajor> MatrixXcf_; + #elif defined (IS_DCOMPLEX) + typedef Matrix, Dynamic, Dynamic, RowMajor> MatrixXcd_; + #endif + #endif + #if defined(IS_FLOAT) + Map > A( ( float* )ap, m, k, stride_a ); + Map > B( ( float* )bp, k, n, stride_b ); + Map > C( ( float* )cp, m, n, stride_c ); + #elif defined (IS_DOUBLE) + Map > A( ( double* )ap, m, k, stride_a ); + Map > B( ( double* )bp, k, n, stride_b ); + Map > C( ( double* )cp, m, n, stride_c ); + #elif defined (IS_SCOMPLEX) + Map > A( ( std::complex* )ap, m, k, stride_a ); + Map > B( ( std::complex* )bp, k, n, stride_b ); + Map > C( ( std::complex* )cp, m, n, stride_c ); + #elif defined (IS_DCOMPLEX) + Map > A( ( std::complex* )ap, m, k, stride_a ); + Map > B( ( std::complex* )bp, k, n, stride_b ); + Map > C( ( std::complex* )cp, m, n, stride_c ); + #endif +#endif + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + bli_copym( &c_save, &c ); + + dtime = bli_clock(); + +#ifdef PRINT + bli_printm( "a", &a, "%4.1f", "" ); + bli_printm( "b", &b, "%4.1f", "" ); + bli_printm( "c", &c, "%4.1f", "" ); +#endif + +#if defined(BLIS) + + bli_gemm( &alpha, + &a, + &b, + &beta, + &c ); + +#elif defined(EIGEN) + + C.noalias() += alpha_r * A * B; + +#else // if defined(BLAS) + + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* bp = ( float* )bli_obj_buffer( &b ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); + + sgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* bp = ( double* )bli_obj_buffer( &b ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); + + dgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* bp = ( scomplex* )bli_obj_buffer( &b ); + scomplex* betap = ( scomplex* )bli_obj_buffer( &beta ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); + + cgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* bp = ( dcomplex* )bli_obj_buffer( &b ); + dcomplex* betap = ( dcomplex* )bli_obj_buffer( &beta ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); + + zgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } +#endif + +#ifdef PRINT + bli_printm( "c after", &c, "%4.1f", "" ); + exit(1); +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 4.0; + + printf( "data_%s_%cgemm_%s", THR_STR, dt_ch, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )m, + ( unsigned long )k, + ( unsigned long )n, gflops ); + //fflush( stdout ); + + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + + //bli_finalize(); + + return 0; +} + diff --git a/test/3/Makefile b/test/3/Makefile new file mode 100644 index 000000000..1a8aa3087 --- /dev/null +++ b/test/3/Makefile @@ -0,0 +1,462 @@ +#!/bin/bash +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# +# Makefile +# +# Field G. Van Zee +# +# Makefile for standalone BLIS test drivers. +# + +# +# --- Makefile PHONY target definitions ---------------------------------------- +# + +.PHONY: all \ + clean cleanx + + + +# +# --- Determine makefile fragment location ------------------------------------- +# + +# Comments: +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +ifneq ($(strip $(BLIS_INSTALL_PATH)),) +LIB_PATH := $(BLIS_INSTALL_PATH)/lib +INC_PATH := $(BLIS_INSTALL_PATH)/include/blis +SHARE_PATH := $(BLIS_INSTALL_PATH)/share/blis +else +DIST_PATH := ../.. +LIB_PATH = ../../lib/$(CONFIG_NAME) +INC_PATH = ../../include/$(CONFIG_NAME) +SHARE_PATH := ../.. +endif + + + +# +# --- Include common makefile definitions -------------------------------------- +# + +# Include the common makefile fragment. +-include $(SHARE_PATH)/common.mk + + + +# +# --- BLAS implementations ----------------------------------------------------- +# + +# BLAS library path(s). This is where the BLAS libraries reside. +HOME_LIB_PATH := $(HOME)/flame/lib + +# OpenBLAS +OPENBLAS_LIB := $(HOME_LIB_PATH)/libopenblas.a +OPENBLASP_LIB := $(HOME_LIB_PATH)/libopenblasp.a + +# ATLAS +#ATLAS_LIB := $(HOME_LIB_PATH)/libf77blas.a \ +# $(HOME_LIB_PATH)/libatlas.a + +# Eigen +EIGEN_INC := $(HOME)/flame/eigen/include/eigen3 +EIGEN_LIB := $(HOME_LIB_PATH)/libeigen_blas_static.a +EIGENP_LIB := $(EIGEN_LIB) + +# MKL +MKL_LIB_PATH := $(HOME)/intel/mkl/lib/intel64 +MKL_LIB := -L$(MKL_LIB_PATH) \ + -lmkl_intel_lp64 \ + -lmkl_core \ + -lmkl_sequential \ + -lpthread -lm -ldl +#MKLP_LIB := -L$(MKL_LIB_PATH) \ +# -lmkl_intel_thread \ +# -lmkl_core \ +# -lmkl_intel_ilp64 \ +# -L$(ICC_LIB_PATH) \ +# -liomp5 +MKLP_LIB := -L$(MKL_LIB_PATH) \ + -lmkl_intel_lp64 \ + -lmkl_core \ + -lmkl_gnu_thread \ + -lpthread -lm -ldl -fopenmp + #-L$(ICC_LIB_PATH) \ + #-lgomp + +VENDOR_LIB := $(MKL_LIB) +VENDORP_LIB := $(MKLP_LIB) + + +# +# --- Problem size definitions ------------------------------------------------- +# + +# Single core (single-threaded) +PS_BEGIN := 48 +PS_MAX := 2400 +PS_INC := 48 + +# Single-socket (multithreaded) +P1_BEGIN := 96 +P1_MAX := 4800 +P1_INC := 96 + +# Dual-socket (multithreaded) +P2_BEGIN := 144 +P2_MAX := 7200 +P2_INC := 144 + + +# +# --- General build definitions ------------------------------------------------ +# + +TEST_SRC_PATH := . +TEST_OBJ_PATH := . + +# Gather all local object files. +TEST_OBJS := $(sort $(patsubst $(TEST_SRC_PATH)/%.c, \ + $(TEST_OBJ_PATH)/%.o, \ + $(wildcard $(TEST_SRC_PATH)/*.c))) + +# Override the value of CINCFLAGS so that the value of CFLAGS returned by +# get-user-cflags-for() is not cluttered up with include paths needed only +# while building BLIS. +CINCFLAGS := -I$(INC_PATH) + +# Use the "framework" CFLAGS for the configuration family. +CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) + +# Add local header paths to CFLAGS. +CFLAGS += -I$(TEST_SRC_PATH) + +# Locate the libblis library to which we will link. +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) + +# Define a set of CFLAGS for use with C++ and Eigen. +CXXFLAGS := $(subst -std=c99,-std=c++11,$(CFLAGS)) +CXXFLAGS += -I$(EIGEN_INC) + +# Create a copy of CXXFLAGS without -fopenmp in order to disable multithreading. +CXXFLAGS_ST := -march=native $(subst -fopenmp,,$(CXXFLAGS)) +CXXFLAGS_MT := -march=native $(CXXFLAGS) + + +# Which library? +BLI_DEF := -DBLIS +BLA_DEF := -DBLAS +EIG_DEF := -DEIGEN + +# Complex implementation type +D3MHW := -DIND=BLIS_3MH +D3M1 := -DIND=BLIS_3M1 +D4MHW := -DIND=BLIS_4MH +D4M1B := -DIND=BLIS_4M1B +D4M1A := -DIND=BLIS_4M1A +D1M := -DIND=BLIS_1M +DNAT := -DIND=BLIS_NAT + +# Implementation string +#STR_3MHW := -DSTR=\"3mhw\" +#STR_3M1 := -DSTR=\"3m1\" +#STR_4MHW := -DSTR=\"4mhw\" +#STR_4M1B := -DSTR=\"4m1b\" +#STR_4M1A := -DSTR=\"4m1a\" +#STR_1M := -DSTR=\"1m\" +STR_NAT := -DSTR=\"asm_blis\" +STR_OBL := -DSTR=\"openblas\" +STR_EIG := -DSTR=\"eigen\" +STR_VEN := -DSTR=\"vendor\" + +# Single or multithreaded string +STR_ST := -DTHR_STR=\"st\" +STR_1S := -DTHR_STR=\"1s\" +STR_2S := -DTHR_STR=\"2s\" + +# Problem size specification +PDEF_ST := -DP_BEGIN=$(PS_BEGIN) -DP_INC=$(PS_INC) -DP_MAX=$(PS_MAX) +PDEF_1S := -DP_BEGIN=$(P1_BEGIN) -DP_INC=$(P1_INC) -DP_MAX=$(P1_MAX) +PDEF_2S := -DP_BEGIN=$(P2_BEGIN) -DP_INC=$(P2_INC) -DP_MAX=$(P2_MAX) + + + +# +# --- Targets/rules ------------------------------------------------------------ +# + +all: all-st all-1s all-2s +blis: blis-st blis-1s blis-2s +openblas: openblas-st openblas-1s openblas-2s +eigen: eigen-st eigen-1s eigen-2s +vendor: vendor-st vendor-1s vendor-2s +mkl: vendor +armpl: vendor + +all-st: blis-st openblas-st mkl-st +all-1s: blis-1s openblas-1s mkl-1s +all-2s: blis-2s openblas-2s mkl-2s + +blis-st: blis-nat-st +blis-1s: blis-nat-1s +blis-2s: blis-nat-2s + +#blis-ind: blis-ind-st blis-ind-mt +blis-nat: blis-nat-st blis-nat-1s blis-nat-2s + +# Define the datatypes, operations, and implementations. +DTS := s d c z +OPS := gemm hemm herk trmm trsm +BIMPLS := asm_blis openblas vendor +EIMPLS := eigen + +# Define functions to construct object filenames from the datatypes and +# operations given an implementation. We define one function for single- +# threaded, single-socket, and dual-socket filenames. +get-st-objs = $(foreach dt,$(DTS),$(foreach op,$(OPS),test_$(dt)$(op)_$(PS_MAX)_$(1)_st.o)) +get-1s-objs = $(foreach dt,$(DTS),$(foreach op,$(OPS),test_$(dt)$(op)_$(P1_MAX)_$(1)_1s.o)) +get-2s-objs = $(foreach dt,$(DTS),$(foreach op,$(OPS),test_$(dt)$(op)_$(P2_MAX)_$(1)_2s.o)) + +# Construct object and binary names for single-threaded, single-socket, and +# dual-socket files for BLIS, OpenBLAS, and a vendor library (e.g. MKL). +BLIS_NAT_ST_OBJS := $(call get-st-objs,asm_blis) +BLIS_NAT_ST_BINS := $(patsubst %.o,%.x,$(BLIS_NAT_ST_OBJS)) +BLIS_NAT_1S_OBJS := $(call get-1s-objs,asm_blis) +BLIS_NAT_1S_BINS := $(patsubst %.o,%.x,$(BLIS_NAT_1S_OBJS)) +BLIS_NAT_2S_OBJS := $(call get-2s-objs,asm_blis) +BLIS_NAT_2S_BINS := $(patsubst %.o,%.x,$(BLIS_NAT_2S_OBJS)) + +OPENBLAS_ST_OBJS := $(call get-st-objs,openblas) +OPENBLAS_ST_BINS := $(patsubst %.o,%.x,$(OPENBLAS_ST_OBJS)) +OPENBLAS_1S_OBJS := $(call get-1s-objs,openblas) +OPENBLAS_1S_BINS := $(patsubst %.o,%.x,$(OPENBLAS_1S_OBJS)) +OPENBLAS_2S_OBJS := $(call get-2s-objs,openblas) +OPENBLAS_2S_BINS := $(patsubst %.o,%.x,$(OPENBLAS_2S_OBJS)) + +EIGEN_ST_OBJS := $(call get-st-objs,eigen) +EIGEN_ST_BINS := $(patsubst %.o,%.x,$(EIGEN_ST_OBJS)) +EIGEN_1S_OBJS := $(call get-1s-objs,eigen) +EIGEN_1S_BINS := $(patsubst %.o,%.x,$(EIGEN_1S_OBJS)) +EIGEN_2S_OBJS := $(call get-2s-objs,eigen) +EIGEN_2S_BINS := $(patsubst %.o,%.x,$(EIGEN_2S_OBJS)) + +VENDOR_ST_OBJS := $(call get-st-objs,vendor) +VENDOR_ST_BINS := $(patsubst %.o,%.x,$(VENDOR_ST_OBJS)) +VENDOR_1S_OBJS := $(call get-1s-objs,vendor) +VENDOR_1S_BINS := $(patsubst %.o,%.x,$(VENDOR_1S_OBJS)) +VENDOR_2S_OBJS := $(call get-2s-objs,vendor) +VENDOR_2S_BINS := $(patsubst %.o,%.x,$(VENDOR_2S_OBJS)) + +# Define some targets associated with the above object/binary files. +blis-nat-st: $(BLIS_NAT_ST_BINS) +blis-nat-1s: $(BLIS_NAT_1S_BINS) +blis-nat-2s: $(BLIS_NAT_2S_BINS) + +openblas-st: $(OPENBLAS_ST_BINS) +openblas-1s: $(OPENBLAS_1S_BINS) +openblas-2s: $(OPENBLAS_2S_BINS) + +eigen-st: $(EIGEN_ST_BINS) +eigen-1s: $(EIGEN_1S_BINS) +eigen-2s: $(EIGEN_2S_BINS) + +vendor-st: $(VENDOR_ST_BINS) +vendor-1s: $(VENDOR_1S_BINS) +vendor-2s: $(VENDOR_2S_BINS) + +mkl-st: vendor-st +mkl-1s: vendor-1s +mkl-2s: vendor-2s + +armpl-st: vendor-st +armpl-1s: vendor-1s +armpl-2s: vendor-2s + +# Mark the object files as intermediate so that make will remove them +# automatically after building the binaries on which they depend. +.INTERMEDIATE: $(BLIS_NAT_ST_OBJS) $(BLIS_NAT_1S_OBJS) $(BLIS_NAT_2S_OBJS) +.INTERMEDIATE: $(OPENBLAS_ST_OBJS) $(OPENBLAS_1S_OBJS) $(OPENBLAS_2S_OBJS) +.INTERMEDIATE: $(EIGEN_ST_OBJS) $(EIGEN_1S_OBJS) $(EIGEN_2S_OBJS) +.INTERMEDIATE: $(VENDOR_ST_OBJS) $(VENDOR_1S_OBJS) $(VENDOR_2S_OBJS) + + +# --Object file rules -- + +#$(TEST_OBJ_PATH)/%.o: $(TEST_SRC_PATH)/%.c +# $(CC) $(CFLAGS) -c $< -o $@ + +# A function to return the datatype cpp macro def from the datatype +# character. +get-dt-cpp = $(strip \ + $(if $(findstring s,$(1)),-DDT=BLIS_FLOAT -DIS_FLOAT,\ + $(if $(findstring d,$(1)),-DDT=BLIS_DOUBLE -DIS_DOUBLE,\ + $(if $(findstring c,$(1)),-DDT=BLIS_SCOMPLEX -DIS_SCOMPLEX,\ + -DDT=BLIS_DCOMPLEX -DIS_DCOMPLEX)))) + +# A function to return other cpp macros that help the test driver +# identify the implementation. +#get-bl-cpp = $(strip \ +# $(if $(findstring blis,$(1)),$(STR_NAT) $(BLI_DEF),\ +# $(if $(findstring openblas,$(1)),$(STR_OBL) $(BLA_DEF),\ +# $(if $(findstring eigen,$(1)),$(STR_EIG) $(EIG_DEF),\ +# $(STR_VEN) $(BLA_DEF))))) + +get-bl-cpp = $(strip \ + $(if $(findstring blis,$(1)),$(STR_NAT) $(BLI_DEF),\ + $(if $(findstring openblas,$(1)),$(STR_OBL) $(BLA_DEF),\ + $(if $(and $(findstring eigen,$(1)),\ + $(findstring gemm,$(2))),\ + $(STR_EIG) $(EIG_DEF),\ + $(if $(findstring eigen,$(1)),\ + $(STR_EIG) $(BLA_DEF),\ + $(STR_VEN) $(BLA_DEF)))))) + + +# Rules for BLIS and BLAS libraries. +define make-st-rule +test_$(1)$(2)_$(PS_MAX)_$(3)_st.o: test_$(op).c Makefile + $(CC) $(CFLAGS) $(PDEF_ST) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_ST) -c $$< -o $$@ +endef + +define make-1s-rule +test_$(1)$(2)_$(P1_MAX)_$(3)_1s.o: test_$(op).c Makefile + $(CC) $(CFLAGS) $(PDEF_1S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_1S) -c $$< -o $$@ +endef + +define make-2s-rule +test_$(1)$(2)_$(P2_MAX)_$(3)_2s.o: test_$(op).c Makefile + $(CC) $(CFLAGS) $(PDEF_2S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_2S) -c $$< -o $$@ +endef + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(BIMPLS),$(eval $(call make-st-rule,$(dt),$(op),$(im)))))) + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(BIMPLS),$(eval $(call make-1s-rule,$(dt),$(op),$(im)))))) + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(BIMPLS),$(eval $(call make-2s-rule,$(dt),$(op),$(im)))))) + +# Rules for Eigen. +define make-eigst-rule +test_$(1)$(2)_$(PS_MAX)_$(3)_st.o: test_$(op).c Makefile + $(CXX) $(CXXFLAGS_ST) $(PDEF_ST) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_ST) -c $$< -o $$@ +endef + +define make-eig1s-rule +test_$(1)$(2)_$(P1_MAX)_$(3)_1s.o: test_$(op).c Makefile + $(CXX) $(CXXFLAGS_MT) $(PDEF_1S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_1S) -c $$< -o $$@ +endef + +define make-eig2s-rule +test_$(1)$(2)_$(P2_MAX)_$(3)_2s.o: test_$(op).c Makefile + $(CXX) $(CXXFLAGS_MT) $(PDEF_2S) $(call get-dt-cpp,$(1)) $(call get-bl-cpp,$(3),$(2)) $(DNAT) $(STR_2S) -c $$< -o $$@ +endef + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(EIMPLS),$(eval $(call make-eigst-rule,$(dt),$(op),$(im)))))) + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(EIMPLS),$(eval $(call make-eig1s-rule,$(dt),$(op),$(im)))))) + +$(foreach dt,$(DTS), \ +$(foreach op,$(OPS), \ +$(foreach im,$(EIMPLS),$(eval $(call make-eig2s-rule,$(dt),$(op),$(im)))))) + + +# -- Executable file rules -- + +# NOTE: For the BLAS test drivers, we place the BLAS libraries before BLIS +# on the link command line in case BLIS was configured with the BLAS +# compatibility layer. This prevents BLIS from inadvertently getting called +# for the BLAS routines we are trying to test with. + +test_%_$(PS_MAX)_asm_blis_st.x: test_%_$(PS_MAX)_asm_blis_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P1_MAX)_asm_blis_1s.x: test_%_$(P1_MAX)_asm_blis_1s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P2_MAX)_asm_blis_2s.x: test_%_$(P2_MAX)_asm_blis_2s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +test_%_$(PS_MAX)_openblas_st.x: test_%_$(PS_MAX)_openblas_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(OPENBLAS_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P1_MAX)_openblas_1s.x: test_%_$(P1_MAX)_openblas_1s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(OPENBLASP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P2_MAX)_openblas_2s.x: test_%_$(P2_MAX)_openblas_2s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(OPENBLASP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +test_%_$(PS_MAX)_eigen_st.x: test_%_$(PS_MAX)_eigen_st.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(EIGEN_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P1_MAX)_eigen_1s.x: test_%_$(P1_MAX)_eigen_1s.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(EIGENP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P2_MAX)_eigen_2s.x: test_%_$(P2_MAX)_eigen_2s.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(EIGENP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +test_%_$(PS_MAX)_vendor_st.x: test_%_$(PS_MAX)_vendor_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDOR_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P1_MAX)_vendor_1s.x: test_%_$(P1_MAX)_vendor_1s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDORP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_$(P2_MAX)_vendor_2s.x: test_%_$(P2_MAX)_vendor_2s.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDORP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +# -- Clean rules -- + +clean: cleanx + +cleanx: + - $(RM_F) *.o *.x + diff --git a/test/3m4m/matlab/gen_opnames.m b/test/3/matlab/gen_opnames.m similarity index 100% rename from test/3m4m/matlab/gen_opnames.m rename to test/3/matlab/gen_opnames.m diff --git a/test/3m4m/matlab/plot_l3_perf.m b/test/3/matlab/plot_l3_perf.m similarity index 50% rename from test/3m4m/matlab/plot_l3_perf.m rename to test/3/matlab/plot_l3_perf.m index 9700faf01..c2ecf4d27 100644 --- a/test/3m4m/matlab/plot_l3_perf.m +++ b/test/3/matlab/plot_l3_perf.m @@ -1,9 +1,11 @@ function r_val = plot_l3_perf( opname, ... data_blis, ... data_open, ... - data_mkl, ... + data_eige, ... + data_vend, vend_str, ... nth, ... rows, cols, ... + with_eigen, ... cfreq, ... dfps, ... theid ) @@ -16,7 +18,8 @@ end % Set line properties. color_blis = 'k'; lines_blis = '-'; markr_blis = ''; color_open = 'r'; lines_open = '--'; markr_open = 'o'; -color_mkl = 'b'; lines_mkl = '--'; markr_mkl = '.'; +color_eige = 'm'; lines_eige = '-.'; markr_eige = 'x'; +color_vend = 'b'; lines_vend = '-.'; markr_vend = '.'; % Compute the peak performance in terms of the number of double flops % executable per cycle and the clock rate. @@ -30,8 +33,13 @@ max_perf_core = (flopspercycle * cfreq) * 1; % Adjust title for real domain hemm and herk. title_opname = opname; if opname(1) == 's' || opname(1) == 'd' - if strcmp( extractAfter( opname, 1 ), 'hemm' ) || ... - strcmp( extractAfter( opname, 1 ), 'herk' ) +% if strcmp( extractAfter( opname, 1 ), 'hemm' ) || ... +% strcmp( extractAfter( opname, 1 ), 'herk' ) +% title_opname(2:3) = 'sy'; +% end + opname_u = opname; opname_u(1) = '_'; + if strcmp( opname_u, '_hemm' ) || ... + strcmp( opname_u, '_herk' ) title_opname(2:3) = 'sy'; end end @@ -43,7 +51,10 @@ titlename = sprintf( titlename, title_opname ); % Set the legend strings. blis_legend = sprintf( 'BLIS' ); open_legend = sprintf( 'OpenBLAS' ); -mkl_legend = sprintf( 'MKL' ); +eige_legend = sprintf( 'Eigen' ); +%vend_legend = sprintf( 'MKL' ); +%vend_legend = sprintf( 'ARMPL' ); +vend_legend = vend_str; % Determine the final dimension. %n_points = size( data_blis, 1 ); @@ -89,59 +100,76 @@ blis_ln = line( x_axis( :, 1 ), data_blis( :, flopscol ) / nth, ... open_ln = line( x_axis( :, 1 ), data_open( :, flopscol ) / nth, ... 'Color',color_open, 'LineStyle',lines_open, ... 'LineWidth',linesize ); -mkl_ln = line( x_axis( :, 1 ), data_mkl( :, flopscol ) / nth, ... - 'Color',color_mkl, 'LineStyle',lines_mkl, ... +if data_eige(1,1) ~= -1 +eige_ln = line( x_axis( :, 1 ), data_eige( :, flopscol ) / nth, ... + 'Color',color_eige, 'LineStyle',lines_eige, ... + 'LineWidth',linesize ); +else +eige_ln = line( nan, nan, ... + 'Color',color_eige, 'LineStyle',lines_eige, ... + 'LineWidth',linesize ); +end +vend_ln = line( x_axis( :, 1 ), data_vend( :, flopscol ) / nth, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... 'LineWidth',linesize ); xlim( ax1, [x_begin x_end] ); ylim( ax1, [y_begin y_end] ); -if x_end == 10000 || x_end == 8000 +if 6000 <= x_end && x_end < 10000 x_tick2 = x_end - 2000; x_tick1 = x_tick2/2; xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 4000 <= x_end && x_end < 6000 + x_tick2 = x_end - 1000; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 2000 <= x_end && x_end < 3000 + x_tick2 = x_end - 400; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); end -if rows == 4 && cols == 5 && ... - theid == 5 - if nth == 1 - leg = legend( ... - [ ... - blis_ln ... - open_ln ... - mkl_ln ... - ], ... - blis_legend, ... - open_legend, ... - mkl_legend, ... - 'Location', legend_loc ); - set( leg,'Box','off' ); - set( leg,'Color','none' ); - set( leg,'FontSize',fontsize-3 ); - set( leg,'Units','inches' ); - %set( leg,'Position',[3.15 10.2 0.7 0.3 ] ); % 1600 1200 - set( leg,'Position',[4.20 12.7 0.7 0.3 ] ); % 2000 1500 - else - leg = legend( ... - [ ... - blis_ln ... - open_ln ... - mkl_ln ... - ], ... - blis_legend, ... - open_legend, ... - mkl_legend, ... - 'Location', legend_loc ); - set( leg,'Box','off' ); - set( leg,'Color','none' ); - set( leg,'FontSize',fontsize-3 ); - set( leg,'Units','inches' ); - %set( leg,'Position',[3.15 10.2 0.7 0.3 ] ); % 1600 1200 - set( leg,'Position',[17.60 14.30 0.7 0.3 ] ); % 2000 1500 - end +if rows == 4 && cols == 5 + if nth == 1 && theid == 3 + if with_eigen == 1 + leg = legend( [ blis_ln open_ln eige_ln vend_ln ], ... + blis_legend, open_legend, eige_legend, vend_legend, ... + 'Location', legend_loc ); + else + leg = legend( [ blis_ln open_ln vend_ln ], ... + blis_legend, open_legend, vend_legend, ... + 'Location', legend_loc ); + end + set( leg,'Box','off','Color','none','Units','inches','FontSize',fontsize-3 ); + set( leg,'Position',[11.20 12.81 0.7 0.3 ] ); % (0,2br) + %set( leg,'Position',[ 4.20 12.81 0.7 0.3 ] ); % (0,0br) + elseif nth > 1 && theid == 4 + if with_eigen == 1 + leg = legend( [ blis_ln open_ln eige_ln vend_ln ], ... + blis_legend, open_legend, eige_legend, vend_legend, ... + 'Location', legend_loc ); + else + leg = legend( [ blis_ln open_ln vend_ln ], ... + blis_legend, open_legend, vend_legend, ... + 'Location', legend_loc ); + end + set( leg,'Box','off','Color','none','Units','inches','FontSize',fontsize-3 ); + %set( leg,'Position',[7.70 12.81 0.7 0.3 ] ); % (0,1br) + %set( leg,'Position',[11.20 12.81 0.7 0.3 ] ); % (0,2br) + set( leg,'Position',[10.47 14.17 0.7 0.3 ] ); % (0,2tl) + end end + %set( leg,'Position',[ 4.20 12.75 0.7 0.3 ] ); % (0,0br) + %set( leg,'Position',[ 7.70 12.75 0.7 0.3 ] ); % (0,1br) + %set( leg,'Position',[10.47 14.28 0.7 0.3 ] ); % (0,2tl) + %set( leg,'Position',[11.20 12.75 0.7 0.3 ] ); % (0,2br) + %set( leg,'Position',[13.95 14.28 0.7 0.3 ] ); % (0,3tl) + %set( leg,'Position',[14.70 12.75 0.7 0.3 ] ); % (0,3br) + %set( leg,'Position',[17.45 14.28 0.7 0.3 ] ); % (0,4tl) + %set( leg,'Position',[18.22 12.75 0.7 0.3 ] ); % (0,4br) set( ax1,'FontSize',fontsize ); set( ax1,'TitleFontSizeMultiplier',1.0 ); % default is 1.1. diff --git a/test/3m4m/matlab/plot_panel_4x5.m b/test/3/matlab/plot_panel_4x5.m similarity index 64% rename from test/3m4m/matlab/plot_panel_4x5.m rename to test/3/matlab/plot_panel_4x5.m index 3735af74f..a5fc0b1c8 100644 --- a/test/3m4m/matlab/plot_panel_4x5.m +++ b/test/3/matlab/plot_panel_4x5.m @@ -1,7 +1,11 @@ function r_val = plot_panel_4x5( cfreq, ... dflopspercycle, ... nth, ... - dirpath ) + thr_str, ... + dirpath, ... + arch_str, ... + vend_str, ... + with_eigen ) %cfreq = 1.8; %dflopspercycle = 32; @@ -10,18 +14,13 @@ function r_val = plot_panel_4x5( cfreq, ... % results. filetemp_blis = '%s/output_%s_%s_asm_blis.m'; filetemp_open = '%s/output_%s_%s_openblas.m'; -filetemp_mkl = '%s/output_%s_%s_mkl.m'; +filetemp_eige = '%s/output_%s_%s_eigen.m'; +filetemp_vend = '%s/output_%s_%s_vendor.m'; % Create a variable name "template" for the variables contained in the % files outlined above. vartemp = 'data_%s_%s_%s( :, : )'; -if nth == 1 - thr_str = 'st'; -else - thr_str = 'mt'; -end - % Define the datatypes and operations we will be plotting. dts = [ 's' 'd' 'c' 'z' ]; ops( 1, : ) = 'gemm'; @@ -35,20 +34,20 @@ ops( 5, : ) = 'trsm'; opnames = gen_opnames( ops, dts ); n_opnames = size( opnames, 1 ); -%fig = figure; -%fig = figure('Position', [100, 100, 1600, 1200]); fig = figure('Position', [100, 100, 2000, 1500]); orient( fig, 'portrait' ); -%set(gcf,'Position',[0 0 2000 900]); set(gcf,'PaperUnits', 'inches'); -%set(gcf,'PaperSize', [16 12.4]); -%set(gcf,'PaperPosition', [0 0 16 12.4]); -set(gcf,'PaperSize', [11 15.0]); -set(gcf,'PaperPosition', [0 0 11 15.0]); -%set(gcf,'PaperPositionMode','auto'); -set(gcf,'PaperPositionMode','manual'); +if 1 == 1 % matlab + set(gcf,'PaperSize', [11 15.0]); + set(gcf,'PaperPosition', [0 0 11 15.0]); + set(gcf,'PaperPositionMode','manual'); +else % octave 4.x + set(gcf,'PaperSize', [15 19.0]); + set(gcf,'PaperPositionMode','auto'); +end set(gcf,'PaperOrientation','landscape'); + % Iterate over the list of datatype-specific operation names. for opi = 1:n_opnames %for opi = 1:1 @@ -61,45 +60,64 @@ for opi = 1:n_opnames % Construct filenames for the data files from templates. file_blis = sprintf( filetemp_blis, dirpath, thr_str, opname ); file_open = sprintf( filetemp_open, dirpath, thr_str, opname ); - file_mkl = sprintf( filetemp_mkl, dirpath, thr_str, opname ); + file_vend = sprintf( filetemp_vend, dirpath, thr_str, opname ); % Load the data files. %str = sprintf( ' Loading %s', file_blis ); disp(str); run( file_blis ) %str = sprintf( ' Loading %s', file_open ); disp(str); run( file_open ) - %str = sprintf( ' Loading %s', file_mkl ); disp(str); - run( file_mkl ) + %str = sprintf( ' Loading %s', file_vend ); disp(str); + run( file_vend ) % Construct variable names for the variables in the data files. var_blis = sprintf( vartemp, thr_str, opname, 'asm_blis' ); var_open = sprintf( vartemp, thr_str, opname, 'openblas' ); - var_mkl = sprintf( vartemp, thr_str, opname, 'mkl' ); + var_vend = sprintf( vartemp, thr_str, opname, 'vendor' ); % Use eval() to instantiate the variable names constructed above, % copying each to a simplified name. data_blis = eval( var_blis ); % e.g. data_st_sgemm_asm_blis( :, : ); data_open = eval( var_open ); % e.g. data_st_sgemm_openblas( :, : ); - data_mkl = eval( var_mkl ); % e.g. data_st_sgemm_mkl( :, : ); + data_vend = eval( var_vend ); % e.g. data_st_sgemm_vendor( :, : ); + + % Only read Eigen data in select cases. + if with_eigen == 1 + opname_u = opname; opname_u(1) = '_'; + if nth == 1 || strcmp( opname_u, '_gemm' ) + file_eige = sprintf( filetemp_eige, dirpath, thr_str, opname ); + run( file_eige ) + var_eige = sprintf( vartemp, thr_str, opname, 'eigen' ); + data_eige = eval( var_eige ); % e.g. data_st_sgemm_eigen( :, : ); + else + data_eige(1,1) = -1; + end + else + data_eige(1,1) = -1; + end % Plot one result in an m x n grid of plots, via the subplot() % function. plot_l3_perf( opname, ... data_blis, ... data_open, ... - data_mkl, ... + data_eige, ... + data_vend, vend_str, ... nth, ... 4, 5, ... + with_eigen, ... cfreq, ... dflopspercycle, ... opi ); end + % Construct the name of the file to which we will output the graph. -outfile = sprintf( 'l3_perf_panel_nt%d', nth ); +outfile = sprintf( 'l3_perf_%s_nt%d.pdf', arch_str, nth ); % Output the graph to pdf format. -print(gcf, outfile,'-bestfit','-dpdf'); %print(gcf, 'gemm_md','-fillpage','-dpdf'); +print(gcf, outfile,'-bestfit','-dpdf'); +end diff --git a/test/3/matlab/runme.m b/test/3/matlab/runme.m new file mode 100644 index 000000000..2e4685735 --- /dev/null +++ b/test/3/matlab/runme.m @@ -0,0 +1,35 @@ +% tx2 +plot_panel_4x5(2.20,8,1, 'st','../results/tx2/20190205/st', 'tx2', 'ARMPL'); close; clear all; +plot_panel_4x5(2.20,8,28,'1s','../results/tx2/20190205/jc4ic7','tx2_jc4ic7','ARMPL'); close; clear all; +plot_panel_4x5(2.20,8,56,'2s','../results/tx2/20190205/jc8ic7','tx2_jc8ic7','ARMPL'); close; clear all; + +% skx +% pre-eigen: +%plot_panel_4x5(2.00,32,1, 'st','../results/skx/20190306/st', 'skx', 'MKL'); close; clear all; +%plot_panel_4x5(2.00,32,26,'1s','../results/skx/20190306/jc2ic13','skx_jc2ic13','MKL'); close; clear all; +%plot_panel_4x5(2.00,32,52,'2s','../results/skx/20190306/jc4ic13','skx_jc4ic13','MKL'); close; clear all; +% with eigen: +plot_panel_4x5(2.00,32,1, 'st','../results/skx/merged20190306_0328/st', 'skx', 'MKL',1); close; clear all; +plot_panel_4x5(2.00,32,26,'1s','../results/skx/merged20190306_0328/jc2ic13','skx_jc2ic13','MKL',1); close; clear all; +plot_panel_4x5(2.00,32,52,'2s','../results/skx/merged20190306_0328/jc4ic13','skx_jc4ic13','MKL',1); close; clear all; + +% has +% pre-eigen: +%plot_panel_4x5(3.25,16,1, 'st','../results/has/20190206/st', 'has', 'MKL',1); close; clear all; +%plot_panel_4x5(3.00,16,12,'1s','../results/has/20190206/jc2ic3jr2','has_jc2ic3jr2','MKL',1); close; clear all; +%plot_panel_4x5(3.00,16,24,'2s','../results/has/20190206/jc4ic3jr2','has_jc4ic3jr2','MKL',1); close; clear all; +% with eigen: +plot_panel_4x5(3.25,16,1, 'st','../results/has/merged20190206_0328/st', 'has', 'MKL',1); close; clear all; +plot_panel_4x5(3.00,16,12,'1s','../results/has/merged20190206_0328/jc2ic3jr2','has_jc2ic3jr2','MKL',1); close; clear all; +plot_panel_4x5(3.00,16,24,'2s','../results/has/merged20190206_0328/jc4ic3jr2','has_jc4ic3jr2','MKL',1); close; clear all; + +% epyc +% pre-eigen: +%plot_panel_4x5(3.00,8,1, 'st','../results/epyc/merged201903_0619/st','epyc', 'MKL'); close; clear all; +%plot_panel_4x5(2.55,8,32,'1s','../results/epyc/merged201903_0619/jc1ic8jr4','epyc_jc1ic8jr4','MKL'); close; clear all; +%plot_panel_4x5(2.55,8,64,'2s','../results/epyc/merged201903_0619/jc2ic8jr4','epyc_jc2ic8jr4','MKL'); close; clear all; +% with eigen: +plot_panel_4x5(3.00,8,1, 'st','../results/epyc/merged20190306_0319_0328/st', 'epyc', 'MKL',1); close; clear all; +plot_panel_4x5(2.55,8,32,'1s','../results/epyc/merged20190306_0319_0328/jc1ic8jr4','epyc_jc1ic8jr4','MKL',1); close; clear all; +plot_panel_4x5(2.55,8,64,'2s','../results/epyc/merged20190306_0319_0328/jc2ic8jr4','epyc_jc2ic8jr4','MKL',1); close; clear all; + diff --git a/test/3/runme.sh b/test/3/runme.sh new file mode 100755 index 000000000..9933dd1e5 --- /dev/null +++ b/test/3/runme.sh @@ -0,0 +1,230 @@ +#!/bin/bash + +# File pefixes. +exec_root="test" +out_root="output" +delay=0.1 + +sys="blis" +#sys="stampede2" +#sys="lonestar5" +#sys="ul252" +#sys="ul264" + +# Bind threads to processors. +#export OMP_PROC_BIND=true +#export GOMP_CPU_AFFINITY="0 2 4 6 8 10 12 14 16 18 20 22 1 3 5 7 9 11 13 15 17 19 21 23" +#export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103" + +if [ ${sys} = "blis" ]; then + + export GOMP_CPU_AFFINITY="0 1 2 3" + + threads="jc1ic1jr1_2400 + jc2ic3jr2_6000 + jc4ic3jr2_8000" + +elif [ ${sys} = "stampede2" ]; then + + echo "Need to set GOMP_CPU_AFFINITY." + exit 1 + + threads="jc1ic1jr1_2400 + jc4ic6jr1_6000 + jc4ic12jr1_8000" + +elif [ ${sys} = "lonestar5" ]; then + + export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23" + + # A hack to use libiomp5 with gcc. + #export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/apps/intel/16.0.1.150/compilers_and_libraries_2016.1.150/linux/compiler/lib/intel64" + + threads="jc1ic1jr1_2400 + jc2ic3jr2_6000 + jc4ic3jr2_8000" + +elif [ ${sys} = "ul252" ]; then + + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" + export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51" + + threads="jc1ic1jr1_2400 + jc2ic13jr1_6000 + jc4ic13jr1_8000" + +elif [ ${sys} = "ul264" ]; then + + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" + export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63" + + threads="jc1ic1jr1_2400 + jc1ic8jr4_6000 + jc2ic8jr4_8000" + +fi + +# Datatypes to test. +test_dts="d s z c" + +# Operations to test. +test_ops="gemm hemm herk trmm trsm" + +# Implementations to test. +#impls="blis" +#impls="other" +impls="eigen" +#impls="all" + +if [ "${impls}" = "blis" ]; then + + test_impls="asm_blis" + +elif [ "${impls}" = "eigen" ]; then + + test_impls="eigen" + +elif [ "${impls}" = "other" ]; then + + test_impls="openblas vendor" + +elif [ "${impls}" = "eigen" ]; then + + test_impls="eigen" + +else + + test_impls="openblas asm_blis vendor" +fi + +# Save a copy of GOMP_CPU_AFFINITY so that if we have to unset it, we can +# restore the value. +GOMP_CPU_AFFINITYsave=${GOMP_CPU_AFFINITY} + + +# First perform real test cases. +for th in ${threads}; do + + # Start with one way of parallelism in each loop. We will now begin + # parsing the 'th' variable to update one or more of these threading + # parameters. + jc_nt=1; pc_nt=1; ic_nt=1; jr_nt=1; ir_nt=1 + + # Strip everything before and after the underscore so that what remains + # is the problem size and threading parameter string, respectively. + psize=${th##*_}; thinfo=${th%%_*} + + # Identify each threading parameter and insert a space before it. + thsep=$(echo -e ${thinfo} | sed -e "s/\([jip][cr]\)/ \1/g" ) + + nt=1 + + for loopnum in ${thsep}; do + + # Given the current string, which identifies a loop and the + # number of ways of parallelism for that loop, strip out + # the ways and loop separately to identify each. + loop=$(echo -e ${loopnum} | sed -e "s/[0-9]//g" ) + num=$(echo -e ${loopnum} | sed -e "s/[a-z]//g" ) + + # Construct a string that we can evaluate to set the number + # of ways of parallelism for the current loop. + loop_nt_eq_num="${loop}_nt=${num}" + + # Update the total number of threads. + nt=$(expr ${nt} \* ${num}) + + # Evaluate the string to assign the ways to the variable. + eval ${loop_nt_eq_num} + + done + + echo "Switching to: jc${jc_nt} pc${pc_nt} ic${ic_nt} jr${jr_nt} ir${ir_nt} (nt = ${nt}) p_max${psize}" + + + for dt in ${test_dts}; do + + for im in ${test_impls}; do + + for op in ${test_ops}; do + + # Eigen does not support multithreading for hemm, herk, trmm, + # or trsm. So if we're getting ready to execute an Eigen driver + # for one of these operations and nt > 1, we skip this test. + if [ "${im}" = "eigen" ] && \ + [ "${op}" != "gemm" ] && \ + [ "${nt}" != "1" ]; then + continue; + fi + + # Find the threading suffix by probing the executable. + binname=$(ls ${exec_root}_${dt}${op}_${psize}_${im}_*.x) + suf_ext=${binname##*_} + suf=${suf_ext%%.*} + + #echo "found file: ${binname} with suffix ${suf}" + + # Set the number of threads according to th. + if [ "${suf}" = "1s" ] || [ "${suf}" = "2s" ]; then + + # Set the threading parameters based on the implementation + # that we are preparing to run. + if [ "${im}" = "asm_blis" ]; then + unset OMP_NUM_THREADS + export BLIS_JC_NT=${jc_nt} + export BLIS_PC_NT=${pc_nt} + export BLIS_IC_NT=${ic_nt} + export BLIS_JR_NT=${jr_nt} + export BLIS_IR_NT=${ir_nt} + elif [ "${im}" = "openblas" ]; then + unset OMP_NUM_THREADS + export OPENBLAS_NUM_THREADS=${nt} + elif [ "${im}" = "eigen" ]; then + export OMP_NUM_THREADS=${nt} + elif [ "${im}" = "vendor" ]; then + unset OMP_NUM_THREADS + export MKL_NUM_THREADS=${nt} + fi + export nt_use=${nt} + + # Multithreaded OpenBLAS seems to have a problem running + # properly if GOMP_CPU_AFFINITY is set. So we temporarily + # unset it here if we are about to execute OpenBLAS, but + # otherwise restore it. + if [ ${im} = "openblas" ]; then + unset GOMP_CPU_AFFINITY + else + export GOMP_CPU_AFFINITY="${GOMP_CPU_AFFINITYsave}" + fi + else + + export BLIS_JC_NT=1 + export BLIS_PC_NT=1 + export BLIS_IC_NT=1 + export BLIS_JR_NT=1 + export BLIS_IR_NT=1 + export OMP_NUM_THREADS=1 + export OPENBLAS_NUM_THREADS=1 + export MKL_NUM_THREADS=1 + export nt_use=1 + fi + + # Construct the name of the test executable. + exec_name="${exec_root}_${dt}${op}_${psize}_${im}_${suf}.x" + + # Construct the name of the output file. + out_file="${out_root}_${suf}_${dt}${op}_${im}.m" + + #echo "Running (nt = ${nt_use}) ./${exec_name} > ${out_file}" + echo "Running ./${exec_name} > ${out_file}" + + # Run executable. + ./${exec_name} > ${out_file} + + sleep ${delay} + + done + done + done +done + diff --git a/test/3/test_gemm.c b/test/3/test_gemm.c new file mode 100644 index 000000000..508de4fd9 --- /dev/null +++ b/test/3/test_gemm.c @@ -0,0 +1,417 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#ifdef EIGEN + #define BLIS_DISABLE_BLAS_DEFS + #include "blis.h" + #include + #include + using namespace Eigen; +#else + #include "blis.h" +#endif + +#define COL_STORAGE +//#define ROW_STORAGE + +//#define PRINT + +int main( int argc, char** argv ) +{ + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, n, k; + dim_t p; + dim_t p_begin, p_max, p_inc; + int m_input, n_input, k_input; + ind_t ind; + num_t dt; + char dt_ch; + int r, n_repeats; + trans_t transa; + trans_t transb; + f77_char f77_transa; + f77_char f77_transb; + + double dtime; + double dtime_save; + double gflops; + + //bli_init(); + + //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + + n_repeats = 3; + + dt = DT; + + ind = IND; + +#if 1 + p_begin = P_BEGIN; + p_max = P_MAX; + p_inc = P_INC; + + m_input = -1; + n_input = -1; + k_input = -1; +#else + p_begin = 40; + p_max = 1000; + p_inc = 40; + + m_input = -1; + n_input = -1; + k_input = -1; +#endif + + + // Supress compiler warnings about unused variable 'ind'. + ( void )ind; + +#if 0 + + cntx_t* cntx; + + ind_t ind_mod = ind; + + // A hack to use 3m1 as 1mpb (with 1m as 1mbp). + if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; + + // Initialize a context for the current induced method and datatype. + cntx = bli_gks_query_ind_cntx( ind_mod, dt ); + + // Set k to the kc blocksize for the current datatype. + k_input = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); + +#elif 1 + + //k_input = 256; + +#endif + + // Choose the char corresponding to the requested datatype. + if ( bli_is_float( dt ) ) dt_ch = 's'; + else if ( bli_is_double( dt ) ) dt_ch = 'd'; + else if ( bli_is_scomplex( dt ) ) dt_ch = 'c'; + else dt_ch = 'z'; + + transa = BLIS_NO_TRANSPOSE; + transb = BLIS_NO_TRANSPOSE; + + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ; + + printf( "data_%s_%cgemm_%s", THR_STR, dt_ch, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )0, + ( unsigned long )0, + ( unsigned long )0, 0.0 ); + + + for ( p = p_begin; p <= p_max; p += p_inc ) + { + + if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); + else m = ( dim_t ) m_input; + if ( n_input < 0 ) n = p / ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + if ( k_input < 0 ) k = p / ( dim_t )abs(k_input); + else k = ( dim_t ) k_input; + + bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + #ifdef COL_STORAGE + bli_obj_create( dt, m, k, 0, 0, &a ); + bli_obj_create( dt, k, n, 0, 0, &b ); + bli_obj_create( dt, m, n, 0, 0, &c ); + bli_obj_create( dt, m, n, 0, 0, &c_save ); + #else + bli_obj_create( dt, m, k, k, 1, &a ); + bli_obj_create( dt, k, n, n, 1, &b ); + bli_obj_create( dt, m, n, n, 1, &c ); + bli_obj_create( dt, m, n, n, 1, &c_save ); + #endif + + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); + + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_conjtrans( transb, &b ); + + bli_setsc( (2.0/1.0), 0.0, &alpha ); + bli_setsc( (1.0/1.0), 0.0, &beta ); + + bli_copym( &c, &c_save ); + +#if 0 //def BLIS + bli_ind_disable_all_dt( dt ); + bli_ind_enable_dt( ind, dt ); +#endif + +#ifdef EIGEN + double alpha_r, alpha_i; + + bli_getsc( &alpha, &alpha_r, &alpha_i ); + + void* ap = bli_obj_buffer_at_off( &a ); + void* bp = bli_obj_buffer_at_off( &b ); + void* cp = bli_obj_buffer_at_off( &c ); + + #ifdef COL_STORAGE + const int os_a = bli_obj_col_stride( &a ); + const int os_b = bli_obj_col_stride( &b ); + const int os_c = bli_obj_col_stride( &c ); + #else + const int os_a = bli_obj_row_stride( &a ); + const int os_b = bli_obj_row_stride( &b ); + const int os_c = bli_obj_row_stride( &c ); + #endif + + Stride stride_a( os_a, 1 ); + Stride stride_b( os_b, 1 ); + Stride stride_c( os_c, 1 ); + + #ifdef COL_STORAGE + #if defined(IS_FLOAT) + typedef Matrix MatrixXf_; + #elif defined (IS_DOUBLE) + typedef Matrix MatrixXd_; + #elif defined (IS_SCOMPLEX) + typedef Matrix, Dynamic, Dynamic, ColMajor> MatrixXcf_; + #elif defined (IS_DCOMPLEX) + typedef Matrix, Dynamic, Dynamic, ColMajor> MatrixXcd_; + #endif + #else + #if defined(IS_FLOAT) + typedef Matrix MatrixXf_; + #elif defined (IS_DOUBLE) + typedef Matrix MatrixXd_; + #elif defined (IS_SCOMPLEX) + typedef Matrix, Dynamic, Dynamic, RowMajor> MatrixXcf_; + #elif defined (IS_DCOMPLEX) + typedef Matrix, Dynamic, Dynamic, RowMajor> MatrixXcd_; + #endif + #endif + #if defined(IS_FLOAT) + Map > A( ( float* )ap, m, k, stride_a ); + Map > B( ( float* )bp, k, n, stride_b ); + Map > C( ( float* )cp, m, n, stride_c ); + #elif defined (IS_DOUBLE) + Map > A( ( double* )ap, m, k, stride_a ); + Map > B( ( double* )bp, k, n, stride_b ); + Map > C( ( double* )cp, m, n, stride_c ); + #elif defined (IS_SCOMPLEX) + Map > A( ( std::complex* )ap, m, k, stride_a ); + Map > B( ( std::complex* )bp, k, n, stride_b ); + Map > C( ( std::complex* )cp, m, n, stride_c ); + #elif defined (IS_DCOMPLEX) + Map > A( ( std::complex* )ap, m, k, stride_a ); + Map > B( ( std::complex* )bp, k, n, stride_b ); + Map > C( ( std::complex* )cp, m, n, stride_c ); + #endif +#endif + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + bli_copym( &c_save, &c ); + + dtime = bli_clock(); + +#ifdef PRINT + bli_printm( "a", &a, "%4.1f", "" ); + bli_printm( "b", &b, "%4.1f", "" ); + bli_printm( "c", &c, "%4.1f", "" ); +#endif + +#if defined(BLIS) + + bli_gemm( &alpha, + &a, + &b, + &beta, + &c ); + +#elif defined(EIGEN) + + C.noalias() += alpha_r * A * B; + +#else // if defined(BLAS) + + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* bp = ( float* )bli_obj_buffer( &b ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); + + sgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* bp = ( double* )bli_obj_buffer( &b ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); + + dgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* bp = ( scomplex* )bli_obj_buffer( &b ); + scomplex* betap = ( scomplex* )bli_obj_buffer( &beta ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); + + cgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* bp = ( dcomplex* )bli_obj_buffer( &b ); + dcomplex* betap = ( dcomplex* )bli_obj_buffer( &beta ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); + + zgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } +#endif + +#ifdef PRINT + bli_printm( "c after", &c, "%4.1f", "" ); + exit(1); +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 4.0; + + printf( "data_%s_%cgemm_%s", THR_STR, dt_ch, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )m, + ( unsigned long )k, + ( unsigned long )n, gflops ); + + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + + //bli_finalize(); + + return 0; +} + diff --git a/test/3m4m/test_hemm.c b/test/3/test_hemm.c similarity index 60% rename from test/3m4m/test_hemm.c rename to test/3/test_hemm.c index bbf404379..73746ae4b 100644 --- a/test/3m4m/test_hemm.c +++ b/test/3/test_hemm.c @@ -44,7 +44,7 @@ int main( int argc, char** argv ) obj_t alpha, beta; dim_t m, n; dim_t p; - dim_t p_begin, p_end, p_inc; + dim_t p_begin, p_max, p_inc; int m_input, n_input; ind_t ind; num_t dt; @@ -70,7 +70,7 @@ int main( int argc, char** argv ) ind = IND; p_begin = P_BEGIN; - p_end = P_END; + p_max = P_MAX; p_inc = P_INC; m_input = -1; @@ -115,19 +115,16 @@ int main( int argc, char** argv ) // Begin with initializing the last entry to zero so that // matlab allocates space for the entire array once up-front. - for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; -#ifdef BLIS - printf( "data_%s_%chemm_%s_blis", THR_STR, dt_ch, STR ); -#else - printf( "data_%s_%chemm_%s", THR_STR, dt_ch, STR ); -#endif + for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ; + + printf( "data_%s_%chemm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", ( unsigned long )(p - p_begin + 1)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_begin; p <= p_max; p += p_inc ) { if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); @@ -161,7 +158,6 @@ int main( int argc, char** argv ) bli_setsc( (2.0/1.0), 0.0, &alpha ); bli_setsc( (1.0/1.0), 0.0, &beta ); - bli_copym( &c, &c_save ); #if 0 //def BLIS @@ -177,7 +173,6 @@ int main( int argc, char** argv ) dtime = bli_clock(); - #ifdef PRINT bli_printm( "a", &a, "%4.1f", "" ); bli_printm( "b", &b, "%4.1f", "" ); @@ -195,98 +190,114 @@ int main( int argc, char** argv ) #else - if ( bli_is_float( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* bp = bli_obj_buffer( &b ); - float* betap = bli_obj_buffer( &beta ); - float* cp = bli_obj_buffer( &c ); + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* bp = ( float* )bli_obj_buffer( &b ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); - ssymm_( &f77_side, - &f77_uploa, - &mm, - &nn, - alphap, - ap, &lda, - bp, &ldb, - betap, - cp, &ldc ); - } - else if ( bli_is_double( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* bp = bli_obj_buffer( &b ); - double* betap = bli_obj_buffer( &beta ); - double* cp = bli_obj_buffer( &c ); + ssymm_( &f77_side, + &f77_uploa, + &mm, + &nn, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* bp = ( double* )bli_obj_buffer( &b ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); - dsymm_( &f77_side, - &f77_uploa, - &mm, - &nn, - alphap, - ap, &lda, - bp, &ldb, - betap, - cp, &ldc ); - } - else if ( bli_is_scomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* bp = bli_obj_buffer( &b ); - scomplex* betap = bli_obj_buffer( &beta ); - scomplex* cp = bli_obj_buffer( &c ); + dsymm_( &f77_side, + &f77_uploa, + &mm, + &nn, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); +#ifdef EIGEN + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* bp = ( float* )bli_obj_buffer( &b ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); +#else + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* bp = ( scomplex* )bli_obj_buffer( &b ); + scomplex* betap = ( scomplex* )bli_obj_buffer( &beta ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); +#endif - chemm_( &f77_side, - &f77_uploa, - &mm, - &nn, - alphap, - ap, &lda, - bp, &ldb, - betap, - cp, &ldc ); - } - else if ( bli_is_dcomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* bp = bli_obj_buffer( &b ); - dcomplex* betap = bli_obj_buffer( &beta ); - dcomplex* cp = bli_obj_buffer( &c ); + chemm_( &f77_side, + &f77_uploa, + &mm, + &nn, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); +#ifdef EIGEN + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* bp = ( double* )bli_obj_buffer( &b ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); +#else + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* bp = ( dcomplex* )bli_obj_buffer( &b ); + dcomplex* betap = ( dcomplex* )bli_obj_buffer( &beta ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); +#endif - zhemm_( &f77_side, - &f77_uploa, - &mm, - &nn, - alphap, - ap, &lda, - bp, &ldb, - betap, - cp, &ldc ); - } + zhemm_( &f77_side, + &f77_uploa, + &mm, + &nn, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } #endif #ifdef PRINT @@ -294,7 +305,6 @@ int main( int argc, char** argv ) exit(1); #endif - dtime_save = bli_clock_min_diff( dtime_save, dtime ); } @@ -305,11 +315,7 @@ int main( int argc, char** argv ) if ( bli_is_complex( dt ) ) gflops *= 4.0; -#ifdef BLIS - printf( "data_%s_%chemm_%s_blis", THR_STR, dt_ch, STR ); -#else - printf( "data_%s_%chemm_%s", THR_STR, dt_ch, STR ); -#endif + printf( "data_%s_%chemm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", ( unsigned long )(p - p_begin + 1)/p_inc + 1, ( unsigned long )m, diff --git a/test/3m4m/test_herk.c b/test/3/test_herk.c similarity index 63% rename from test/3m4m/test_herk.c rename to test/3/test_herk.c index 1626bb6fa..623bec30e 100644 --- a/test/3m4m/test_herk.c +++ b/test/3/test_herk.c @@ -36,7 +36,6 @@ #include #include "blis.h" - //#define PRINT int main( int argc, char** argv ) @@ -46,7 +45,7 @@ int main( int argc, char** argv ) obj_t alpha, beta; dim_t m, k; dim_t p; - dim_t p_begin, p_end, p_inc; + dim_t p_begin, p_max, p_inc; int m_input, k_input; ind_t ind; num_t dt, dt_real; @@ -73,7 +72,7 @@ int main( int argc, char** argv ) ind = IND; p_begin = P_BEGIN; - p_end = P_END; + p_max = P_MAX; p_inc = P_INC; m_input = -1; @@ -118,19 +117,16 @@ int main( int argc, char** argv ) // Begin with initializing the last entry to zero so that // matlab allocates space for the entire array once up-front. - for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; -#ifdef BLIS - printf( "data_%s_%cherk_%s_blis", THR_STR, dt_ch, STR ); -#else - printf( "data_%s_%cherk_%s", THR_STR, dt_ch, STR ); -#endif + for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ; + + printf( "data_%s_%cherk_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", ( unsigned long )(p - p_begin + 1)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_begin; p <= p_max; p += p_inc ) { if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); @@ -162,7 +158,6 @@ int main( int argc, char** argv ) bli_setsc( (2.0/1.0), 0.0, &alpha ); bli_setsc( (1.0/1.0), 0.0, &beta ); - bli_copym( &c, &c_save ); #if 0 //def BLIS @@ -176,10 +171,8 @@ int main( int argc, char** argv ) { bli_copym( &c_save, &c ); - dtime = bli_clock(); - #ifdef PRINT bli_printm( "a", &a, "%4.1f", "" ); bli_printm( "c", &c, "%4.1f", "" ); @@ -194,86 +187,100 @@ int main( int argc, char** argv ) #else - if ( bli_is_float( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* betap = bli_obj_buffer( &beta ); - float* cp = bli_obj_buffer( &c ); + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); - ssyrk_( &f77_uploc, - &f77_transa, - &mm, - &kk, - alphap, - ap, &lda, - betap, - cp, &ldc ); - } - else if ( bli_is_double( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* betap = bli_obj_buffer( &beta ); - double* cp = bli_obj_buffer( &c ); + ssyrk_( &f77_uploc, + &f77_transa, + &mm, + &kk, + alphap, + ap, &lda, + betap, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); - dsyrk_( &f77_uploc, - &f77_transa, - &mm, - &kk, - alphap, - ap, &lda, - betap, - cp, &ldc ); - } - else if ( bli_is_scomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - float* betap = bli_obj_buffer( &beta ); - scomplex* cp = bli_obj_buffer( &c ); + dsyrk_( &f77_uploc, + &f77_transa, + &mm, + &kk, + alphap, + ap, &lda, + betap, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); +#ifdef EIGEN + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); +#else + float* alphap = ( float* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + float* betap = ( float* )bli_obj_buffer( &beta ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); +#endif - cherk_( &f77_uploc, - &f77_transa, - &mm, - &kk, - alphap, - ap, &lda, - betap, - cp, &ldc ); - } - else if ( bli_is_dcomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - double* betap = bli_obj_buffer( &beta ); - dcomplex* cp = bli_obj_buffer( &c ); + cherk_( &f77_uploc, + &f77_transa, + &mm, + &kk, + alphap, + ap, &lda, + betap, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); +#ifdef EIGEN + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); +#else + double* alphap = ( double* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + double* betap = ( double* )bli_obj_buffer( &beta ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); +#endif - zherk_( &f77_uploc, - &f77_transa, - &mm, - &kk, - alphap, - ap, &lda, - betap, - cp, &ldc ); - } + zherk_( &f77_uploc, + &f77_transa, + &mm, + &kk, + alphap, + ap, &lda, + betap, + cp, &ldc ); + } #endif #ifdef PRINT @@ -281,7 +288,6 @@ int main( int argc, char** argv ) exit(1); #endif - dtime_save = bli_clock_min_diff( dtime_save, dtime ); } @@ -289,11 +295,7 @@ int main( int argc, char** argv ) if ( bli_is_complex( dt ) ) gflops *= 4.0; -#ifdef BLIS - printf( "data_%s_%cherk_%s_blis", THR_STR, dt_ch, STR ); -#else - printf( "data_%s_%cherk_%s", THR_STR, dt_ch, STR ); -#endif + printf( "data_%s_%cherk_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", ( unsigned long )(p - p_begin + 1)/p_inc + 1, ( unsigned long )m, diff --git a/test/3m4m/test_trmm.c b/test/3/test_trmm.c similarity index 66% rename from test/3m4m/test_trmm.c rename to test/3/test_trmm.c index 20bfa11c0..5fa7f8e32 100644 --- a/test/3m4m/test_trmm.c +++ b/test/3/test_trmm.c @@ -36,7 +36,6 @@ #include #include "blis.h" - //#define PRINT int main( int argc, char** argv ) @@ -46,7 +45,7 @@ int main( int argc, char** argv ) obj_t alpha; dim_t m, n; dim_t p; - dim_t p_begin, p_end, p_inc; + dim_t p_begin, p_max, p_inc; int m_input, n_input; ind_t ind; num_t dt; @@ -76,7 +75,7 @@ int main( int argc, char** argv ) ind = IND; p_begin = P_BEGIN; - p_end = P_END; + p_max = P_MAX; p_inc = P_INC; m_input = -1; @@ -133,19 +132,16 @@ int main( int argc, char** argv ) // Begin with initializing the last entry to zero so that // matlab allocates space for the entire array once up-front. - for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; -#ifdef BLIS - printf( "data_%s_%ctrmm_%s_blis", THR_STR, dt_ch, STR ); -#else - printf( "data_%s_%ctrmm_%s", THR_STR, dt_ch, STR ); -#endif + for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ; + + printf( "data_%s_%ctrmm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", ( unsigned long )(p - p_begin + 1)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_begin; p <= p_max; p += p_inc ) { if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); @@ -155,7 +151,7 @@ int main( int argc, char** argv ) bli_obj_create( dt, 1, 1, 0, 0, &alpha ); - if ( bli_does_trans( side ) ) + if ( bli_is_left( side ) ) bli_obj_create( dt, m, m, 0, 0, &a ); else bli_obj_create( dt, n, n, 0, 0, &a ); @@ -188,10 +184,8 @@ int main( int argc, char** argv ) { bli_copym( &c_save, &c ); - dtime = bli_clock(); - #ifdef PRINT bli_printm( "a", &a, "%4.1f", "" ); bli_printm( "c", &c, "%4.1f", "" ); @@ -206,86 +200,98 @@ int main( int argc, char** argv ) #else - if ( bli_is_float( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* cp = bli_obj_buffer( &c ); + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* cp = ( float* )bli_obj_buffer( &c ); - strmm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &kk, - alphap, - ap, &lda, - cp, &ldc ); - } - else if ( bli_is_double( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* cp = bli_obj_buffer( &c ); + strmm_( &f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &kk, + alphap, + ap, &lda, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* cp = ( double* )bli_obj_buffer( &c ); - dtrmm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &kk, - alphap, - ap, &lda, - cp, &ldc ); - } - else if ( bli_is_scomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* cp = bli_obj_buffer( &c ); + dtrmm_( &f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &kk, + alphap, + ap, &lda, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); +#ifdef EIGEN + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* cp = ( float* )bli_obj_buffer( &c ); +#else + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); +#endif - ctrmm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &kk, - alphap, - ap, &lda, - cp, &ldc ); - } - else if ( bli_is_dcomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* cp = bli_obj_buffer( &c ); + ctrmm_( &f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &kk, + alphap, + ap, &lda, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); +#ifdef EIGEN + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* cp = ( double* )bli_obj_buffer( &c ); +#else + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); +#endif - ztrmm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &kk, - alphap, - ap, &lda, - cp, &ldc ); - } + ztrmm_( &f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &kk, + alphap, + ap, &lda, + cp, &ldc ); + } #endif #ifdef PRINT @@ -293,7 +299,6 @@ int main( int argc, char** argv ) exit(1); #endif - dtime_save = bli_clock_min_diff( dtime_save, dtime ); } @@ -304,11 +309,7 @@ int main( int argc, char** argv ) if ( bli_is_complex( dt ) ) gflops *= 4.0; -#ifdef BLIS - printf( "data_%s_%ctrmm_%s_blis", THR_STR, dt_ch, STR ); -#else - printf( "data_%s_%ctrmm_%s", THR_STR, dt_ch, STR ); -#endif + printf( "data_%s_%ctrmm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", ( unsigned long )(p - p_begin + 1)/p_inc + 1, ( unsigned long )m, diff --git a/test/3m4m/test_trsm.c b/test/3/test_trsm.c similarity index 66% rename from test/3m4m/test_trsm.c rename to test/3/test_trsm.c index a696a87a6..fe1481ddf 100644 --- a/test/3m4m/test_trsm.c +++ b/test/3/test_trsm.c @@ -36,7 +36,6 @@ #include #include "blis.h" - //#define PRINT int main( int argc, char** argv ) @@ -46,7 +45,7 @@ int main( int argc, char** argv ) obj_t alpha; dim_t m, n; dim_t p; - dim_t p_begin, p_end, p_inc; + dim_t p_begin, p_max, p_inc; int m_input, n_input; ind_t ind; num_t dt; @@ -76,7 +75,7 @@ int main( int argc, char** argv ) ind = IND; p_begin = P_BEGIN; - p_end = P_END; + p_max = P_MAX; p_inc = P_INC; m_input = -1; @@ -133,19 +132,16 @@ int main( int argc, char** argv ) // Begin with initializing the last entry to zero so that // matlab allocates space for the entire array once up-front. - for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; -#ifdef BLIS - printf( "data_%s_%ctrsm_%s_blis", THR_STR, dt_ch, STR ); -#else - printf( "data_%s_%ctrsm_%s", THR_STR, dt_ch, STR ); -#endif + for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ; + + printf( "data_%s_%ctrsm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", ( unsigned long )(p - p_begin + 1)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, 0.0 ); - for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_begin; p <= p_max; p += p_inc ) { if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); @@ -155,7 +151,7 @@ int main( int argc, char** argv ) bli_obj_create( dt, 1, 1, 0, 0, &alpha ); - if ( bli_does_trans( side ) ) + if ( bli_is_left( side ) ) bli_obj_create( dt, m, m, 0, 0, &a ); else bli_obj_create( dt, n, n, 0, 0, &a ); @@ -192,10 +188,8 @@ int main( int argc, char** argv ) { bli_copym( &c_save, &c ); - dtime = bli_clock(); - #ifdef PRINT bli_printm( "a", &a, "%4.1f", "" ); bli_printm( "c", &c, "%4.1f", "" ); @@ -210,86 +204,98 @@ int main( int argc, char** argv ) #else - if ( bli_is_float( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* cp = bli_obj_buffer( &c ); + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* cp = ( float* )bli_obj_buffer( &c ); - strsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &kk, - alphap, - ap, &lda, - cp, &ldc ); - } - else if ( bli_is_double( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* cp = bli_obj_buffer( &c ); + strsm_( &f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &kk, + alphap, + ap, &lda, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* cp = ( double* )bli_obj_buffer( &c ); - dtrsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &kk, - alphap, - ap, &lda, - cp, &ldc ); - } - else if ( bli_is_scomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* cp = bli_obj_buffer( &c ); + dtrsm_( &f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &kk, + alphap, + ap, &lda, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); +#ifdef EIGEN + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* cp = ( float* )bli_obj_buffer( &c ); +#else + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); +#endif - ctrsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &kk, - alphap, - ap, &lda, - cp, &ldc ); - } - else if ( bli_is_dcomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* cp = bli_obj_buffer( &c ); + ctrsm_( &f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &kk, + alphap, + ap, &lda, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldc = bli_obj_col_stride( &c ); +#ifdef EIGEN + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* cp = ( double* )bli_obj_buffer( &c ); +#else + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); +#endif - ztrsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &kk, - alphap, - ap, &lda, - cp, &ldc ); - } + ztrsm_( &f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &kk, + alphap, + ap, &lda, + cp, &ldc ); + } #endif #ifdef PRINT @@ -297,7 +303,6 @@ int main( int argc, char** argv ) exit(1); #endif - dtime_save = bli_clock_min_diff( dtime_save, dtime ); } @@ -308,11 +313,7 @@ int main( int argc, char** argv ) if ( bli_is_complex( dt ) ) gflops *= 4.0; -#ifdef BLIS - printf( "data_%s_%ctrsm_%s_blis", THR_STR, dt_ch, STR ); -#else - printf( "data_%s_%ctrsm_%s", THR_STR, dt_ch, STR ); -#endif + printf( "data_%s_%ctrsm_%s", THR_STR, dt_ch, STR ); printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", ( unsigned long )(p - p_begin + 1)/p_inc + 1, ( unsigned long )m, diff --git a/test/3m4m/Makefile b/test/3m4m/Makefile deleted file mode 100644 index b4ae45bb8..000000000 --- a/test/3m4m/Makefile +++ /dev/null @@ -1,586 +0,0 @@ -#!/bin/bash -# -# BLIS -# An object-based framework for developing high-performance BLAS-like -# libraries. -# -# Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# - Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# - Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# - Neither the name(s) of the copyright holder(s) nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -# - -# -# Makefile -# -# Field G. Van Zee -# -# Makefile for standalone BLIS test drivers. -# - -# -# --- Makefile PHONY target definitions ---------------------------------------- -# - -.PHONY: all \ - blis-gemm-st openblas-gemm-st mkl-gemm-st acml-gemm-st \ - blis-gemm-mt openblas-gemm-mt mkl-gemm-mt acml-gemm-mt \ - clean cleanx - - - -# -# --- Determine makefile fragment location ------------------------------------- -# - -# Comments: -# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. -# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in -# the second case because CONFIG_NAME is not yet set. -ifneq ($(strip $(BLIS_INSTALL_PATH)),) -LIB_PATH := $(BLIS_INSTALL_PATH)/lib -INC_PATH := $(BLIS_INSTALL_PATH)/include/blis -SHARE_PATH := $(BLIS_INSTALL_PATH)/share/blis -else -DIST_PATH := ../.. -LIB_PATH = ../../lib/$(CONFIG_NAME) -INC_PATH = ../../include/$(CONFIG_NAME) -SHARE_PATH := ../.. -endif - - - -# -# --- Include common makefile definitions -------------------------------------- -# - -# Include the common makefile fragment. --include $(SHARE_PATH)/common.mk - - - -# -# --- BLAS and LAPACK implementations ------------------------------------------ -# - -# BLIS library and header path. This is simply wherever it was installed. -#BLIS_LIB_PATH := $(INSTALL_PREFIX)/lib -#BLIS_INC_PATH := $(INSTALL_PREFIX)/include/blis - -# BLIS library. -#BLIS_LIB := $(BLIS_LIB_PATH)/libblis.a - -# BLAS library path(s). This is where the BLAS libraries reside. -HOME_LIB_PATH := $(HOME)/flame/lib -#MKL_LIB_PATH := /opt/apps/intel/13/composer_xe_2013.2.146/mkl/lib/intel64 -MKL_LIB_PATH := $(HOME)/intel/mkl/lib/intel64 -#MKL_LIB_PATH := ${MKLROOT}/lib/intel64 -#ICC_LIB_PATH := /opt/apps/intel/13/composer_xe_2013.2.146/compiler/lib/intel64 -ACML_LIB_PATH := $(HOME_LIB_PATH)/acml/5.3.1/gfortran64_fma4_int64/lib -ACMLP_LIB_PATH := $(HOME_LIB_PATH)/acml/5.3.1/gfortran64_fma4_mp_int64/lib - -# OpenBLAS -OPENBLAS_LIB := $(HOME_LIB_PATH)/libopenblas.a -OPENBLASP_LIB := $(HOME_LIB_PATH)/libopenblasp.a - -# ATLAS -ATLAS_LIB := $(HOME_LIB_PATH)/libf77blas.a \ - $(HOME_LIB_PATH)/libatlas.a - -# MKL -MKL_LIB := -L$(MKL_LIB_PATH) \ - -lmkl_intel_lp64 \ - -lmkl_core \ - -lmkl_sequential \ - -lpthread -lm -ldl -#MKLP_LIB := -L$(MKL_LIB_PATH) \ -# -lmkl_intel_thread \ -# -lmkl_core \ -# -lmkl_intel_ilp64 \ -# -L$(ICC_LIB_PATH) \ -# -liomp5 -MKLP_LIB := -L$(MKL_LIB_PATH) \ - -lmkl_intel_lp64 \ - -lmkl_core \ - -lmkl_gnu_thread \ - -lpthread -lm -ldl -fopenmp - #-L$(ICC_LIB_PATH) \ - #-lgomp - -# ACML -ACML_LIB := -L$(ACML_LIB_PATH) \ - -lgfortran -lm -lrt -ldl -lacml -ACMLP_LIB := -L$(ACMLP_LIB_PATH) \ - -lgfortran -lm -lrt -ldl -lacml_mp - - - -# -# --- General build definitions ------------------------------------------------ -# - -TEST_SRC_PATH := . -TEST_OBJ_PATH := . - -# Gather all local object files. -TEST_OBJS := $(sort $(patsubst $(TEST_SRC_PATH)/%.c, \ - $(TEST_OBJ_PATH)/%.o, \ - $(wildcard $(TEST_SRC_PATH)/*.c))) - -# Override the value of CINCFLAGS so that the value of CFLAGS returned by -# get-user-cflags-for() is not cluttered up with include paths needed only -# while building BLIS. -CINCFLAGS := -I$(INC_PATH) - -# Use the "framework" CFLAGS for the configuration family. -CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) - -# Add local header paths to CFLAGS. -CFLAGS += -I$(TEST_SRC_PATH) - -# Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) - - -# Datatype -DT_S := -DDT=BLIS_FLOAT -DT_D := -DDT=BLIS_DOUBLE -DT_C := -DDT=BLIS_SCOMPLEX -DT_Z := -DDT=BLIS_DCOMPLEX - -# Which library? -BLI_DEF := -DBLIS -BLA_DEF := -DBLAS - -# Complex implementation type -D3MHW := -DIND=BLIS_3MH -D3M1 := -DIND=BLIS_3M1 -D4MHW := -DIND=BLIS_4MH -D4M1B := -DIND=BLIS_4M1B -D4M1A := -DIND=BLIS_4M1A -D1M := -DIND=BLIS_1M -DNAT := -DIND=BLIS_NAT - -# Implementation string -STR_3MHW := -DSTR=\"3mhw\" -STR_3M1 := -DSTR=\"3m1\" -STR_4MHW := -DSTR=\"4mhw\" -STR_4M1B := -DSTR=\"4m1b\" -STR_4M1A := -DSTR=\"4m1a\" -STR_1M := -DSTR=\"1m\" -STR_NAT := -DSTR=\"asm\" -STR_OBL := -DSTR=\"openblas\" -STR_MKL := -DSTR=\"mkl\" -STR_ACML := -DSTR=\"acml\" - -# Single or multithreaded string -STR_ST := -DTHR_STR=\"st\" -STR_MT := -DTHR_STR=\"mt\" - -# Problem size specification -PDEF_ST := -DP_BEGIN=56 \ - -DP_END=2800 \ - -DP_INC=56 - -PDEF_MT := -DP_BEGIN=160 \ - -DP_END=8000 \ - -DP_INC=160 - - - -# -# --- Targets/rules ------------------------------------------------------------ -# - -all: all-st all-mt -blis: blis-st blis-mt -openblas: openblas-st openblas-mt -mkl: mkl-st mkl-mt - -all-st: blis-st openblas-st mkl-st -all-mt: blis-mt openblas-mt mkl-mt - -blis-st: blis-nat-st -blis-mt: blis-nat-mt - -blis-ind: blis-ind-st blis-ind-mt -blis-nat: blis-nat-st blis-nat-mt - -blis-ind-st: \ - test_cgemm_3mhw_blis_st.x \ - test_zgemm_3mhw_blis_st.x \ - test_cgemm_3m1_blis_st.x \ - test_zgemm_3m1_blis_st.x \ - test_cgemm_4mhw_blis_st.x \ - test_zgemm_4mhw_blis_st.x \ - test_cgemm_4m1b_blis_st.x \ - test_zgemm_4m1b_blis_st.x \ - test_cgemm_4m1a_blis_st.x \ - test_zgemm_4m1a_blis_st.x \ - test_cgemm_1m_blis_st.x \ - test_zgemm_1m_blis_st.x - -blis-ind-mt: \ - test_cgemm_3mhw_blis_mt.x \ - test_zgemm_3mhw_blis_mt.x \ - test_cgemm_3m1_blis_mt.x \ - test_zgemm_3m1_blis_mt.x \ - test_cgemm_4mhw_blis_mt.x \ - test_zgemm_4mhw_blis_mt.x \ - test_cgemm_4m1b_blis_mt.x \ - test_zgemm_4m1b_blis_mt.x \ - test_cgemm_4m1a_blis_mt.x \ - test_zgemm_4m1a_blis_mt.x \ - test_cgemm_1m_blis_mt.x \ - test_zgemm_1m_blis_mt.x - -blis-nat-st: \ - test_sgemm_asm_blis_st.x \ - test_dgemm_asm_blis_st.x \ - test_cgemm_asm_blis_st.x \ - test_zgemm_asm_blis_st.x \ - test_shemm_asm_blis_st.x \ - test_dhemm_asm_blis_st.x \ - test_chemm_asm_blis_st.x \ - test_zhemm_asm_blis_st.x \ - test_sherk_asm_blis_st.x \ - test_dherk_asm_blis_st.x \ - test_cherk_asm_blis_st.x \ - test_zherk_asm_blis_st.x \ - test_strmm_asm_blis_st.x \ - test_dtrmm_asm_blis_st.x \ - test_ctrmm_asm_blis_st.x \ - test_ztrmm_asm_blis_st.x \ - test_strsm_asm_blis_st.x \ - test_dtrsm_asm_blis_st.x \ - test_ctrsm_asm_blis_st.x \ - test_ztrsm_asm_blis_st.x - -blis-nat-mt: \ - test_sgemm_asm_blis_mt.x \ - test_dgemm_asm_blis_mt.x \ - test_cgemm_asm_blis_mt.x \ - test_zgemm_asm_blis_mt.x \ - test_shemm_asm_blis_mt.x \ - test_dhemm_asm_blis_mt.x \ - test_chemm_asm_blis_mt.x \ - test_zhemm_asm_blis_mt.x \ - test_sherk_asm_blis_mt.x \ - test_dherk_asm_blis_mt.x \ - test_cherk_asm_blis_mt.x \ - test_zherk_asm_blis_mt.x \ - test_strmm_asm_blis_mt.x \ - test_dtrmm_asm_blis_mt.x \ - test_ctrmm_asm_blis_mt.x \ - test_ztrmm_asm_blis_mt.x \ - test_strsm_asm_blis_mt.x \ - test_dtrsm_asm_blis_mt.x \ - test_ctrsm_asm_blis_mt.x \ - test_ztrsm_asm_blis_mt.x - -openblas-st: \ - test_sgemm_openblas_st.x \ - test_dgemm_openblas_st.x \ - test_cgemm_openblas_st.x \ - test_zgemm_openblas_st.x \ - test_shemm_openblas_st.x \ - test_dhemm_openblas_st.x \ - test_chemm_openblas_st.x \ - test_zhemm_openblas_st.x \ - test_sherk_openblas_st.x \ - test_dherk_openblas_st.x \ - test_cherk_openblas_st.x \ - test_zherk_openblas_st.x \ - test_strmm_openblas_st.x \ - test_dtrmm_openblas_st.x \ - test_ctrmm_openblas_st.x \ - test_ztrmm_openblas_st.x \ - test_strsm_openblas_st.x \ - test_dtrsm_openblas_st.x \ - test_ctrsm_openblas_st.x \ - test_ztrsm_openblas_st.x - -openblas-mt: \ - test_sgemm_openblas_mt.x \ - test_dgemm_openblas_mt.x \ - test_cgemm_openblas_mt.x \ - test_zgemm_openblas_mt.x \ - test_shemm_openblas_mt.x \ - test_dhemm_openblas_mt.x \ - test_chemm_openblas_mt.x \ - test_zhemm_openblas_mt.x \ - test_sherk_openblas_mt.x \ - test_dherk_openblas_mt.x \ - test_cherk_openblas_mt.x \ - test_zherk_openblas_mt.x \ - test_strmm_openblas_mt.x \ - test_dtrmm_openblas_mt.x \ - test_ctrmm_openblas_mt.x \ - test_ztrmm_openblas_mt.x \ - test_strsm_openblas_mt.x \ - test_dtrsm_openblas_mt.x \ - test_ctrsm_openblas_mt.x \ - test_ztrsm_openblas_mt.x - -mkl-st: \ - test_sgemm_mkl_st.x \ - test_dgemm_mkl_st.x \ - test_cgemm_mkl_st.x \ - test_zgemm_mkl_st.x \ - test_shemm_mkl_st.x \ - test_dhemm_mkl_st.x \ - test_chemm_mkl_st.x \ - test_zhemm_mkl_st.x \ - test_sherk_mkl_st.x \ - test_dherk_mkl_st.x \ - test_cherk_mkl_st.x \ - test_zherk_mkl_st.x \ - test_strmm_mkl_st.x \ - test_dtrmm_mkl_st.x \ - test_ctrmm_mkl_st.x \ - test_ztrmm_mkl_st.x \ - test_strsm_mkl_st.x \ - test_dtrsm_mkl_st.x \ - test_ctrsm_mkl_st.x \ - test_ztrsm_mkl_st.x - -mkl-mt: \ - test_sgemm_mkl_mt.x \ - test_dgemm_mkl_mt.x \ - test_cgemm_mkl_mt.x \ - test_zgemm_mkl_mt.x \ - test_shemm_mkl_mt.x \ - test_dhemm_mkl_mt.x \ - test_chemm_mkl_mt.x \ - test_zhemm_mkl_mt.x \ - test_sherk_mkl_mt.x \ - test_dherk_mkl_mt.x \ - test_cherk_mkl_mt.x \ - test_zherk_mkl_mt.x \ - test_strmm_mkl_mt.x \ - test_dtrmm_mkl_mt.x \ - test_ctrmm_mkl_mt.x \ - test_ztrmm_mkl_mt.x \ - test_strsm_mkl_mt.x \ - test_dtrsm_mkl_mt.x \ - test_ctrsm_mkl_mt.x \ - test_ztrsm_mkl_mt.x - - - - -# --Object file rules -- - -$(TEST_OBJ_PATH)/%.o: $(TEST_SRC_PATH)/%.c - $(CC) $(CFLAGS) -c $< -o $@ - -# blis 3mhw -test_z%_3mhw_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_Z) $(BLI_DEF) $(D3MHW) $(STR_3MHW) $(STR_ST) -c $< -o $@ - -test_c%_3mhw_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_C) $(BLI_DEF) $(D3MHW) $(STR_3MHW) $(STR_ST) -c $< -o $@ - -test_z%_3mhw_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_Z) $(BLI_DEF) $(D3MHW) $(STR_3MHW) $(STR_MT) -c $< -o $@ - -test_c%_3mhw_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_C) $(BLI_DEF) $(D3MHW) $(STR_3MHW) $(STR_MT) -c $< -o $@ - -# blis 3m1 -test_z%_3m1_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_Z) $(BLI_DEF) $(D3M1) $(STR_3M1) $(STR_ST) -c $< -o $@ - -test_c%_3m1_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_C) $(BLI_DEF) $(D3M1) $(STR_3M1) $(STR_ST) -c $< -o $@ - -test_z%_3m1_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_Z) $(BLI_DEF) $(D3M1) $(STR_3M1) $(STR_MT) -c $< -o $@ - -test_c%_3m1_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_C) $(BLI_DEF) $(D3M1) $(STR_3M1) $(STR_MT) -c $< -o $@ - -# blis 4mhw -test_z%_4mhw_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_Z) $(BLI_DEF) $(D4MHW) $(STR_4MHW) $(STR_ST) -c $< -o $@ - -test_c%_4mhw_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_C) $(BLI_DEF) $(D4MHW) $(STR_4MHW) $(STR_ST) -c $< -o $@ - -test_z%_4mhw_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_Z) $(BLI_DEF) $(D4MHW) $(STR_4MHW) $(STR_MT) -c $< -o $@ - -test_c%_4mhw_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_C) $(BLI_DEF) $(D4MHW) $(STR_4MHW) $(STR_MT) -c $< -o $@ - -# blis 4m1b -test_z%_4m1b_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_Z) $(BLI_DEF) $(D4M1B) $(STR_4M1B) $(STR_ST) -c $< -o $@ - -test_c%_4m1b_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_C) $(BLI_DEF) $(D4M1B) $(STR_4M1B) $(STR_ST) -c $< -o $@ - -test_z%_4m1b_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_Z) $(BLI_DEF) $(D4M1B) $(STR_4M1B) $(STR_MT) -c $< -o $@ - -test_c%_4m1b_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_C) $(BLI_DEF) $(D4M1B) $(STR_4M1B) $(STR_MT) -c $< -o $@ - -# blis 4m1a -test_z%_4m1a_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_Z) $(BLI_DEF) $(D4M1A) $(STR_4M1A) $(STR_ST) -c $< -o $@ - -test_c%_4m1a_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_C) $(BLI_DEF) $(D4M1A) $(STR_4M1A) $(STR_ST) -c $< -o $@ - -test_z%_4m1a_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_Z) $(BLI_DEF) $(D4M1A) $(STR_4M1A) $(STR_MT) -c $< -o $@ - -test_c%_4m1a_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_C) $(BLI_DEF) $(D4M1A) $(STR_4M1A) $(STR_MT) -c $< -o $@ - -# blis 1m -test_z%_1m_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_Z) $(BLI_DEF) $(D1M) $(STR_1M) $(STR_ST) -c $< -o $@ - -test_c%_1m_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_C) $(BLI_DEF) $(D1M) $(STR_1M) $(STR_ST) -c $< -o $@ - -test_z%_1m_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_Z) $(BLI_DEF) $(D1M) $(STR_1M) $(STR_MT) -c $< -o $@ - -test_c%_1m_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_C) $(BLI_DEF) $(D1M) $(STR_1M) $(STR_MT) -c $< -o $@ - -# blis asm -test_d%_asm_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_D) $(BLI_DEF) $(DNAT) $(STR_NAT) $(STR_ST) -c $< -o $@ - -test_s%_asm_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_S) $(BLI_DEF) $(DNAT) $(STR_NAT) $(STR_ST) -c $< -o $@ - -test_z%_asm_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_Z) $(BLI_DEF) $(DNAT) $(STR_NAT) $(STR_ST) -c $< -o $@ - -test_c%_asm_blis_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_C) $(BLI_DEF) $(DNAT) $(STR_NAT) $(STR_ST) -c $< -o $@ - -test_d%_asm_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_D) $(BLI_DEF) $(DNAT) $(STR_NAT) $(STR_MT) -c $< -o $@ - -test_s%_asm_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_S) $(BLI_DEF) $(DNAT) $(STR_NAT) $(STR_MT) -c $< -o $@ - -test_z%_asm_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_Z) $(BLI_DEF) $(DNAT) $(STR_NAT) $(STR_MT) -c $< -o $@ - -test_c%_asm_blis_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_C) $(BLI_DEF) $(DNAT) $(STR_NAT) $(STR_MT) -c $< -o $@ - -# openblas -test_d%_openblas_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_D) $(BLA_DEF) $(DNAT) $(STR_OBL) $(STR_ST) -c $< -o $@ - -test_s%_openblas_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_S) $(BLA_DEF) $(DNAT) $(STR_OBL) $(STR_ST) -c $< -o $@ - -test_z%_openblas_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_Z) $(BLA_DEF) $(DNAT) $(STR_OBL) $(STR_ST) -c $< -o $@ - -test_c%_openblas_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_C) $(BLA_DEF) $(DNAT) $(STR_OBL) $(STR_ST) -c $< -o $@ - -test_d%_openblas_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_D) $(BLA_DEF) $(DNAT) $(STR_OBL) $(STR_MT) -c $< -o $@ - -test_s%_openblas_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_S) $(BLA_DEF) $(DNAT) $(STR_OBL) $(STR_MT) -c $< -o $@ - -test_z%_openblas_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_Z) $(BLA_DEF) $(DNAT) $(STR_OBL) $(STR_MT) -c $< -o $@ - -test_c%_openblas_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_C) $(BLA_DEF) $(DNAT) $(STR_OBL) $(STR_MT) -c $< -o $@ - -# mkl -test_d%_mkl_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_D) $(BLA_DEF) $(DNAT) $(STR_MKL) $(STR_ST) -c $< -o $@ - -test_s%_mkl_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_S) $(BLA_DEF) $(DNAT) $(STR_MKL) $(STR_ST) -c $< -o $@ - -test_z%_mkl_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_Z) $(BLA_DEF) $(DNAT) $(STR_MKL) $(STR_ST) -c $< -o $@ - -test_c%_mkl_st.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_ST) $(DT_C) $(BLA_DEF) $(DNAT) $(STR_MKL) $(STR_ST) -c $< -o $@ - -test_d%_mkl_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_D) $(BLA_DEF) $(DNAT) $(STR_MKL) $(STR_MT) -c $< -o $@ - -test_s%_mkl_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_S) $(BLA_DEF) $(DNAT) $(STR_MKL) $(STR_MT) -c $< -o $@ - -test_z%_mkl_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_Z) $(BLA_DEF) $(DNAT) $(STR_MKL) $(STR_MT) -c $< -o $@ - -test_c%_mkl_mt.o: test_%.c Makefile - $(CC) $(CFLAGS) $(PDEF_MT) $(DT_C) $(BLA_DEF) $(DNAT) $(STR_MKL) $(STR_MT) -c $< -o $@ - - -# -- Executable file rules -- - -# NOTE: For the BLAS test drivers, we place the BLAS libraries before BLIS -# on the link command line in case BLIS was configured with the BLAS -# compatibility layer. This prevents BLIS from inadvertently getting called -# for the BLAS routines we are trying to test with. - -test_%_openblas_st.x: test_%_openblas_st.o $(LIBBLIS_LINK) - $(LINKER) $< $(OPENBLAS_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ - -test_%_openblas_mt.x: test_%_openblas_mt.o $(LIBBLIS_LINK) - $(LINKER) $< $(OPENBLASP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ - -test_%_mkl_st.x: test_%_mkl_st.o $(LIBBLIS_LINK) - $(LINKER) $< $(MKL_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ - -test_%_mkl_mt.x: test_%_mkl_mt.o $(LIBBLIS_LINK) - $(LINKER) $< $(MKLP_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ - -test_%_blis_st.x: test_%_blis_st.o $(LIBBLIS_LINK) - $(LINKER) $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@ - -test_%_blis_mt.x: test_%_blis_mt.o $(LIBBLIS_LINK) - $(LINKER) $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@ - - -# -- Clean rules -- - -clean: cleanx - -cleanx: - - $(RM_F) *.o *.x - diff --git a/test/3m4m/runme.sh b/test/3m4m/runme.sh deleted file mode 100755 index a48cca989..000000000 --- a/test/3m4m/runme.sh +++ /dev/null @@ -1,209 +0,0 @@ -#!/bin/bash - -# File pefixes. -exec_root="test" -out_root="output" - -#sys="blis" -#sys="stampede" -#sys="stampede2" -#sys="lonestar5" -sys="ul252" - -# Bind threads to processors. -#export OMP_PROC_BIND=true -#export GOMP_CPU_AFFINITY="0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15" -#export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7" -#export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7" -#export GOMP_CPU_AFFINITY="0 2 4 6 1 3 5 7" -#export GOMP_CPU_AFFINITY="0 4 1 5 2 6 3 7" -#export GOMP_CPU_AFFINITY="0 1 4 5 8 9 12 13 16 17 20 21 24 25 28 29 32 33 36 37 40 41 44 45" -#export GOMP_CPU_AFFINITY="0 2 4 6 8 10 12 14 16 18 20 22 1 3 5 7 9 11 13 15 17 19 21 23" - -# Modify LD_LIBRARY_PATH. -if [ ${sys} = "blis" ]; then - - export GOMP_CPU_AFFINITY="0 1 2 3" - - jc_nt=1 # 5th loop - ic_nt=4 # 3rd loop - jr_nt=1 # 2nd loop - ir_nt=1 # 1st loop - nt=4 - -elif [ ${sys} = "stampede2" ]; then - - echo "Need to set GOMP_CPU_AFFINITY." - exit 1 - - jc_nt=4 # 5th loop - ic_nt=12 # 3rd loop - jr_nt=1 # 2nd loop - ir_nt=1 # 1st loop - nt=48 - -elif [ ${sys} = "lonestar5" ]; then - - echo "Need to set GOMP_CPU_AFFINITY." - exit 1 - - # A hack to use libiomp5 with gcc. - export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/apps/intel/16.0.1.150/compilers_and_libraries_2016.1.150/linux/compiler/lib/intel64" - - jc_nt=2 # 5th loop - ic_nt=12 # 3rd loop - jr_nt=1 # 2nd loop - ir_nt=1 # 1st loop - nt=24 - -elif [ ${sys} = "ul252" ]; then - - export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/field/intel/mkl/lib/intel64" - #export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103" - export GOMP_CPU_AFFINITY="0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51" - - #jc_nt=4 # 5th loop - jc_nt=2 # 5th loop - ic_nt=13 # 3rd loop - jr_nt=1 # 2nd loop - ir_nt=1 # 1st loop - #nt=52 - nt=26 -fi - -# Save a copy of GOMP_CPU_AFFINITY so that if we have to unset it, we can -# restore the value. -GOMP_CPU_AFFINITYsave=${GOMP_CPU_AFFINITY} - -# Threadedness to test. -threads="mt" -threads_r="mt" -#threads="st" -#threads_r="st" - -# Datatypes to test. -dts="z c" -dts_r="d s" - -# Operations to test. -l3_ops="gemm hemm herk trmm trsm" -test_ops="${l3_ops}" -test_ops_r="${l3_ops}" - -# Complex domain implementations to test. -#test_impls="3mhw_blis 3m1_blis 4mhw_blis 4m1b_blis 4m1a_blis 1m_blis" -#test_impls="openblas mkl asm_blis" - - -# Real domain implementations to test. -test_impls_r="openblas asm_blis mkl" -test_impls="openblas asm_blis mkl" -#test_impls_r="asm_blis openblas" - -# First perform real test cases. -for th in ${threads_r}; do - - for dt in ${dts_r}; do - - for im in ${test_impls_r}; do - - for op in ${test_ops_r}; do - - # Set the number of threads according to th. - if [ ${th} = "mt" ]; then - - export BLIS_JC_NT=${jc_nt} - export BLIS_IC_NT=${ic_nt} - export BLIS_JR_NT=${jr_nt} - export BLIS_IR_NT=${ir_nt} - export OPENBLAS_NUM_THREADS=${nt} - export MKL_NUM_THREADS=${nt} - - # Unset GOMP_CPU_AFFINITY for OpenBLAS. - if [ ${im} = "openblas" ]; then - - unset GOMP_CPU_AFFINITY - else - export GOMP_CPU_AFFINITY=${GOMP_CPU_AFFINITYsave} - fi - else - - export BLIS_JC_NT=1 - export BLIS_IC_NT=1 - export BLIS_JR_NT=1 - export BLIS_IR_NT=1 - export OPENBLAS_NUM_THREADS=1 - export MKL_NUM_THREADS=1 - fi - - # Construct the name of the test executable. - exec_name="${exec_root}_${dt}${op}_${im}_${th}.x" - - # Construct the name of the output file. - out_file="${out_root}_${th}_${dt}${op}_${im}.m" - - echo "Running (nt = ${nt}) ./${exec_name} > ${out_file}" - - # Run executable. - ./${exec_name} > ${out_file} - - sleep 1 - - done - done - done -done - -# Now perform complex test cases. -for th in ${threads}; do - - for dt in ${dts}; do - - for im in ${test_impls}; do - - for op in ${test_ops}; do - - # Set the number of threads according to th. - if [ ${th} = "mt" ]; then - - export BLIS_JC_NT=${jc_nt} - export BLIS_IC_NT=${ic_nt} - export BLIS_JR_NT=${jr_nt} - export BLIS_IR_NT=${ir_nt} - export OPENBLAS_NUM_THREADS=${nt} - export MKL_NUM_THREADS=${nt} - - # Unset GOMP_CPU_AFFINITY for OpenBLAS. - if [ ${im} = "openblas" ]; then - - unset GOMP_CPU_AFFINITY - else - export GOMP_CPU_AFFINITY=${GOMP_CPU_AFFINITYsave} - fi - else - - export BLIS_JC_NT=1 - export BLIS_IC_NT=1 - export BLIS_JR_NT=1 - export BLIS_IR_NT=1 - export OPENBLAS_NUM_THREADS=1 - export MKL_NUM_THREADS=1 - fi - - # Construct the name of the test executable. - exec_name="${exec_root}_${dt}${op}_${im}_${th}.x" - - # Construct the name of the output file. - out_file="${out_root}_${th}_${dt}${op}_${im}.m" - - echo "Running (nt = ${nt}) ./${exec_name} > ${out_file}" - - # Run executable. - ./${exec_name} > ${out_file} - - sleep 1 - - done - done - done -done diff --git a/test/3m4m/test_gemm.c b/test/3m4m/test_gemm.c deleted file mode 100644 index 8ba53d63a..000000000 --- a/test/3m4m/test_gemm.c +++ /dev/null @@ -1,333 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "blis.h" - -//#define PRINT - -int main( int argc, char** argv ) -{ - obj_t a, b, c; - obj_t c_save; - obj_t alpha, beta; - dim_t m, n, k; - dim_t p; - dim_t p_begin, p_end, p_inc; - int m_input, n_input, k_input; - ind_t ind; - num_t dt; - char dt_ch; - int r, n_repeats; - trans_t transa; - trans_t transb; - f77_char f77_transa; - f77_char f77_transb; - - double dtime; - double dtime_save; - double gflops; - - //bli_init(); - - //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); - - n_repeats = 3; - - dt = DT; - - ind = IND; - - p_begin = P_BEGIN; - p_end = P_END; - p_inc = P_INC; - - m_input = -1; - n_input = -1; - k_input = -1; - - - // Supress compiler warnings about unused variable 'ind'. - ( void )ind; - -#if 0 - - cntx_t* cntx; - - ind_t ind_mod = ind; - - // A hack to use 3m1 as 1mpb (with 1m as 1mbp). - if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M; - - // Initialize a context for the current induced method and datatype. - cntx = bli_gks_query_ind_cntx( ind_mod, dt ); - - // Set k to the kc blocksize for the current datatype. - k_input = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); - -#elif 1 - - //k_input = 256; - -#endif - - // Choose the char corresponding to the requested datatype. - if ( bli_is_float( dt ) ) dt_ch = 's'; - else if ( bli_is_double( dt ) ) dt_ch = 'd'; - else if ( bli_is_scomplex( dt ) ) dt_ch = 'c'; - else dt_ch = 'z'; - - transa = BLIS_NO_TRANSPOSE; - transb = BLIS_NO_TRANSPOSE; - - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); - - // Begin with initializing the last entry to zero so that - // matlab allocates space for the entire array once up-front. - for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; -#ifdef BLIS - printf( "data_%s_%cgemm_%s_blis", THR_STR, dt_ch, STR ); -#else - printf( "data_%s_%cgemm_%s", THR_STR, dt_ch, STR ); -#endif - printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, - ( unsigned long )0, - ( unsigned long )0, - ( unsigned long )0, 0.0 ); - - - for ( p = p_begin; p <= p_end; p += p_inc ) - { - - if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); - else m = ( dim_t ) m_input; - if ( n_input < 0 ) n = p / ( dim_t )abs(n_input); - else n = ( dim_t ) n_input; - if ( k_input < 0 ) k = p / ( dim_t )abs(k_input); - else k = ( dim_t ) k_input; - - bli_obj_create( dt, 1, 1, 0, 0, &alpha ); - bli_obj_create( dt, 1, 1, 0, 0, &beta ); - - bli_obj_create( dt, m, k, 0, 0, &a ); - bli_obj_create( dt, k, n, 0, 0, &b ); - bli_obj_create( dt, m, n, 0, 0, &c ); - bli_obj_create( dt, m, n, 0, 0, &c_save ); - - bli_randm( &a ); - bli_randm( &b ); - bli_randm( &c ); - - bli_obj_set_conjtrans( transa, &a ); - bli_obj_set_conjtrans( transb, &b ); - - bli_setsc( (2.0/1.0), 0.0, &alpha ); - bli_setsc( (1.0/1.0), 0.0, &beta ); - - - bli_copym( &c, &c_save ); - -#if 0 //def BLIS - bli_ind_disable_all_dt( dt ); - bli_ind_enable_dt( ind, dt ); -#endif - - dtime_save = DBL_MAX; - - for ( r = 0; r < n_repeats; ++r ) - { - bli_copym( &c_save, &c ); - - dtime = bli_clock(); - - -#ifdef PRINT - bli_printm( "a", &a, "%4.1f", "" ); - bli_printm( "b", &b, "%4.1f", "" ); - bli_printm( "c", &c, "%4.1f", "" ); -#endif - -#ifdef BLIS - - bli_gemm( &alpha, - &a, - &b, - &beta, - &c ); - -#else - - if ( bli_is_float( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* bp = bli_obj_buffer( &b ); - float* betap = bli_obj_buffer( &beta ); - float* cp = bli_obj_buffer( &c ); - - sgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, &lda, - bp, &ldb, - betap, - cp, &ldc ); - } - else if ( bli_is_double( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* bp = bli_obj_buffer( &b ); - double* betap = bli_obj_buffer( &beta ); - double* cp = bli_obj_buffer( &c ); - - dgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, &lda, - bp, &ldb, - betap, - cp, &ldc ); - } - else if ( bli_is_scomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* bp = bli_obj_buffer( &b ); - scomplex* betap = bli_obj_buffer( &beta ); - scomplex* cp = bli_obj_buffer( &c ); - - cgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, &lda, - bp, &ldb, - betap, - cp, &ldc ); - } - else if ( bli_is_dcomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldb = bli_obj_col_stride( &b ); - f77_int ldc = bli_obj_col_stride( &c ); - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* bp = bli_obj_buffer( &b ); - dcomplex* betap = bli_obj_buffer( &beta ); - dcomplex* cp = bli_obj_buffer( &c ); - - zgemm_( &f77_transa, - //zgemm3m_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, &lda, - bp, &ldb, - betap, - cp, &ldc ); - } -#endif - -#ifdef PRINT - bli_printm( "c after", &c, "%4.1f", "" ); - exit(1); -#endif - - - dtime_save = bli_clock_min_diff( dtime_save, dtime ); - } - - gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); - - if ( bli_is_complex( dt ) ) gflops *= 4.0; - -#ifdef BLIS - printf( "data_%s_%cgemm_%s_blis", THR_STR, dt_ch, STR ); -#else - printf( "data_%s_%cgemm_%s", THR_STR, dt_ch, STR ); -#endif - printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin + 1)/p_inc + 1, - ( unsigned long )m, - ( unsigned long )k, - ( unsigned long )n, gflops ); - - bli_obj_free( &alpha ); - bli_obj_free( &beta ); - - bli_obj_free( &a ); - bli_obj_free( &b ); - bli_obj_free( &c ); - bli_obj_free( &c_save ); - } - - //bli_finalize(); - - return 0; -} - diff --git a/test/Makefile b/test/Makefile index 53d1f9803..732ef0dd0 100644 --- a/test/Makefile +++ b/test/Makefile @@ -150,7 +150,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) diff --git a/test/exec_sizes/Makefile b/test/exec_sizes/Makefile index ca8486353..eefc89918 100644 --- a/test/exec_sizes/Makefile +++ b/test/exec_sizes/Makefile @@ -143,7 +143,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) diff --git a/test/mixeddt/Makefile b/test/mixeddt/Makefile index 7ae4cb934..87568825a 100644 --- a/test/mixeddt/Makefile +++ b/test/mixeddt/Makefile @@ -127,7 +127,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Which library? diff --git a/test/studies/skx/Makefile b/test/studies/skx/Makefile index 29134a4ff..83c29f876 100644 --- a/test/studies/skx/Makefile +++ b/test/studies/skx/Makefile @@ -5,6 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -168,7 +169,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -g -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Datatype diff --git a/test/studies/skx/test_syrk.c b/test/studies/skx/test_syrk.c index 5e1c43159..074d1e708 100644 --- a/test/studies/skx/test_syrk.c +++ b/test/studies/skx/test_syrk.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/test/studies/skx/test_trmm.c b/test/studies/skx/test_trmm.c index 1c7db7956..85cdee37f 100644 --- a/test/studies/skx/test_trmm.c +++ b/test/studies/skx/test_trmm.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/test/studies/thunderx2/Makefile b/test/studies/thunderx2/Makefile index c812161d4..ba45ebbe4 100644 --- a/test/studies/thunderx2/Makefile +++ b/test/studies/thunderx2/Makefile @@ -158,7 +158,7 @@ CFLAGS := $(call get-frame-cflags-for,$(CONFIG_NAME)) CFLAGS += -g -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -lIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Datatype diff --git a/test/studies/thunderx2/test_trmm.c b/test/studies/thunderx2/test_trmm.c index 0fb153444..2f2c12386 100644 --- a/test/studies/thunderx2/test_trmm.c +++ b/test/studies/thunderx2/test_trmm.c @@ -4,7 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/test/sup/Makefile b/test/sup/Makefile new file mode 100644 index 000000000..d2b3c7170 --- /dev/null +++ b/test/sup/Makefile @@ -0,0 +1,460 @@ +#!/bin/bash +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2019, Advanced Micro Devices, Inc. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# +# Makefile +# +# Field G. Van Zee +# +# Makefile for standalone BLIS test drivers. +# + +# +# --- Makefile PHONY target definitions ---------------------------------------- +# + +.PHONY: all all-st all-mt \ + blis blis-st blis-mt \ + clean cleanx + + + +# +# --- Determine makefile fragment location ------------------------------------- +# + +# Comments: +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +ifneq ($(strip $(BLIS_INSTALL_PATH)),) +LIB_PATH := $(BLIS_INSTALL_PATH)/lib +INC_PATH := $(BLIS_INSTALL_PATH)/include/blis +SHARE_PATH := $(BLIS_INSTALL_PATH)/share/blis +else +DIST_PATH := ../.. +LIB_PATH = ../../lib/$(CONFIG_NAME) +INC_PATH = ../../include/$(CONFIG_NAME) +SHARE_PATH := ../.. +endif + + + +# +# --- Include common makefile definitions -------------------------------------- +# + +# Include the common makefile fragment. +-include $(SHARE_PATH)/common.mk + + + +# +# --- BLAS and LAPACK implementations ------------------------------------------ +# + +# BLIS library and header path. This is simply wherever it was installed. +#BLIS_LIB_PATH := $(INSTALL_PREFIX)/lib +#BLIS_INC_PATH := $(INSTALL_PREFIX)/include/blis + +# BLIS library. +#BLIS_LIB := $(BLIS_LIB_PATH)/libblis.a + +# BLAS library path(s). This is where the BLAS libraries reside. +HOME_LIB_PATH := $(HOME)/flame/lib +MKL_LIB_PATH := $(HOME)/intel/mkl/lib/intel64 + +# OpenBLAS +OPENBLAS_LIB := $(HOME_LIB_PATH)/libopenblas.a +OPENBLASP_LIB := $(HOME_LIB_PATH)/libopenblasp.a + +# BLASFEO +BLASFEO_LIB := $(HOME_LIB_PATH)/libblasfeo.a + +# ATLAS +ATLAS_LIB := $(HOME_LIB_PATH)/libf77blas.a \ + $(HOME_LIB_PATH)/libatlas.a + +# Eigen +EIGEN_INC := $(HOME)/flame/eigen/include/eigen3 +EIGEN_LIB := $(HOME_LIB_PATH)/libeigen_blas_static.a +EIGENP_LIB := $(EIGEN_LIB) + +# MKL +MKL_LIB := -L$(MKL_LIB_PATH) \ + -lmkl_intel_lp64 \ + -lmkl_core \ + -lmkl_sequential \ + -lpthread -lm -ldl +MKLP_LIB := -L$(MKL_LIB_PATH) \ + -lmkl_intel_lp64 \ + -lmkl_core \ + -lmkl_gnu_thread \ + -lpthread -lm -ldl -fopenmp + #-L$(ICC_LIB_PATH) \ + #-lgomp + +VENDOR_LIB := $(MKL_LIB) +VENDORP_LIB := $(MKLP_LIB) + + +# +# --- Problem size definitions ------------------------------------------------- +# + +# Single core +PS_BEGIN := 4 +PS_MAX := 800 +PS_INC := 4 + +# Multicore +P1_BEGIN := 120 +P1_MAX := 6000 +P1_INC := 120 + + +# +# --- General build definitions ------------------------------------------------ +# + +TEST_SRC_PATH := . +TEST_OBJ_PATH := . + +# Gather all local object files. +TEST_OBJS := $(sort $(patsubst $(TEST_SRC_PATH)/%.c, \ + $(TEST_OBJ_PATH)/%.o, \ + $(wildcard $(TEST_SRC_PATH)/*.c))) + +# Override the value of CINCFLAGS so that the value of CFLAGS returned by +# get-frame-cflags-for() is not cluttered up with include paths needed only +# while building BLIS. +CINCFLAGS := -I$(INC_PATH) + +# Use the "framework" CFLAGS for the configuration family. +CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) + +# Add local header paths to CFLAGS. +CFLAGS += -I$(TEST_SRC_PATH) + +# Locate the libblis library to which we will link. +LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) + +# Define a set of CFLAGS for use with C++ and Eigen. +CXXFLAGS := $(subst -std=c99,-std=c++11,$(CFLAGS)) +CXXFLAGS += -I$(EIGEN_INC) + +# Create a copy of CXXFLAGS without -fopenmp in order to disable multithreading. +CXXFLAGS_ST := -march=native $(subst -fopenmp,,$(CXXFLAGS)) +CXXFLAGS_MT := -march=native $(CXXFLAGS) + +# Single or multithreaded string +STR_ST := -DTHR_STR=\"st\" +STR_MT := -DTHR_STR=\"mt\" + +# Number of trials per problem size. +N_TRIALS := -DN_TRIALS=3 + +# Problem size specification +PDEF_ST := -DP_BEGIN=$(PS_BEGIN) \ + -DP_MAX=$(PS_MAX) \ + -DP_INC=$(PS_INC) + +PDEF_MT := -DP_BEGIN=$(P1_BEGIN) \ + -DP_MAX=$(P1_MAX) \ + -DP_INC=$(P1_INC) + +ifeq ($(E),1) +ERRCHK := -DERROR_CHECK +else +ERRCHK := -DNO_ERROR_CHECK +endif + +# Enumerate possible datatypes and computation precisions. +#dts := s d c z +DTS := d + +TRANS := n_n \ + n_t \ + t_n \ + t_t + +STORS := r_r_r \ + r_r_c \ + r_c_r \ + r_c_c \ + c_r_r \ + c_r_c \ + c_c_r \ + c_c_c + +SHAPES := l_l_s \ + l_s_l \ + s_l_l \ + s_s_l \ + s_l_s \ + l_s_s \ + l_l_l + +SMS := 6 +SNS := 8 +SKS := 4 + + +# +# --- Function definitions ----------------------------------------------------- +# + +# A function to strip the underscores from a list of strings. +stripu = $(subst _,,$(1)) + +# Various functions that help us construct the datatype combinations and then +# extract the needed datatype strings and C preprocessor define flags. +get-1of2 = $(word 1,$(subst _, ,$(1))) +get-2of2 = $(word 2,$(subst _, ,$(1))) + +get-1of3 = $(word 1,$(subst _, ,$(1))) +get-2of3 = $(word 2,$(subst _, ,$(1))) +get-3of3 = $(word 3,$(subst _, ,$(1))) + +# Datatype defs. +get-dt-cpp = $(strip \ + $(if $(findstring s,$(1)),-DDT=BLIS_FLOAT -DIS_FLOAT,\ + $(if $(findstring d,$(1)),-DDT=BLIS_DOUBLE -DIS_DOUBLE,\ + $(if $(findstring c,$(1)),-DDT=BLIS_SCOMPLEX -DIS_SCOMPLEX,\ + -DDT=BLIS_DCOMPLEX -DIS_DCOMPLEX)))) + +# Transpose defs. +get-tra-defs-a = $(strip $(subst n,-DTRANSA=BLIS_NO_TRANSPOSE -DA_NOTRANS, \ + $(subst t,-DTRANSA=BLIS_TRANSPOSE -DA_TRANS,$(call get-1of2,$(1))))) +get-tra-defs-b = $(strip $(subst n,-DTRANSB=BLIS_NO_TRANSPOSE -DB_NOTRANS, \ + $(subst t,-DTRANSB=BLIS_TRANSPOSE -DB_TRANS,$(call get-2of2,$(1))))) +get-tra-defs = $(call get-tra-defs-a,$(1)) $(call get-tra-defs-b,$(1)) + +# Storage defs. +get-sto-uch-a = $(strip $(subst r,R, \ + $(subst c,C,$(call get-1of3,$(1))))) +get-sto-uch-b = $(strip $(subst r,R, \ + $(subst c,C,$(call get-2of3,$(1))))) +get-sto-uch-c = $(strip $(subst r,R, \ + $(subst c,C,$(call get-3of3,$(1))))) +get-sto-defs = $(strip \ + -DSTOR3=BLIS_$(call get-sto-uch-a,$(1))$(call get-sto-uch-b,$(1))$(call get-sto-uch-c,$(1)) \ + -DA_STOR_$(call get-sto-uch-a,$(1)) \ + -DB_STOR_$(call get-sto-uch-b,$(1)) \ + -DC_STOR_$(call get-sto-uch-c,$(1))) + +# Dimension defs. +get-shape-defs-cm = $(if $(findstring l,$(1)),-DM_DIM=-1,-DM_DIM=$(2)) +get-shape-defs-cn = $(if $(findstring l,$(1)),-DN_DIM=-1,-DN_DIM=$(2)) +get-shape-defs-ck = $(if $(findstring l,$(1)),-DK_DIM=-1,-DK_DIM=$(2)) +get-shape-defs-m = $(call get-shape-defs-cm,$(call get-1of3,$(1)),$(2)) +get-shape-defs-n = $(call get-shape-defs-cn,$(call get-2of3,$(1)),$(2)) +get-shape-defs-k = $(call get-shape-defs-ck,$(call get-3of3,$(1)),$(2)) + +# arguments: 1: shape (w/ underscores) 2: smallm 3: smalln 4: smallk +get-shape-defs = $(strip $(call get-shape-defs-m,$(1),$(2)) \ + $(call get-shape-defs-n,$(1),$(3)) \ + $(call get-shape-defs-k,$(1),$(4))) + +#$(error l_l_s 6 8 4 = $(call get-shape-defs,l_l_s,6,8,4)) + +# Shape-dimension string. +get-shape-str-ch = $(if $(findstring l,$(1)),p,$(2)) +get-shape-str-m = $(call get-shape-str-ch,$(call get-1of3,$(1)),$(2)) +get-shape-str-n = $(call get-shape-str-ch,$(call get-2of3,$(1)),$(2)) +get-shape-str-k = $(call get-shape-str-ch,$(call get-3of3,$(1)),$(2)) + +# arguments: 1: shape (w/ underscores) 2: smallm 3: smalln 4: smallk +get-shape-dim-str = m$(call get-shape-str-m,$(1),$(2))n$(call get-shape-str-n,$(1),$(3))k$(call get-shape-str-k,$(1),$(4)) + +# Implementation defs. +# Define a function to return the appropriate -DSTR= and -D[BLIS|BLAS] flags. +get-imp-defs = $(strip $(subst blissup,-DSTR=\"$(1)\" -DBLIS -DSUP, \ + $(subst blislpab,-DSTR=\"$(1)\" -DBLIS, \ + $(subst eigen,-DSTR=\"$(1)\" -DEIGEN, \ + $(subst openblas,-DSTR=\"$(1)\" -DCBLAS, \ + $(subst blasfeo,-DSTR=\"$(1)\" -DCBLAS, \ + $(subst vendor,-DSTR=\"$(1)\" -DCBLAS,$(1)))))))) + +TRANS0 = $(call stripu,$(TRANS)) +STORS0 = $(call stripu,$(STORS)) + +# Limit BLAS and Eigen to only using all row-stored, or all column-stored matrices. +BSTORS0 = rrr ccc +ESTORS0 = rrr ccc + + +# +# --- Object and binary file definitons ---------------------------------------- +# + +get-st-objs = $(foreach dt,$(1),$(foreach tr,$(2),$(foreach st,$(3),$(foreach sh,$(4),$(foreach sm,$(5),$(foreach sn,$(6),$(foreach sk,$(7),test_$(dt)gemm_$(tr)_$(st)_$(call get-shape-dim-str,$(sh),$(sm),$(sn),$(sk))_$(8)_st.o))))))) + +# Build a list of object files and binaries for each single-threaded +# implementation using the get-st-objs() function defined above. +BLISSUP_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(STORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),blissup) +BLISSUP_ST_BINS := $(patsubst %.o,%.x,$(BLISSUP_ST_OBJS)) + +BLISLPAB_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(STORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),blislpab) +BLISLPAB_ST_BINS := $(patsubst %.o,%.x,$(BLISLPAB_ST_OBJS)) + +EIGEN_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(ESTORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),eigen) +EIGEN_ST_BINS := $(patsubst %.o,%.x,$(EIGEN_ST_OBJS)) + +OPENBLAS_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),openblas) +OPENBLAS_ST_BINS := $(patsubst %.o,%.x,$(OPENBLAS_ST_OBJS)) + +BLASFEO_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),blasfeo) +BLASFEO_ST_BINS := $(patsubst %.o,%.x,$(BLASFEO_ST_OBJS)) + +VENDOR_ST_OBJS := $(call get-st-objs,$(DTS),$(TRANS0),$(BSTORS0),$(SHAPES),$(SMS),$(SNS),$(SKS),vendor) +VENDOR_ST_BINS := $(patsubst %.o,%.x,$(VENDOR_ST_OBJS)) + +#$(error "objs = $(EIGEN_ST_BINS)" ) + +# Mark the object files as intermediate so that make will remove them +# automatically after building the binaries on which they depend. +.INTERMEDIATE: $(BLISSUP_ST_OBJS) \ + $(BLISLPAB_ST_OBJS) \ + $(EIGEN_ST_OBJS) \ + $(OPENBLAS_ST_OBJS) \ + $(BLASFEO_ST_OBJS) \ + $(VENDOR_ST_OBJS) + + +# +# --- Targets/rules ------------------------------------------------------------ +# + +all: st + +blissup: blissup-st +blislpab: blislpab-st +eigen: eigen-st +openblas: openblas-st +blasfeo: blasfeo-st +vendor: vendor-st + +st: blissup-st blislpab-st eigen-st openblas-st blasfeo-st vendor-st +blis: blissup-st blislpab-st + +blissup-st: $(BLISSUP_ST_BINS) +blislpab-st: $(BLISLPAB_ST_BINS) +eigen-st: $(EIGEN_ST_BINS) +openblas-st: $(OPENBLAS_ST_BINS) +blasfeo-st: $(BLASFEO_ST_BINS) +vendor-st: $(VENDOR_ST_BINS) + + +# --Object file rules -- + +# Define the implementations for which we will instantiate compilation rules. +BIMPLS := blissup blislpab openblas blasfeo vendor +EIMPLS := eigen + +# 1 2 3 4 567 8 +# test_dgemm_nn_rrr_mpn6kp_blissup_st.x + +# Define the function that will be used to instantiate compilation rules +# for the various implementations. +define make-st-rule +test_$(1)gemm_$(call stripu,$(2))_$(call stripu,$(3))_$(call get-shape-dim-str,$(4),$(5),$(6),$(7))_$(8)_st.o: test_gemm.c Makefile + $(CC) $(CFLAGS) $(ERRCHK) $(N_TRIALS) $(PDEF_ST) $(call get-dt-cpp,$(1)) $(call get-tra-defs,$(2)) $(call get-sto-defs,$(3)) $(call get-shape-defs,$(4),$(5),$(6),$(7)) $(call get-imp-defs,$(8)) $(STR_ST) -c $$< -o $$@ +endef + +# Instantiate the rule function make-st-rule() for each BLIS/BLAS/CBLAS +# implementation. +$(foreach dt,$(DTS), \ +$(foreach tr,$(TRANS), \ +$(foreach st,$(STORS), \ +$(foreach sh,$(SHAPES), \ +$(foreach sm,$(SMS), \ +$(foreach sn,$(SNS), \ +$(foreach sk,$(SKS), \ +$(foreach impl,$(BIMPLS), \ +$(eval $(call make-st-rule,$(dt),$(tr),$(st),$(sh),$(sm),$(sn),$(sk),$(impl))))))))))) + +# Define the function that will be used to instantiate compilation rules +# for the various implementations. +define make-eigst-rule +test_$(1)gemm_$(call stripu,$(2))_$(call stripu,$(3))_$(call get-shape-dim-str,$(4),$(5),$(6),$(7))_$(8)_st.o: test_gemm.c Makefile + $(CXX) $(CXXFLAGS_ST) $(ERRCHK) $(N_TRIALS) $(PDEF_ST) $(call get-dt-cpp,$(1)) $(call get-tra-defs,$(2)) $(call get-sto-defs,$(3)) $(call get-shape-defs,$(4),$(5),$(6),$(7)) $(call get-imp-defs,$(8)) $(STR_ST) -c $$< -o $$@ +endef + +# Instantiate the rule function make-st-rule() for each Eigen implementation. +$(foreach dt,$(DTS), \ +$(foreach tr,$(TRANS), \ +$(foreach st,$(STORS), \ +$(foreach sh,$(SHAPES), \ +$(foreach sm,$(SMS), \ +$(foreach sn,$(SNS), \ +$(foreach sk,$(SKS), \ +$(foreach impl,$(EIMPLS), \ +$(eval $(call make-eigst-rule,$(dt),$(tr),$(st),$(sh),$(sm),$(sn),$(sk),$(impl))))))))))) + + +# -- Executable file rules -- + +# NOTE: For the BLAS test drivers, we place the BLAS libraries before BLIS +# on the link command line in case BLIS was configured with the BLAS +# compatibility layer. This prevents BLIS from inadvertently getting called +# for the BLAS routines we are trying to test with. + +test_%_blissup_st.x: test_%_blissup_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_blislpab_st.x: test_%_blislpab_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_eigen_st.x: test_%_eigen_st.o $(LIBBLIS_LINK) + $(CXX) $(strip $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_openblas_st.x: test_%_openblas_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(OPENBLAS_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_blasfeo_st.x: test_%_blasfeo_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(BLASFEO_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + +test_%_vendor_st.x: test_%_vendor_st.o $(LIBBLIS_LINK) + $(CC) $(strip $< $(VENDOR_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $@) + + +# -- Clean rules -- + +clean: cleanx + +cleanx: + - $(RM_F) *.x + diff --git a/test/sup/octave/gen_opsupnames.m b/test/sup/octave/gen_opsupnames.m new file mode 100644 index 000000000..a1a226a76 --- /dev/null +++ b/test/sup/octave/gen_opsupnames.m @@ -0,0 +1,37 @@ +function [ r_val1, r_val2 ] = gen_opsupnames( ops, stor, smalldims ) + +nops = size( ops, 1 ); + +smallm = smalldims( 1 ); +smalln = smalldims( 2 ); +smallk = smalldims( 3 ); + +i = 1; + +for io = 1:nops + + op = ops( io, : ); + + opsupnames( i+0, : ) = sprintf( '%s_%s_m%dnpkp', op, stor, smallm ); + opsupnames( i+1, : ) = sprintf( '%s_%s_mpn%dkp', op, stor, smalln ); + opsupnames( i+2, : ) = sprintf( '%s_%s_mpnpk%d', op, stor, smallk ); + opsupnames( i+3, : ) = sprintf( '%s_%s_mpn%dk%d', op, stor, smalln, smallk ); + opsupnames( i+4, : ) = sprintf( '%s_%s_m%dnpk%d', op, stor, smallm, smallk ); + opsupnames( i+5, : ) = sprintf( '%s_%s_m%dn%dkp', op, stor, smallm, smalln ); + opsupnames( i+6, : ) = sprintf( '%s_%s_mpnpkp', op, stor ); + + opnames( i+0, : ) = sprintf( '%s', op ); + opnames( i+1, : ) = sprintf( '%s', op ); + opnames( i+2, : ) = sprintf( '%s', op ); + opnames( i+3, : ) = sprintf( '%s', op ); + opnames( i+4, : ) = sprintf( '%s', op ); + opnames( i+5, : ) = sprintf( '%s', op ); + opnames( i+6, : ) = sprintf( '%s', op ); + + i = i + 7; +end + +r_val1 = opsupnames; +r_val2 = opnames; + +end diff --git a/test/sup/octave/plot_l3sup_perf.m b/test/sup/octave/plot_l3sup_perf.m new file mode 100644 index 000000000..327a152ab --- /dev/null +++ b/test/sup/octave/plot_l3sup_perf.m @@ -0,0 +1,262 @@ +function r_val = plot_l3sup_perf( opname, ... + data_blissup, ... + data_blislpab, ... + data_eigen, ... + data_open, ... + data_bfeo, ... + data_vend, vend_str, ... + nth, ... + rows, cols, ... + cfreq, ... + dfps, ... + theid, impl ) +%if ... %mod(theid-1,cols) == 2 || ... +% ... %mod(theid-1,cols) == 3 || ... +% ... %mod(theid-1,cols) == 4 || ... +% 0 == 1 ... %theid >= 19 +% show_plot = 0; +%else + show_plot = 1; +%end + +%legend_plot_id = 11; +legend_plot_id = 1*cols + 1*5; + +if 1 +ax1 = subplot( rows, cols, theid ); +hold( ax1, 'on' ); +end + +% Set line properties. +color_blissup = 'k'; lines_blissup = '-'; markr_blissup = ''; +color_blislpab = 'k'; lines_blislpab = ':'; markr_blislpab = ''; +color_eigen = 'm'; lines_eigen = '-.'; markr_eigen = 'o'; +color_open = 'r'; lines_open = '--'; markr_open = 'o'; +color_bfeo = 'c'; lines_bfeo = '-'; markr_bfeo = 'o'; +color_vend = 'b'; lines_vend = '-.'; markr_vend = '.'; + +% Compute the peak performance in terms of the number of double flops +% executable per cycle and the clock rate. +if opname(1) == 's' || opname(1) == 'c' + flopspercycle = dfps * 2; +else + flopspercycle = dfps; +end +max_perf_core = (flopspercycle * cfreq) * 1; + +% Escape underscores in the title. +title_opname = strrep( opname, '_', '\_' ); + +% Print the title to a string. +titlename = '%s'; +titlename = sprintf( titlename, title_opname ); + +% Set the legend strings. +blissup_legend = sprintf( 'BLIS sup' ); +blislpab_legend = sprintf( 'BLIS conv' ); +eigen_legend = sprintf( 'Eigen' ); +open_legend = sprintf( 'OpenBLAS' ); +bfeo_legend = sprintf( 'BLASFEO' ); +%vend_legend = sprintf( 'MKL' ); +%vend_legend = sprintf( 'ARMPL' ); +vend_legend = vend_str; + +% Set axes range values. +y_scale = 1.00; +x_begin = 0; +%x_end is set below. +y_begin = 0; +y_end = max_perf_core * y_scale; + +% Set axes names. +if nth == 1 + yaxisname = 'GFLOPS'; +else + yaxisname = 'GFLOPS/core'; +end + + +%flopscol = 4; +flopscol = size( data_blissup, 2 ); +msize = 5; +if 1 +fontsize = 11; +else +fontsize = 16; +end +linesize = 0.5; +legend_loc = 'southeast'; + +% -------------------------------------------------------------------- + +% Automatically detect a column with the increasing problem size. +% Then set the maximum x-axis value. +for psize_col = 1:3 + if data_blissup( 1, psize_col ) ~= data_blissup( 2, psize_col ) + break; + end +end +x_axis( :, 1 ) = data_blissup( :, psize_col ); + +% Compute the number of data points we have in the x-axis. Note that +% we only use half the data points for the m = n = k column of graphs. +if mod(theid-1,cols) == 6 + np = size( data_blissup, 1 ) / 2; +else + np = size( data_blissup, 1 ); +end + +% Grab the last x-axis value. +x_end = data_blissup( np, psize_col ); + +%data_peak( 1, 1:2 ) = [ 0 max_perf_core ]; +%data_peak( 2, 1:2 ) = [ x_end max_perf_core ]; + +if show_plot == 1 +blissup_ln = line( x_axis( 1:np, 1 ), data_blissup( 1:np, flopscol ) / nth, ... + 'Color',color_blissup, 'LineStyle',lines_blissup, ... + 'LineWidth',linesize ); +blislpab_ln = line( x_axis( 1:np, 1 ), data_blislpab( 1:np, flopscol ) / nth, ... + 'Color',color_blislpab, 'LineStyle',lines_blislpab, ... + 'LineWidth',linesize ); +eigen_ln = line( x_axis( 1:np, 1 ), data_eigen( 1:np, flopscol ) / nth, ... + 'Color',color_eigen, 'LineStyle',lines_eigen, ... + 'LineWidth',linesize ); +open_ln = line( x_axis( 1:np, 1 ), data_open( 1:np, flopscol ) / nth, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +bfeo_ln = line( x_axis( 1:np, 1 ), data_bfeo( 1:np, flopscol ) / nth, ... + 'Color',color_bfeo, 'LineStyle',lines_bfeo, ... + 'LineWidth',linesize ); +vend_ln = line( x_axis( 1:np, 1 ), data_vend( 1:np, flopscol ) / nth, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +else +if theid == legend_plot_id +blissup_ln = line( nan, nan, ... + 'Color',color_blissup, 'LineStyle',lines_blissup, ... + 'LineWidth',linesize ); +blislpab_ln = line( nan, nan, ... + 'Color',color_blislpab, 'LineStyle',lines_blislpab, ... + 'LineWidth',linesize ); +eigen_ln = line( nan, nan, ... + 'Color',color_eigen, 'LineStyle',lines_eigen, ... + 'LineWidth',linesize ); +open_ln = line( nan, nan, ... + 'Color',color_open, 'LineStyle',lines_open, ... + 'LineWidth',linesize ); +bfeo_ln = line( nan, nan, ... + 'Color',color_bfeo, 'LineStyle',lines_bfeo, ... + 'LineWidth',linesize ); +vend_ln = line( nan, nan, ... + 'Color',color_vend, 'LineStyle',lines_vend, ... + 'LineWidth',linesize ); +end +end + + +xlim( ax1, [x_begin x_end] ); +ylim( ax1, [y_begin y_end] ); + +if 6000 <= x_end && x_end < 10000 + x_tick2 = x_end - 2000; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 4000 <= x_end && x_end < 6000 + x_tick2 = x_end - 1000; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 2000 <= x_end && x_end < 3000 + x_tick2 = x_end - 400; + x_tick1 = x_tick2/2; + xticks( ax1, [ x_tick1 x_tick2 ] ); +elseif 500 <= x_end && x_end < 1000 + x_tick3 = x_end*(3/4); + x_tick2 = x_end*(2/4); + x_tick1 = x_end*(1/4); + xticks( ax1, [ x_tick1 x_tick2 x_tick3 ] ); +end + +if show_plot == 1 || theid == legend_plot_id +if rows == 4 && cols == 7 + if nth == 1 && theid == legend_plot_id + leg = legend( ... + [ ... + blissup_ln ... + blislpab_ln ... + eigen_ln ... + open_ln ... + bfeo_ln ... + vend_ln ... + ], ... + blissup_legend, ... + blislpab_legend, ... + eigen_legend, ... + open_legend, ... + bfeo_legend, ... + vend_legend, ... + 'Location', legend_loc ); + set( leg,'Box','off' ); + set( leg,'Color','none' ); + set( leg,'Units','inches' ); + % xpos ypos + %set( leg,'Position',[11.32 6.36 1.15 0.7 ] ); % (1,4tl) + if impl == 'octave' + set( leg,'FontSize',fontsize ); + set( leg,'Position',[11.92 6.54 1.15 0.7 ] ); % (1,4tl) + else + set( leg,'FontSize',fontsize-1 ); + set( leg,'Position',[18.24 10.15 1.15 0.7 ] ); % (1,4tl) + end + elseif nth > 1 && theid == legend_plot_id + end +end +end + +set( ax1,'FontSize',fontsize ); +set( ax1,'TitleFontSizeMultiplier',1.0 ); % default is 1.1. +box( ax1, 'on' ); + +titl = title( titlename ); +set( titl, 'FontWeight', 'normal' ); % default font style is now 'bold'. + +if impl == 'octave' +tpos = get( titl, 'Position' ); % default is to align across whole figure, not box. +tpos(1) = tpos(1) + -40; +set( titl, 'Position', tpos ); % here we nudge it back to centered with box. +end + +if theid > (rows-1)*cols +%xlab = xlabel( ax1,xaxisname ); +%tpos = get( xlab, 'Position' ) +%tpos(2) = tpos(2) + 10; +%set( xlab, 'Position', tpos ); + if theid == rows*cols - 6 + xlab = xlabel( ax1, 'm = 6; n = k' ); + elseif theid == rows*cols - 5 + xlab = xlabel( ax1, 'n = 8; m = k' ); + elseif theid == rows*cols - 4 + xlab = xlabel( ax1, 'k = 4; m = n' ); + elseif theid == rows*cols - 3 + xlab = xlabel( ax1, 'm; n = 8, k = 4' ); + elseif theid == rows*cols - 2 + xlab = xlabel( ax1, 'n; m = 6, k = 4' ); + elseif theid == rows*cols - 1 + xlab = xlabel( ax1, 'k; m = 6, n = 8' ); + elseif theid == rows*cols - 0 + xlab = xlabel( ax1, 'm = n = k' ); + end +end + +if mod(theid-1,cols) == 0 +ylab = ylabel( ax1,yaxisname ); +end + +%export_fig( filename, colorflag, '-pdf', '-m2', '-painters', '-transparent' ); +%saveas( fig, filename_png ); + +%hold( ax1, 'off' ); + +r_val = 0; + +end diff --git a/test/sup/octave/plot_panel_trxsh.m b/test/sup/octave/plot_panel_trxsh.m new file mode 100644 index 000000000..e5d282bc8 --- /dev/null +++ b/test/sup/octave/plot_panel_trxsh.m @@ -0,0 +1,160 @@ +function r_val = plot_panel_trxsh ... + ( ... + cfreq, ... + dflopspercycle, ... + nth, ... + thr_str, ... + dt_ch, ... + stor_str, ... + smalldims, ... + dirpath, ... + arch_str, ... + vend_str, ... + impl ... + ) + +%cfreq = 1.8; +%dflopspercycle = 32; + +% Create filename "templates" for the files that contain the performance +% results. +filetemp_blissup = '%s/output_%s_%s_blissup.m'; +filetemp_blislpab = '%s/output_%s_%s_blislpab.m'; +filetemp_eigen = '%s/output_%s_%s_eigen.m'; +filetemp_open = '%s/output_%s_%s_openblas.m'; +filetemp_bfeo = '%s/output_%s_%s_blasfeo.m'; +filetemp_vend = '%s/output_%s_%s_vendor.m'; + +% Create a variable name "template" for the variables contained in the +% files outlined above. +vartemp = 'data_%s_%s_%s( :, : )'; + +% Define the datatypes and operations we will be plotting. +oproot = sprintf( '%cgemm', dt_ch ); +ops( 1, : ) = sprintf( '%s_nn', oproot ); +ops( 2, : ) = sprintf( '%s_nt', oproot ); +ops( 3, : ) = sprintf( '%s_tn', oproot ); +ops( 4, : ) = sprintf( '%s_tt', oproot ); + +% Generate datatype-specific operation names from the set of operations +% and datatypes. +[ opsupnames, opnames ] = gen_opsupnames( ops, stor_str, smalldims ); +n_opsupnames = size( opsupnames, 1 ); + +%opsupnames +%opnames +%return + +if 1 == 1 + %fig = figure('Position', [100, 100, 2400, 1500]); + fig = figure('Position', [100, 100, 2800, 1500]); + orient( fig, 'portrait' ); + set(gcf,'PaperUnits', 'inches'); + if impl == 'matlab' + set(gcf,'PaperSize', [11.5 20.4]); + set(gcf,'PaperPosition', [0 0 11.5 20.4]); + set(gcf,'PaperPositionMode','manual'); + else % impl == 'octave' % octave 4.x + set(gcf,'PaperSize', [10 17.5]); + set(gcf,'PaperPositionMode','auto'); + end + set(gcf,'PaperOrientation','landscape'); +end + + +% Iterate over the list of datatype-specific operation names. +for opi = 1:n_opsupnames +%for opi = 1:1 + + % Grab the current datatype combination. + opsupname = opsupnames( opi, : ); + opname = opnames( opi, : ); + + str = sprintf( 'Plotting %2d: %s', opi, opsupname ); disp(str); + + % Construct filenames for the data files from templates. + file_blissup = sprintf( filetemp_blissup, dirpath, thr_str, opsupname ); + file_blislpab = sprintf( filetemp_blislpab, dirpath, thr_str, opsupname ); + file_eigen = sprintf( filetemp_eigen, dirpath, thr_str, opsupname ); + file_open = sprintf( filetemp_open, dirpath, thr_str, opsupname ); + file_bfeo = sprintf( filetemp_bfeo, dirpath, thr_str, opsupname ); + file_vend = sprintf( filetemp_vend, dirpath, thr_str, opsupname ); + + % Load the data files. + %str = sprintf( ' Loading %s', file_blissup ); disp(str); + run( file_blissup ) + %str = sprintf( ' Loading %s', file_blislpab ); disp(str); + run( file_blislpab ) + %str = sprintf( ' Loading %s', file_eigen ); disp(str); + run( file_eigen ) + %str = sprintf( ' Loading %s', file_open ); disp(str); + run( file_open ) + %str = sprintf( ' Loading %s', file_open ); disp(str); + run( file_bfeo ) + %str = sprintf( ' Loading %s', file_vend ); disp(str); + run( file_vend ) + + % Construct variable names for the variables in the data files. + var_blissup = sprintf( vartemp, thr_str, opname, 'blissup' ); + var_blislpab = sprintf( vartemp, thr_str, opname, 'blislpab' ); + var_eigen = sprintf( vartemp, thr_str, opname, 'eigen' ); + var_open = sprintf( vartemp, thr_str, opname, 'openblas' ); + var_bfeo = sprintf( vartemp, thr_str, opname, 'blasfeo' ); + var_vend = sprintf( vartemp, thr_str, opname, 'vendor' ); + + % Use eval() to instantiate the variable names constructed above, + % copying each to a simplified name. + data_blissup = eval( var_blissup ); % e.g. data_st_dgemm_blissup( :, : ); + data_blislpab = eval( var_blislpab ); % e.g. data_st_dgemm_blislpab( :, : ); + data_eigen = eval( var_eigen ); % e.g. data_st_dgemm_eigen( :, : ); + data_open = eval( var_open ); % e.g. data_st_dgemm_openblas( :, : ); + data_bfeo = eval( var_bfeo ); % e.g. data_st_dgemm_blasfeo( :, : ); + data_vend = eval( var_vend ); % e.g. data_st_dgemm_vendor( :, : ); + + %str = sprintf( ' Reading %s', var_blissup ); disp(str); + %str = sprintf( ' Reading %s', var_blislpab ); disp(str); + %str = sprintf( ' Reading %s', var_eigen ); disp(str); + %str = sprintf( ' Reading %s', var_open ); disp(str); + %str = sprintf( ' Reading %s', var_bfeo ); disp(str); + %str = sprintf( ' Reading %s', var_vend ); disp(str); + + % Plot one result in an m x n grid of plots, via the subplot() + % function. + if 1 == 1 + plot_l3sup_perf( opsupname, ... + data_blissup, ... + data_blislpab, ... + data_eigen, ... + data_open, ... + data_bfeo, ... + data_vend, vend_str, ... + nth, ... + 4, 7, ... + cfreq, ... + dflopspercycle, ... + opi, impl ); + + clear data_st_*gemm_*; + clear data_blissup; + clear data_blislpab; + clear data_eigen; + clear data_open; + clear data_bfeo; + clear data_vend; + + end + +end + +% Construct the name of the file to which we will output the graph. +outfile = sprintf( 'l3sup_%s_%s_%s_nt%d.pdf', oproot, stor_str, arch_str, nth ); + +% Output the graph to pdf format. +%print(gcf, 'gemm_md','-fillpage','-dpdf'); +%print(gcf, outfile,'-bestfit','-dpdf'); +if impl == 'octave' +print(gcf, outfile); +else % if impl == 'matlab' +print(gcf, outfile,'-bestfit','-dpdf'); +end + diff --git a/test/sup/octave/runme.m b/test/sup/octave/runme.m new file mode 100644 index 000000000..a3628b28f --- /dev/null +++ b/test/sup/octave/runme.m @@ -0,0 +1,8 @@ + +% kabylake +plot_panel_trxsh(3.8,16,1,'st','d','rrr',[ 6 8 4 ],'../results/kabylake/20190619/4_800_4_mt201','kbl','MKL','matlab'); close; clear all; +plot_panel_trxsh(3.8,16,1,'st','d','ccc',[ 6 8 4 ],'../results/kabylake/20190619/4_800_4_mt201','kbl','MKL','matlab'); close; clear all; + +% epyc +plot_panel_trxsh(3.0,8,1,'st','d','rrr',[ 6 8 4 ],'../results/epyc/20190619/4_800_4_mt256','epyc','MKL','matlab'); close; clear all; +plot_panel_trxsh(3.0,8,1,'st','d','ccc',[ 6 8 4 ],'../results/epyc/20190619/4_800_4_mt256','epyc','MKL','matlab'); close; clear all; diff --git a/test/sup/runme.sh b/test/sup/runme.sh new file mode 100755 index 000000000..9646e3ccc --- /dev/null +++ b/test/sup/runme.sh @@ -0,0 +1,129 @@ +#!/bin/bash + +# File pefixes. +exec_root="test" +out_root="output" + +# Placeholder until we add multithreading. +nt=1 + +# Delay between test cases. +delay=0.02 + +# Threadedness to test. +threads="st" + +# Datatypes to test. +#dts="d s" +dts="d" + +# Operations to test. +ops="gemm" + +# Transpose combintions to test. +trans="nn nt tn tt" + +# Storage combinations to test. +#stors="rrr rrc rcr rcc crr crc ccr ccc" +stors="rrr ccc" + +# Problem shapes to test. +shapes="sll lsl lls lss sls ssl lll" + +# FGVZ: figure out how to probe what's in the directory and +# execute everything that's there? +sms="6" +sns="8" +sks="4" + +# Implementations to test. +impls="vendor blissup blislpab openblas eigen" +#impls="vendor openblas eigen" +#impls="blislpab blissup" +#mpls="openblas eigen vendor" +#mpls="eigen" +#impls="blissup" +#impls="blasfeo" + +# Example: test_dgemm_nn_rrc_m6npkp_blissup_st.x + +for th in ${threads}; do + + for dt in ${dts}; do + + for op in ${ops}; do + + for tr in ${trans}; do + + for st in ${stors}; do + + for sh in ${shapes}; do + + for sm in ${sms}; do + + for sn in ${sns}; do + + for sk in ${sks}; do + + for im in ${impls}; do + + # Limit execution of non-BLIS implementations to + # rrr/ccc storage cases. + if [ "${im:0:4}" != "blis" ] && \ + [ "${st}" != "rrr" ] && \ + [ "${st}" != "ccc" ]; then + continue; + fi + + # Extract the shape chars for m, n, k. + chm=${sh:0:1} + chn=${sh:1:1} + chk=${sh:2:1} + + # Construct the shape substring (e.g. m6npkp) + shstr="" + + if [ ${chm} = "s" ]; then + shstr="${shstr}m${sm}" + else + shstr="${shstr}mp" + fi + + if [ ${chn} = "s" ]; then + shstr="${shstr}n${sn}" + else + shstr="${shstr}np" + fi + + if [ ${chk} = "s" ]; then + shstr="${shstr}k${sk}" + else + shstr="${shstr}kp" + fi + + # Ex: test_dgemm_nn_rrc_m6npkp_blissup_st.x + + # Construct the name of the test executable. + exec_name="${exec_root}_${dt}${op}_${tr}_${st}_${shstr}_${im}_${th}.x" + + # Construct the name of the output file. + out_file="${out_root}_${th}_${dt}${op}_${tr}_${st}_${shstr}_${im}.m" + + echo "Running (nt = ${nt}) ./${exec_name} > ${out_file}" + + # Run executable. + ./${exec_name} > ${out_file} + + sleep ${delay} + + done + done + done + done + done + done + done + done + done +done + diff --git a/test/sup/test_gemm.c b/test/sup/test_gemm.c new file mode 100644 index 000000000..311e8552a --- /dev/null +++ b/test/sup/test_gemm.c @@ -0,0 +1,566 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#ifdef EIGEN + #define BLIS_DISABLE_BLAS_DEFS + #include "blis.h" + #include + //#include + using namespace Eigen; +#else + #include "blis.h" +#endif + +//#define PRINT + +int main( int argc, char** argv ) +{ + + bli_init(); + +#ifndef ERROR_CHECK + bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); +#endif + + + dim_t n_trials = N_TRIALS; + + num_t dt = DT; + +#if 1 + dim_t p_begin = P_BEGIN; + dim_t p_max = P_MAX; + dim_t p_inc = P_INC; +#else + dim_t p_begin = 4; + dim_t p_max = 40; + dim_t p_inc = 4; +#endif + +#if 1 + dim_t m_input = M_DIM; + dim_t n_input = N_DIM; + dim_t k_input = K_DIM; +#else + p_begin = p_inc = 32; + dim_t m_input = 6; + dim_t n_input = -1; + dim_t k_input = -1; +#endif + +#if 1 + trans_t transa = TRANSA; + trans_t transb = TRANSB; +#else + trans_t transa = BLIS_NO_TRANSPOSE; + trans_t transb = BLIS_NO_TRANSPOSE; +#endif + +#if 1 + stor3_t sc = STOR3; +#else + stor3_t sc = BLIS_RRR; +#endif + + + inc_t rs_c, cs_c; + inc_t rs_a, cs_a; + inc_t rs_b, cs_b; + + if ( sc == BLIS_RRR ) { rs_c = cs_c = -1; rs_a = cs_a = -1; rs_b = cs_b = -1; } + else if ( sc == BLIS_RRC ) { rs_c = cs_c = -1; rs_a = cs_a = -1; rs_b = cs_b = 0; } + else if ( sc == BLIS_RCR ) { rs_c = cs_c = -1; rs_a = cs_a = 0; rs_b = cs_b = -1; } + else if ( sc == BLIS_RCC ) { rs_c = cs_c = -1; rs_a = cs_a = 0; rs_b = cs_b = 0; } + else if ( sc == BLIS_CRR ) { rs_c = cs_c = 0; rs_a = cs_a = -1; rs_b = cs_b = -1; } + else if ( sc == BLIS_CRC ) { rs_c = cs_c = 0; rs_a = cs_a = -1; rs_b = cs_b = 0; } + else if ( sc == BLIS_CCR ) { rs_c = cs_c = 0; rs_a = cs_a = 0; rs_b = cs_b = -1; } + else if ( sc == BLIS_CCC ) { rs_c = cs_c = 0; rs_a = cs_a = 0; rs_b = cs_b = 0; } + else { bli_abort(); } + + f77_int cbla_storage; + + if ( sc == BLIS_RRR ) cbla_storage = CblasRowMajor; + else if ( sc == BLIS_CCC ) cbla_storage = CblasColMajor; + else cbla_storage = -1; + + ( void )cbla_storage; + + + char dt_ch; + + // Choose the char corresponding to the requested datatype. + if ( bli_is_float( dt ) ) dt_ch = 's'; + else if ( bli_is_double( dt ) ) dt_ch = 'd'; + else if ( bli_is_scomplex( dt ) ) dt_ch = 'c'; + else dt_ch = 'z'; + + f77_char f77_transa; + f77_char f77_transb; + char transal, transbl; + + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + + transal = tolower( f77_transa ); + transbl = tolower( f77_transb ); + + f77_int cbla_transa = ( transal == 'n' ? CblasNoTrans : CblasTrans ); + f77_int cbla_transb = ( transbl == 'n' ? CblasNoTrans : CblasTrans ); + + ( void )cbla_transa; + ( void )cbla_transb; + + dim_t p; + + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ; + + printf( "data_%s_%cgemm_%c%c_%s", THR_STR, dt_ch, + transal, transbl, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )0, + ( unsigned long )0, + ( unsigned long )0, 0.0 ); + + + for ( p = p_begin; p <= p_max; p += p_inc ) + { + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, n, k; + + if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); + else m = ( dim_t ) m_input; + if ( n_input < 0 ) n = p / ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + if ( k_input < 0 ) k = p / ( dim_t )abs(k_input); + else k = ( dim_t ) k_input; + + bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + bli_obj_create( dt, m, n, rs_c, cs_c, &c ); + bli_obj_create( dt, m, n, rs_c, cs_c, &c_save ); + + if ( bli_does_notrans( transa ) ) + bli_obj_create( dt, m, k, rs_a, cs_a, &a ); + else + bli_obj_create( dt, k, m, rs_a, cs_a, &a ); + + if ( bli_does_notrans( transb ) ) + bli_obj_create( dt, k, n, rs_b, cs_b, &b ); + else + bli_obj_create( dt, n, k, rs_b, cs_b, &b ); + + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); + + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_conjtrans( transb, &b ); + + bli_setsc( (2.0/1.0), 0.0, &alpha ); + bli_setsc( (1.0/1.0), 0.0, &beta ); + + bli_copym( &c, &c_save ); + +#ifdef EIGEN + double alpha_r, alpha_i; + + bli_getsc( &alpha, &alpha_r, &alpha_i ); + + void* ap = bli_obj_buffer_at_off( &a ); + void* bp = bli_obj_buffer_at_off( &b ); + void* cp = bli_obj_buffer_at_off( &c ); + + const int os_a = ( bli_obj_is_col_stored( &a ) ? bli_obj_col_stride( &a ) + : bli_obj_row_stride( &a ) ); + const int os_b = ( bli_obj_is_col_stored( &b ) ? bli_obj_col_stride( &b ) + : bli_obj_row_stride( &b ) ); + const int os_c = ( bli_obj_is_col_stored( &c ) ? bli_obj_col_stride( &c ) + : bli_obj_row_stride( &c ) ); + + Stride stride_a( os_a, 1 ); + Stride stride_b( os_b, 1 ); + Stride stride_c( os_c, 1 ); + + #if defined(IS_FLOAT) + #elif defined (IS_DOUBLE) + #ifdef A_STOR_R + typedef Matrix MatrixXd_A; + #else + typedef Matrix MatrixXd_A; + #endif + #ifdef B_STOR_R + typedef Matrix MatrixXd_B; + #else + typedef Matrix MatrixXd_B; + #endif + #ifdef C_STOR_R + typedef Matrix MatrixXd_C; + #else + typedef Matrix MatrixXd_C; + #endif + + #ifdef A_NOTRANS // A is not transposed + Map > A( ( double* )ap, m, k, stride_a ); + #else // A is transposed + Map > A( ( double* )ap, k, m, stride_a ); + #endif + + #ifdef B_NOTRANS // B is not transposed + Map > B( ( double* )bp, k, n, stride_b ); + #else // B is transposed + Map > B( ( double* )bp, n, k, stride_b ); + #endif + + Map > C( ( double* )cp, m, n, stride_c ); + #endif +#endif + + + double dtime_save = DBL_MAX; + + for ( dim_t r = 0; r < n_trials; ++r ) + { + bli_copym( &c_save, &c ); + + + double dtime = bli_clock(); + + +#ifdef EIGEN + + #ifdef A_NOTRANS + #ifdef B_NOTRANS + C.noalias() += alpha_r * A * B; + #else // B_TRANS + C.noalias() += alpha_r * A * B.transpose(); + #endif + #else // A_TRANS + #ifdef B_NOTRANS + C.noalias() += alpha_r * A.transpose() * B; + #else // B_TRANS + C.noalias() += alpha_r * A.transpose() * B.transpose(); + #endif + #endif + +#endif +#ifdef BLIS + #ifdef SUP + // Allow sup. + bli_gemm( &alpha, + &a, + &b, + &beta, + &c ); + #else + // Disable sup and use the expert interface. + rntm_t rntm = BLIS_RNTM_INITIALIZER; + bli_rntm_disable_l3_sup( &rntm ); + + bli_gemm_ex( &alpha, + &a, + &b, + &beta, + &c, NULL, &rntm ); + #endif +#endif +#ifdef BLAS + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = ( float* )bli_obj_buffer( &alpha ); + float* ap = ( float* )bli_obj_buffer( &a ); + float* bp = ( float* )bli_obj_buffer( &b ); + float* betap = ( float* )bli_obj_buffer( &beta ); + float* cp = ( float* )bli_obj_buffer( &c ); + + sgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = ( double* )bli_obj_buffer( &alpha ); + double* ap = ( double* )bli_obj_buffer( &a ); + double* bp = ( double* )bli_obj_buffer( &b ); + double* betap = ( double* )bli_obj_buffer( &beta ); + double* cp = ( double* )bli_obj_buffer( &c ); + + dgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha ); + scomplex* ap = ( scomplex* )bli_obj_buffer( &a ); + scomplex* bp = ( scomplex* )bli_obj_buffer( &b ); + scomplex* betap = ( scomplex* )bli_obj_buffer( &beta ); + scomplex* cp = ( scomplex* )bli_obj_buffer( &c ); + + cgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha ); + dcomplex* ap = ( dcomplex* )bli_obj_buffer( &a ); + dcomplex* bp = ( dcomplex* )bli_obj_buffer( &b ); + dcomplex* betap = ( dcomplex* )bli_obj_buffer( &beta ); + dcomplex* cp = ( dcomplex* )bli_obj_buffer( &c ); + + zgemm_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + } +#endif +#ifdef CBLAS + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* bp = bli_obj_buffer( &b ); + float* betap = bli_obj_buffer( &beta ); + float* cp = bli_obj_buffer( &c ); + + cblas_sgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* bp = bli_obj_buffer( &b ); + double* betap = bli_obj_buffer( &beta ); + double* cp = bli_obj_buffer( &c ); + + cblas_dgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* bp = bli_obj_buffer( &b ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* cp = bli_obj_buffer( &c ); + + cblas_cgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc ); + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + #ifdef C_STOR_R + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + #else + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + #endif + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* bp = bli_obj_buffer( &b ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* cp = bli_obj_buffer( &c ); + + cblas_zgemm( cbla_storage, + cbla_transa, + cbla_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc ); + } +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + double gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 4.0; + + printf( "data_%s_%cgemm_%c%c_%s", THR_STR, dt_ch, + transal, transbl, STR ); + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin + 1)/p_inc + 1, + ( unsigned long )m, + ( unsigned long )n, + ( unsigned long )k, gflops ); + + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + + //bli_finalize(); + + return 0; +} + diff --git a/test/test_gemm.c b/test/test_gemm.c index 2d650260e..0ab9b4c1b 100644 --- a/test/test_gemm.c +++ b/test/test_gemm.c @@ -147,6 +147,10 @@ int main( int argc, char** argv ) bli_obj_create( dt, m, n, 1, cs_c, &c ); bli_obj_create( dt, m, n, 1, cs_c, &c_save ); + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); + bli_obj_set_conjtrans( transa, &a); bli_obj_set_conjtrans( transb, &b); diff --git a/test/thread_ranges/Makefile b/test/thread_ranges/Makefile index 2ed155be1..5af2ce533 100644 --- a/test/thread_ranges/Makefile +++ b/test/thread_ranges/Makefile @@ -104,7 +104,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Datatype diff --git a/testsuite/Makefile b/testsuite/Makefile index 1e97cdcf4..57c1c748d 100644 --- a/testsuite/Makefile +++ b/testsuite/Makefile @@ -103,7 +103,7 @@ CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) CFLAGS += -I$(TEST_SRC_PATH) # Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # Binary executable name. TESTSUITE_BIN := test_libblis.x diff --git a/testsuite/src/test_addm.c b/testsuite/src/test_addm.c index c957e4dc9..f7c21b733 100644 --- a/testsuite/src/test_addm.c +++ b/testsuite/src/test_addm.c @@ -275,7 +275,7 @@ void libblis_test_addm_check // // is functioning correctly if // - // normfv(y) - sqrt( absqsc( beta + conjx(alpha) ) * m * n ) + // normfm(y) - sqrt( absqsc( beta + conjx(alpha) ) * m * n ) // // is negligible. // diff --git a/testsuite/src/test_amaxv.c b/testsuite/src/test_amaxv.c index 3d58a34ed..ed2a821fe 100644 --- a/testsuite/src/test_amaxv.c +++ b/testsuite/src/test_amaxv.c @@ -351,11 +351,18 @@ void PASTEMAC0(opname) \ \ void* buf_index = bli_obj_buffer_at_off( index ); \ \ +/* + FGVZ: Disabling this code since bli_amaxv_check() is supposed to be a + non-public API function, and therefore unavailable unless all symbols + are scheduled to be exported at configure-time (which is not currently + the default behavior). + if ( bli_error_checking_is_enabled() ) \ bli_amaxv_check( x, index ); \ +*/ \ \ /* Query a type-specific function pointer, except one that uses - void* instead of typed pointers. */ \ + void* for function arguments instead of typed pointers. */ \ PASTECH(tname,_vft) f = \ PASTEMAC(opname,_qfp)( dt ); \ \ diff --git a/testsuite/src/test_axpbyv.c b/testsuite/src/test_axpbyv.c index 2cfdd416c..a82ff6e25 100644 --- a/testsuite/src/test_axpbyv.c +++ b/testsuite/src/test_axpbyv.c @@ -296,7 +296,7 @@ void libblis_test_axpbyv_check // // is functioning correctly if // - // normf( y - ( beta * y_orig + alpha * conjx(x) ) ) + // normfv( y - ( beta * y_orig + alpha * conjx(x) ) ) // // is negligible. // diff --git a/testsuite/src/test_axpy2v.c b/testsuite/src/test_axpy2v.c index f310d5cb6..eeebf15e7 100644 --- a/testsuite/src/test_axpy2v.c +++ b/testsuite/src/test_axpy2v.c @@ -314,7 +314,7 @@ void libblis_test_axpy2v_check // // is functioning correctly if // - // normf( z - v ) + // normfv( z - v ) // // is negligible, where v contains z as computed by two calls to axpyv. // diff --git a/testsuite/src/test_axpyf.c b/testsuite/src/test_axpyf.c index 7572b3a48..7a85b2212 100644 --- a/testsuite/src/test_axpyf.c +++ b/testsuite/src/test_axpyf.c @@ -319,7 +319,7 @@ void libblis_test_axpyf_check // // is functioning correctly if // - // normf( y - v ) + // normfv( y - v ) // // is negligible, where v contains y as computed by repeated calls to // axpyv. diff --git a/testsuite/src/test_axpym.c b/testsuite/src/test_axpym.c index 9097043a4..222fda33d 100644 --- a/testsuite/src/test_axpym.c +++ b/testsuite/src/test_axpym.c @@ -289,7 +289,7 @@ void libblis_test_axpym_check // // is functioning correctly if // - // normf( y - ( y_orig + alpha * conjx(x) ) ) + // normfm( y - ( y_orig + alpha * conjx(x) ) ) // // is negligible. // diff --git a/testsuite/src/test_axpyv.c b/testsuite/src/test_axpyv.c index 5f3f991ef..81d4f3770 100644 --- a/testsuite/src/test_axpyv.c +++ b/testsuite/src/test_axpyv.c @@ -286,7 +286,7 @@ void libblis_test_axpyv_check // // is functioning correctly if // - // normf( y - ( y_orig + alpha * conjx(x) ) ) + // normfv( y - ( y_orig + alpha * conjx(x) ) ) // // is negligible. // diff --git a/testsuite/src/test_dotaxpyv.c b/testsuite/src/test_dotaxpyv.c index 28d6c6916..391c119bb 100644 --- a/testsuite/src/test_dotaxpyv.c +++ b/testsuite/src/test_dotaxpyv.c @@ -345,7 +345,7 @@ void libblis_test_dotaxpyv_check // // and // - // normf( z - z_temp ) + // normfv( z - z_temp ) // // are negligible, where rho_temp and z_temp contain rho and z as // computed by dotv and axpyv, respectively. diff --git a/testsuite/src/test_dotv.c b/testsuite/src/test_dotv.c index bc4ad54f9..347ce9e62 100644 --- a/testsuite/src/test_dotv.c +++ b/testsuite/src/test_dotv.c @@ -278,7 +278,7 @@ void libblis_test_dotv_check // // is functioning correctly if // - // sqrtsc( rho.real ) - normf( x ) + // sqrtsc( rho.real ) - normfv( x ) // // and // diff --git a/testsuite/src/test_dotxaxpyf.c b/testsuite/src/test_dotxaxpyf.c index ca57ca39b..c73ab6c9d 100644 --- a/testsuite/src/test_dotxaxpyf.c +++ b/testsuite/src/test_dotxaxpyf.c @@ -366,11 +366,11 @@ void libblis_test_dotxaxpyf_check // // is functioning correctly if // - // normf( y - v ) + // normfv( y - v ) // // and // - // normf( z - q ) + // normfv( z - q ) // // are negligible, where v and q contain y and z as computed by repeated // calls to dotxv and axpyv, respectively. diff --git a/testsuite/src/test_dotxf.c b/testsuite/src/test_dotxf.c index eefe2bb77..8a1eca4eb 100644 --- a/testsuite/src/test_dotxf.c +++ b/testsuite/src/test_dotxf.c @@ -324,7 +324,7 @@ void libblis_test_dotxf_check // // is functioning correctly if // - // normf( y - v ) + // normfv( y - v ) // // is negligible, where v contains y as computed by repeated calls to // dotxv. diff --git a/testsuite/src/test_dotxv.c b/testsuite/src/test_dotxv.c index fb677a06d..da42e6ae4 100644 --- a/testsuite/src/test_dotxv.c +++ b/testsuite/src/test_dotxv.c @@ -304,7 +304,7 @@ void libblis_test_dotxv_check // // is functioning correctly if // - // sqrtsc( rho.real ) - sqrtsc( alpha ) * normf( x ) + // sqrtsc( rho.real ) - sqrtsc( alpha ) * normfv( x ) // // and // diff --git a/testsuite/src/test_gemm.c b/testsuite/src/test_gemm.c index 2ec117464..e941946e1 100644 --- a/testsuite/src/test_gemm.c +++ b/testsuite/src/test_gemm.c @@ -236,11 +236,11 @@ void libblis_test_gemm_experiment libblis_test_mobj_create( params, datatype, transa, sc_str[1], m, k, &a ); libblis_test_mobj_create( params, datatype, transb, - sc_str[1], k, n, &b ); + sc_str[2], k, n, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c ); + sc_str[0], m, n, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c_save ); + sc_str[0], m, n, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -259,6 +259,9 @@ void libblis_test_gemm_experiment libblis_test_mobj_randomize( params, TRUE, &b ); libblis_test_mobj_randomize( params, TRUE, &c ); bli_copym( &c, &c_save ); +//bli_setm( &BLIS_ONE, &a ); +//bli_setsc( 1.0, 0.0, &alpha ); +//bli_setsc( 0.0, 0.0, &beta ); // Apply the parameters. bli_obj_set_conjtrans( transa, &a ); @@ -349,13 +352,13 @@ void libblis_test_gemm_md // Create test operands (vectors and/or matrices). libblis_test_mobj_create( params, dt_a, transa, - sc_str[0], m, k, &a ); + sc_str[1], m, k, &a ); libblis_test_mobj_create( params, dt_b, transb, - sc_str[1], k, n, &b ); + sc_str[2], k, n, &b ); libblis_test_mobj_create( params, dt_c, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c ); + sc_str[0], m, n, &c ); libblis_test_mobj_create( params, dt_c, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c_save ); + sc_str[0], m, n, &c_save ); // For mixed-precision, set the computation precision of C. if ( params->mixed_precision ) @@ -449,15 +452,20 @@ void libblis_test_gemm_impl { case BLIS_TEST_SEQ_FRONT_END: #if 0 +//bli_printm( "alpha", alpha, "%5.2f", "" ); +//bli_printm( "beta", beta, "%5.2f", "" ); bli_printm( "a", a, "%5.2f", "" ); bli_printm( "b", b, "%5.2f", "" ); bli_printm( "c", c, "%5.2f", "" ); -bli_printm( "alpha", alpha, "%5.2f", "" ); -bli_printm( "beta", beta, "%5.2f", "" ); #endif +//if ( bli_obj_length( b ) == 16 && +// bli_obj_stor3_from_strides( c, a, b ) == BLIS_CRR ) +//bli_printm( "c before", c, "%6.3f", "" ); bli_gemm( alpha, a, b, beta, c ); #if 0 -bli_printm( "c after", c, "%5.2f", "" ); +if ( bli_obj_length( c ) == 12 && + bli_obj_stor3_from_strides( c, a, b ) == BLIS_RRR ) +bli_printm( "c after", c, "%6.3f", "" ); #endif break; @@ -617,7 +625,7 @@ void libblis_test_gemm_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_gemm_ukr.c b/testsuite/src/test_gemm_ukr.c index 8de3f144b..66e84d644 100644 --- a/testsuite/src/test_gemm_ukr.c +++ b/testsuite/src/test_gemm_ukr.c @@ -390,7 +390,7 @@ void libblis_test_gemm_ukr_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_gemmtrsm_ukr.c b/testsuite/src/test_gemmtrsm_ukr.c index 58da5410a..e2cf10ab3 100644 --- a/testsuite/src/test_gemmtrsm_ukr.c +++ b/testsuite/src/test_gemmtrsm_ukr.c @@ -465,7 +465,7 @@ void libblis_test_gemmtrsm_ukr_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_gemv.c b/testsuite/src/test_gemv.c index 7df4c3fc8..e6090e1c5 100644 --- a/testsuite/src/test_gemv.c +++ b/testsuite/src/test_gemv.c @@ -324,7 +324,7 @@ void libblis_test_gemv_check // // is functioning correctly if // - // normf( y - z ) + // normfv( y - z ) // // is negligible, where // diff --git a/testsuite/src/test_ger.c b/testsuite/src/test_ger.c index 961247e84..b44fe6ba6 100644 --- a/testsuite/src/test_ger.c +++ b/testsuite/src/test_ger.c @@ -303,7 +303,7 @@ void libblis_test_ger_check // // is functioning correctly if // - // normf( v - w ) + // normfv( v - w ) // // is negligible, where // diff --git a/testsuite/src/test_hemm.c b/testsuite/src/test_hemm.c index 8768af8c7..0145dd0df 100644 --- a/testsuite/src/test_hemm.c +++ b/testsuite/src/test_hemm.c @@ -202,13 +202,13 @@ void libblis_test_hemm_experiment // Create test operands (vectors and/or matrices). bli_set_dim_with_side( side, m, n, &mn_side ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[0], mn_side, mn_side, &a ); + sc_str[1], mn_side, mn_side, &a ); libblis_test_mobj_create( params, datatype, transb, - sc_str[1], m, n, &b ); + sc_str[2], m, n, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c ); + sc_str[0], m, n, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c_save ); + sc_str[0], m, n, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -338,7 +338,7 @@ void libblis_test_hemm_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_hemv.c b/testsuite/src/test_hemv.c index 5aca22239..02e205392 100644 --- a/testsuite/src/test_hemv.c +++ b/testsuite/src/test_hemv.c @@ -322,7 +322,7 @@ void libblis_test_hemv_check // // is functioning correctly if // - // normf( y - v ) + // normfv( y - v ) // // is negligible, where // diff --git a/testsuite/src/test_her.c b/testsuite/src/test_her.c index 855679cf8..c122f6ce5 100644 --- a/testsuite/src/test_her.c +++ b/testsuite/src/test_her.c @@ -301,7 +301,7 @@ void libblis_test_her_check // // is functioning correctly if // - // normf( v - w ) + // normfv( v - w ) // // is negligible, where // diff --git a/testsuite/src/test_her2.c b/testsuite/src/test_her2.c index ee35cc93f..1ed6b3bb9 100644 --- a/testsuite/src/test_her2.c +++ b/testsuite/src/test_her2.c @@ -311,7 +311,7 @@ void libblis_test_her2_check // // is functioning correctly if // - // normf( v - w ) + // normfv( v - w ) // // is negligible, where // diff --git a/testsuite/src/test_her2k.c b/testsuite/src/test_her2k.c index c5cf5dbeb..0158e25a2 100644 --- a/testsuite/src/test_her2k.c +++ b/testsuite/src/test_her2k.c @@ -195,13 +195,13 @@ void libblis_test_her2k_experiment // Create test operands (vectors and/or matrices). libblis_test_mobj_create( params, datatype, transa, - sc_str[0], m, k, &a ); + sc_str[1], m, k, &a ); libblis_test_mobj_create( params, datatype, transb, - sc_str[1], m, k, &b ); + sc_str[2], m, k, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, m, &c ); + sc_str[0], m, m, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, m, &c_save ); + sc_str[0], m, m, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -336,7 +336,7 @@ void libblis_test_her2k_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_herk.c b/testsuite/src/test_herk.c index af192f06b..abe4e70b1 100644 --- a/testsuite/src/test_herk.c +++ b/testsuite/src/test_herk.c @@ -192,11 +192,11 @@ void libblis_test_herk_experiment // Create test operands (vectors and/or matrices). libblis_test_mobj_create( params, datatype, transa, - sc_str[0], m, k, &a ); + sc_str[1], m, k, &a ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, m, &c ); + sc_str[0], m, m, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, m, &c_save ); + sc_str[0], m, m, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -323,7 +323,7 @@ void libblis_test_herk_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_libblis.c b/testsuite/src/test_libblis.c index 780c4b9e6..5c2dd2ed9 100644 --- a/testsuite/src/test_libblis.c +++ b/testsuite/src/test_libblis.c @@ -130,22 +130,22 @@ void libblis_test_thread_decorator( test_params_t* params, test_ops_t* ops ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - bli_pthread_t* pthread = bli_malloc_intl( sizeof( bli_pthread_t ) * nt ); + bli_pthread_t* pthread = bli_malloc_user( sizeof( bli_pthread_t ) * nt ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - thread_data_t* tdata = bli_malloc_intl( sizeof( thread_data_t ) * nt ); + thread_data_t* tdata = bli_malloc_user( sizeof( thread_data_t ) * nt ); // Allocate a mutex for the threads to share. - //bli_pthread_mutex_t* mutex = bli_malloc_intl( sizeof( bli_pthread_mutex_t ) ); + //bli_pthread_mutex_t* mutex = bli_malloc_user( sizeof( bli_pthread_mutex_t ) ); // Allocate a barrier for the threads to share. #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - bli_pthread_barrier_t* barrier = bli_malloc_intl( sizeof( bli_pthread_barrier_t ) ); + bli_pthread_barrier_t* barrier = bli_malloc_user( sizeof( bli_pthread_barrier_t ) ); // Initialize the mutex. //bli_pthread_mutex_init( mutex, NULL ); @@ -191,18 +191,18 @@ void libblis_test_thread_decorator( test_params_t* params, test_ops_t* ops ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - bli_free_intl( pthread ); + bli_free_user( pthread ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - bli_free_intl( tdata ); + bli_free_user( tdata ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - //bli_free_intl( mutex ); - bli_free_intl( barrier ); + //bli_free_user( mutex ); + bli_free_user( barrier ); } diff --git a/testsuite/src/test_normfm.c b/testsuite/src/test_normfm.c index 8cee0332e..c4b9a0105 100644 --- a/testsuite/src/test_normfm.c +++ b/testsuite/src/test_normfm.c @@ -259,7 +259,7 @@ void libblis_test_normfm_check // // Under these conditions, we assume that the implementation for // - // norm := normf( x ) + // norm := normfm( x ) // // is functioning correctly if // diff --git a/testsuite/src/test_normfv.c b/testsuite/src/test_normfv.c index 8a473f5fd..3bcce35af 100644 --- a/testsuite/src/test_normfv.c +++ b/testsuite/src/test_normfv.c @@ -256,7 +256,7 @@ void libblis_test_normfv_check // // Under these conditions, we assume that the implementation for // - // norm := normf( x ) + // norm := normfv( x ) // // is functioning correctly if // diff --git a/testsuite/src/test_scal2m.c b/testsuite/src/test_scal2m.c index 3dcff3d78..e8440fc46 100644 --- a/testsuite/src/test_scal2m.c +++ b/testsuite/src/test_scal2m.c @@ -288,7 +288,7 @@ void libblis_test_scal2m_check // // is functioning correctly if // - // normf( y - alpha * conjx(x) ) + // normfm( y - alpha * conjx(x) ) // // is negligible. // diff --git a/testsuite/src/test_scal2v.c b/testsuite/src/test_scal2v.c index 94af24502..c200e13fc 100644 --- a/testsuite/src/test_scal2v.c +++ b/testsuite/src/test_scal2v.c @@ -285,7 +285,7 @@ void libblis_test_scal2v_check // // is functioning correctly if // - // normf( y - alpha * conjx(x) ) + // normfv( y - alpha * conjx(x) ) // // is negligible. // diff --git a/testsuite/src/test_scalm.c b/testsuite/src/test_scalm.c index b3f2066e0..6219c71df 100644 --- a/testsuite/src/test_scalm.c +++ b/testsuite/src/test_scalm.c @@ -280,7 +280,7 @@ void libblis_test_scalm_check // // is functioning correctly if // - // normf( y + -conjbeta(beta) * y_orig ) + // normfm( y + -conjbeta(beta) * y_orig ) // // is negligible. // diff --git a/testsuite/src/test_scalv.c b/testsuite/src/test_scalv.c index fefb23b4a..142b5e410 100644 --- a/testsuite/src/test_scalv.c +++ b/testsuite/src/test_scalv.c @@ -276,7 +276,7 @@ void libblis_test_scalv_check // // is functioning correctly if // - // normf( y + -conjbeta(beta) * y_orig ) + // normfv( y + -conjbeta(beta) * y_orig ) // // is negligible. // diff --git a/testsuite/src/test_subm.c b/testsuite/src/test_subm.c index b2de8cfad..63b48eedc 100644 --- a/testsuite/src/test_subm.c +++ b/testsuite/src/test_subm.c @@ -275,7 +275,7 @@ void libblis_test_subm_check // // is functioning correctly if // - // normfv(y) - sqrt( absqsc( beta - conjx(alpha) ) * m * n ) + // normfm(y) - sqrt( absqsc( beta - conjx(alpha) ) * m * n ) // // is negligible. // diff --git a/testsuite/src/test_symm.c b/testsuite/src/test_symm.c index ded4bb143..2ac7b4106 100644 --- a/testsuite/src/test_symm.c +++ b/testsuite/src/test_symm.c @@ -202,13 +202,13 @@ void libblis_test_symm_experiment // Create test operands (vectors and/or matrices). bli_set_dim_with_side( side, m, n, &mn_side ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[0], mn_side, mn_side, &a ); + sc_str[1], mn_side, mn_side, &a ); libblis_test_mobj_create( params, datatype, transb, - sc_str[1], m, n, &b ); + sc_str[2], m, n, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c ); + sc_str[0], m, n, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c_save ); + sc_str[0], m, n, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -338,7 +338,7 @@ void libblis_test_symm_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_symv.c b/testsuite/src/test_symv.c index c99941ea0..5ae5f30be 100644 --- a/testsuite/src/test_symv.c +++ b/testsuite/src/test_symv.c @@ -322,7 +322,7 @@ void libblis_test_symv_check // // is functioning correctly if // - // normf( y - v ) + // normfv( y - v ) // // is negligible, where // diff --git a/testsuite/src/test_syr.c b/testsuite/src/test_syr.c index b3a6a356d..69376b970 100644 --- a/testsuite/src/test_syr.c +++ b/testsuite/src/test_syr.c @@ -301,7 +301,7 @@ void libblis_test_syr_check // // is functioning correctly if // - // normf( v - w ) + // normfv( v - w ) // // is negligible, where // diff --git a/testsuite/src/test_syr2.c b/testsuite/src/test_syr2.c index 6b5d72e76..42d65c00e 100644 --- a/testsuite/src/test_syr2.c +++ b/testsuite/src/test_syr2.c @@ -313,7 +313,7 @@ void libblis_test_syr2_check // // is functioning correctly if // - // normf( v - w ) + // normfv( v - w ) // // is negligible, where // diff --git a/testsuite/src/test_syr2k.c b/testsuite/src/test_syr2k.c index 69dfbda7a..4d83bb88c 100644 --- a/testsuite/src/test_syr2k.c +++ b/testsuite/src/test_syr2k.c @@ -195,13 +195,13 @@ void libblis_test_syr2k_experiment // Create test operands (vectors and/or matrices). libblis_test_mobj_create( params, datatype, transa, - sc_str[0], m, k, &a ); + sc_str[1], m, k, &a ); libblis_test_mobj_create( params, datatype, transb, - sc_str[1], m, k, &b ); + sc_str[2], m, k, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, m, &c ); + sc_str[0], m, m, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, m, &c_save ); + sc_str[0], m, m, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -335,7 +335,7 @@ void libblis_test_syr2k_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_syrk.c b/testsuite/src/test_syrk.c index a21bb7af5..65d978bb0 100644 --- a/testsuite/src/test_syrk.c +++ b/testsuite/src/test_syrk.c @@ -192,11 +192,11 @@ void libblis_test_syrk_experiment // Create test operands (vectors and/or matrices). libblis_test_mobj_create( params, datatype, transa, - sc_str[0], m, k, &a ); + sc_str[1], m, k, &a ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, m, &c ); + sc_str[0], m, m, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, m, &c_save ); + sc_str[0], m, m, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -324,7 +324,7 @@ void libblis_test_syrk_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_trmm.c b/testsuite/src/test_trmm.c index 02505e630..a1decd37c 100644 --- a/testsuite/src/test_trmm.c +++ b/testsuite/src/test_trmm.c @@ -197,11 +197,11 @@ void libblis_test_trmm_experiment // Create test operands (vectors and/or matrices). bli_set_dim_with_side( side, m, n, &mn_side ); libblis_test_mobj_create( params, datatype, transa, - sc_str[0], mn_side, mn_side, &a ); + sc_str[1], mn_side, mn_side, &a ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, n, &b ); + sc_str[0], m, n, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, n, &b_save ); + sc_str[0], m, n, &b_save ); // Set alpha and beta. if ( bli_obj_is_real( &b ) ) @@ -320,7 +320,7 @@ void libblis_test_trmm_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_trmm3.c b/testsuite/src/test_trmm3.c index bd3937d3b..17ba2190b 100644 --- a/testsuite/src/test_trmm3.c +++ b/testsuite/src/test_trmm3.c @@ -204,13 +204,13 @@ void libblis_test_trmm3_experiment // Create test operands (vectors and/or matrices). bli_set_dim_with_side( side, m, n, &mn_side ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[0], mn_side, mn_side, &a ); + sc_str[1], mn_side, mn_side, &a ); libblis_test_mobj_create( params, datatype, transb, - sc_str[1], m, n, &b ); + sc_str[2], m, n, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c ); + sc_str[0], m, n, &c ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[2], m, n, &c_save ); + sc_str[0], m, n, &c_save ); // Set alpha and beta. if ( bli_obj_is_real( &c ) ) @@ -339,7 +339,7 @@ void libblis_test_trmm3_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_trmv.c b/testsuite/src/test_trmv.c index 85ee1e802..71acc90ba 100644 --- a/testsuite/src/test_trmv.c +++ b/testsuite/src/test_trmv.c @@ -304,7 +304,7 @@ void libblis_test_trmv_check // // is functioning correctly if // - // normf( y - x ) + // normfv( y - x ) // // is negligible, where // diff --git a/testsuite/src/test_trsm.c b/testsuite/src/test_trsm.c index 1c90edef5..fa0d8e7c3 100644 --- a/testsuite/src/test_trsm.c +++ b/testsuite/src/test_trsm.c @@ -197,11 +197,11 @@ void libblis_test_trsm_experiment // Create test operands (vectors and/or matrices). bli_set_dim_with_side( side, m, n, &mn_side ); libblis_test_mobj_create( params, datatype, transa, - sc_str[0], mn_side, mn_side, &a ); + sc_str[1], mn_side, mn_side, &a ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, n, &b ); + sc_str[0], m, n, &b ); libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE, - sc_str[1], m, n, &b_save ); + sc_str[0], m, n, &b_save ); // Set alpha. if ( bli_obj_is_real( &b ) ) @@ -327,7 +327,7 @@ void libblis_test_trsm_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_trsm_ukr.c b/testsuite/src/test_trsm_ukr.c index 1f81400c9..ee468dbd3 100644 --- a/testsuite/src/test_trsm_ukr.c +++ b/testsuite/src/test_trsm_ukr.c @@ -401,7 +401,7 @@ void libblis_test_trsm_ukr_check // // is functioning correctly if // - // normf( v - z ) + // normfv( v - z ) // // is negligible, where // diff --git a/testsuite/src/test_trsv.c b/testsuite/src/test_trsv.c index 15398b193..12543cd9a 100644 --- a/testsuite/src/test_trsv.c +++ b/testsuite/src/test_trsv.c @@ -305,7 +305,7 @@ void libblis_test_trsv_check // // is functioning correctly if // - // normf( y - x_orig ) + // normfv( y - x_orig ) // // is negligible, where // diff --git a/testsuite/src/test_xpbym.c b/testsuite/src/test_xpbym.c index b7acc654e..2340b4e11 100644 --- a/testsuite/src/test_xpbym.c +++ b/testsuite/src/test_xpbym.c @@ -288,7 +288,7 @@ void libblis_test_xpbym_check // // is functioning correctly if // - // normf( y - ( beta * y_orig + conjx(x) ) ) + // normfm( y - ( beta * y_orig + conjx(x) ) ) // // is negligible. // diff --git a/testsuite/src/test_xpbyv.c b/testsuite/src/test_xpbyv.c index 6b2f21734..197de86e7 100644 --- a/testsuite/src/test_xpbyv.c +++ b/testsuite/src/test_xpbyv.c @@ -283,7 +283,7 @@ void libblis_test_xpbyv_check // // is functioning correctly if // - // normf( y - ( beta * y_orig + conjx(x) ) ) + // normfv( y - ( beta * y_orig + conjx(x) ) ) // // is negligible. // diff --git a/version b/version index cd5ac039d..879b416e6 100644 --- a/version +++ b/version @@ -1 +1 @@ -2.0 +2.1 diff --git a/windows/build/bli_config.h b/windows/build/bli_config.h deleted file mode 100644 index aced5d1b7..000000000 --- a/windows/build/bli_config.h +++ /dev/null @@ -1,141 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef BLIS_CONFIG_H -#define BLIS_CONFIG_H - - -// -- OPERATING SYSTEM --------------------------------------------------------- - - - -// -- FLOATING-POINT PROPERTIES ------------------------------------------------ - -#define BLIS_NUM_FP_TYPES 4 -#define BLIS_MAX_TYPE_SIZE sizeof(dcomplex) - -// Enable use of built-in C99 "float complex" and "double complex" types and -// associated overloaded operations and functions? Disabling results in -// scomplex and dcomplex being defined in terms of simple structs. -//#define BLIS_ENABLE_C99_COMPLEX - - - -// -- MULTITHREADING ----------------------------------------------------------- - -// The maximum number of BLIS threads that will run concurrently. -#define BLIS_MAX_NUM_THREADS 24 - - - -// -- MEMORY ALLOCATION -------------------------------------------------------- - -// -- Contiguous (static) memory allocator -- - -// The number of MC x KC, KC x NC, and MC x NC blocks to reserve in the -// contiguous memory pools. -#define BLIS_NUM_MC_X_KC_BLOCKS BLIS_MAX_NUM_THREADS -#define BLIS_NUM_KC_X_NC_BLOCKS 1 -#define BLIS_NUM_MC_X_NC_BLOCKS 1 - -// The maximum preload byte offset is used to pad the end of the contiguous -// memory pools so that the micro-kernel, when computing with the end of the -// last block, can exceed the bounds of the usable portion of the memory -// region without causing a segmentation fault. -#define BLIS_MAX_PRELOAD_BYTE_OFFSET 128 - -// -- Memory alignment -- - -// It is sometimes useful to define the various memory alignments in terms -// of some other characteristics of the system, such as the cache line size -// and the page size. -#define BLIS_CACHE_LINE_SIZE 64 -#define BLIS_PAGE_SIZE 4096 - -// Alignment size used to align local stack buffers within macro-kernel -// functions. -#define BLIS_STACK_BUF_ALIGN_SIZE 16 - -// Alignment size used when allocating memory dynamically from the operating -// system (eg: posix_memalign()). To disable heap alignment and just use -// malloc() instead, set this to 1. -#define BLIS_HEAP_ADDR_ALIGN_SIZE 16 - -// Alignment size used when sizing leading dimensions of dynamically -// allocated memory. -#define BLIS_HEAP_STRIDE_ALIGN_SIZE BLIS_CACHE_LINE_SIZE - -// Alignment size used when allocating entire blocks of contiguous memory -// from the contiguous memory allocator. -#define BLIS_CONTIG_ADDR_ALIGN_SIZE BLIS_PAGE_SIZE - - - -// -- MIXED DATATYPE SUPPORT --------------------------------------------------- - -// Basic (homogeneous) datatype support always enabled. - -// Enable mixed domain operations? -//#define BLIS_ENABLE_MIXED_DOMAIN_SUPPORT - -// Enable extra mixed precision operations? -//#define BLIS_ENABLE_MIXED_PRECISION_SUPPORT - - - -// -- MISCELLANEOUS OPTIONS ---------------------------------------------------- - -// Stay initialized after auto-initialization, unless and until the user -// explicitly calls bli_finalize(). -#define BLIS_ENABLE_STAY_AUTO_INITIALIZED - - - -// -- BLAS-to-BLIS COMPATIBILITY LAYER ----------------------------------------- - -// Enable the BLAS compatibility layer? -#define BLIS_ENABLE_BLAS2BLIS - -// Enable 64-bit integers in the BLAS compatibility layer? If disabled, -// these integers will be defined as 32-bit. -#define BLIS_ENABLE_BLAS2BLIS_INT64 - -// Fortran-77 name-mangling macros. -#define PASTEF770(name) name ## _ -#define PASTEF77(ch1,name) ch1 ## name ## _ -#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name ## _ - - -#endif -