Files
blis/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c
Field G. Van Zee 7038bbaa05 Optionally disable trsm diagonal pre-inversion.
Details:
- Implemented a configure-time option, --disable-trsm-preinversion, that
  optionally disables the pre-inversion of diagonal elements of the
  triangular matrix in the trsm operation and instead uses division
  instructions within the gemmtrsm microkernels. Pre-inversion is
  enabled by default. When it is disabled, performance may suffer
  slightly, but numerical robustness should improve for certain
  pathological cases involving denormal (subnormal) numbers that would
  otherwise result in overflow in the pre-inverted value. Thanks to
  Bhaskar Nallani for reporting this issue via #461.
- Added preprocessor macro guards to bli_trsm_cntl.c as well as the
  gemmtrsm microkernels for 'haswell' and 'penryn' kernel sets pursuant
  to the aforementioned feature.
- Added macros to frame/include/bli_x86_asm_macros.h related to division
  instructions.
2020-12-04 16:08:15 -06:00

1585 lines
41 KiB
C

/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#define BLIS_ASM_SYNTAX_ATT
#include "bli_x86_asm_macros.h"
#define SGEMM_OUTPUT_GS_BETA_NZ \
vextractf128(imm(1), ymm0, xmm2) \
vmovss(xmm0, mem(rcx)) \
vpermilps(imm(0x39), xmm0, xmm1) \
vmovss(xmm1, mem(rcx, rsi, 1)) \
vpermilps(imm(0x39), xmm1, xmm0) \
vmovss(xmm0, mem(rcx, rsi, 2)) \
vpermilps(imm(0x39), xmm0, xmm1) \
vmovss(xmm1, mem(rcx, r13, 1)) \
vmovss(xmm2, mem(rcx, rsi, 4)) \
vpermilps(imm(0x39), xmm2, xmm1) \
vmovss(xmm1, mem(rcx, r15, 1)) \
vpermilps(imm(0x39), xmm1, xmm2) \
vmovss(xmm2, mem(rcx, r13, 2)) \
vpermilps(imm(0x39), xmm2, xmm1) \
vmovss(xmm1, mem(rcx, r10, 1))
void bli_sgemmtrsm_u_haswell_asm_6x16
(
dim_t k0,
float* restrict alpha,
float* restrict a10,
float* restrict a11,
float* restrict b01,
float* restrict b11,
float* restrict c11, inc_t rs_c0, inc_t cs_c0,
auxinfo_t* restrict data,
cntx_t* restrict cntx
)
{
//void* a_next = bli_auxinfo_next_a( data );
//void* b_next = bli_auxinfo_next_b( data );
// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
uint64_t k_iter = k0 / 4;
uint64_t k_left = k0 % 4;
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
float* beta = bli_sm1;
begin_asm()
vzeroall() // zero all xmm/ymm registers.
mov(var(a10), rax) // load address of a.
mov(var(b01), rbx) // load address of b.
add(imm(32*4), rbx)
// initialize loop by pre-loading
vmovaps(mem(rbx, -4*32), ymm0)
vmovaps(mem(rbx, -3*32), ymm1)
mov(var(b11), rcx) // load address of b11
mov(imm(16), rdi) // set rs_b = PACKNR = 16
lea(mem(, rdi, 4), rdi) // rs_b *= sizeof(float)
// NOTE: c11, rs_c, and cs_c aren't
// needed for a while, but we load
// them now to avoid stalling later.
mov(var(c11), r8) // load address of c11
mov(var(rs_c), r9) // load rs_c
lea(mem(, r9 , 4), r9) // rs_c *= sizeof(float)
mov(var(k_left)0, r10) // load cs_c
lea(mem(, r10, 4), r10) // cs_c *= sizeof(float)
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.SCONSIDKLEFT) // if i == 0, jump to code that
// contains the k_left loop.
label(.SLOOPKITER) // MAIN LOOP
// iteration 0
prefetch(0, mem(rax, 64*4))
vbroadcastss(mem(rax, 0*4), ymm2)
vbroadcastss(mem(rax, 1*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm4)
vfmadd231ps(ymm1, ymm2, ymm5)
vfmadd231ps(ymm0, ymm3, ymm6)
vfmadd231ps(ymm1, ymm3, ymm7)
vbroadcastss(mem(rax, 2*4), ymm2)
vbroadcastss(mem(rax, 3*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm8)
vfmadd231ps(ymm1, ymm2, ymm9)
vfmadd231ps(ymm0, ymm3, ymm10)
vfmadd231ps(ymm1, ymm3, ymm11)
vbroadcastss(mem(rax, 4*4), ymm2)
vbroadcastss(mem(rax, 5*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm12)
vfmadd231ps(ymm1, ymm2, ymm13)
vfmadd231ps(ymm0, ymm3, ymm14)
vfmadd231ps(ymm1, ymm3, ymm15)
vmovaps(mem(rbx, -2*32), ymm0)
vmovaps(mem(rbx, -1*32), ymm1)
// iteration 1
vbroadcastss(mem(rax, 6*4), ymm2)
vbroadcastss(mem(rax, 7*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm4)
vfmadd231ps(ymm1, ymm2, ymm5)
vfmadd231ps(ymm0, ymm3, ymm6)
vfmadd231ps(ymm1, ymm3, ymm7)
vbroadcastss(mem(rax, 8*4), ymm2)
vbroadcastss(mem(rax, 9*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm8)
vfmadd231ps(ymm1, ymm2, ymm9)
vfmadd231ps(ymm0, ymm3, ymm10)
vfmadd231ps(ymm1, ymm3, ymm11)
vbroadcastss(mem(rax, 10*4), ymm2)
vbroadcastss(mem(rax, 11*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm12)
vfmadd231ps(ymm1, ymm2, ymm13)
vfmadd231ps(ymm0, ymm3, ymm14)
vfmadd231ps(ymm1, ymm3, ymm15)
vmovaps(mem(rbx, 0*32), ymm0)
vmovaps(mem(rbx, 1*32), ymm1)
// iteration 2
prefetch(0, mem(rax, 76*4))
vbroadcastss(mem(rax, 12*4), ymm2)
vbroadcastss(mem(rax, 13*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm4)
vfmadd231ps(ymm1, ymm2, ymm5)
vfmadd231ps(ymm0, ymm3, ymm6)
vfmadd231ps(ymm1, ymm3, ymm7)
vbroadcastss(mem(rax, 14*4), ymm2)
vbroadcastss(mem(rax, 15*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm8)
vfmadd231ps(ymm1, ymm2, ymm9)
vfmadd231ps(ymm0, ymm3, ymm10)
vfmadd231ps(ymm1, ymm3, ymm11)
vbroadcastss(mem(rax, 16*4), ymm2)
vbroadcastss(mem(rax, 17*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm12)
vfmadd231ps(ymm1, ymm2, ymm13)
vfmadd231ps(ymm0, ymm3, ymm14)
vfmadd231ps(ymm1, ymm3, ymm15)
vmovaps(mem(rbx, 2*32), ymm0)
vmovaps(mem(rbx, 3*32), ymm1)
// iteration 3
vbroadcastss(mem(rax, 18*4), ymm2)
vbroadcastss(mem(rax, 19*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm4)
vfmadd231ps(ymm1, ymm2, ymm5)
vfmadd231ps(ymm0, ymm3, ymm6)
vfmadd231ps(ymm1, ymm3, ymm7)
vbroadcastss(mem(rax, 20*4), ymm2)
vbroadcastss(mem(rax, 21*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm8)
vfmadd231ps(ymm1, ymm2, ymm9)
vfmadd231ps(ymm0, ymm3, ymm10)
vfmadd231ps(ymm1, ymm3, ymm11)
vbroadcastss(mem(rax, 22*4), ymm2)
vbroadcastss(mem(rax, 23*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm12)
vfmadd231ps(ymm1, ymm2, ymm13)
vfmadd231ps(ymm0, ymm3, ymm14)
vfmadd231ps(ymm1, ymm3, ymm15)
add(imm(4*6*4), rax) // a += 4*6 (unroll x mr)
add(imm(4*16*4), rbx) // b += 4*16 (unroll x nr)
vmovaps(mem(rbx, -4*32), ymm0)
vmovaps(mem(rbx, -3*32), ymm1)
dec(rsi) // i -= 1;
jne(.SLOOPKITER) // iterate again if i != 0.
label(.SCONSIDKLEFT)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.SLOOPKLEFT) // EDGE LOOP
prefetch(0, mem(rax, 64*4))
vbroadcastss(mem(rax, 0*4), ymm2)
vbroadcastss(mem(rax, 1*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm4)
vfmadd231ps(ymm1, ymm2, ymm5)
vfmadd231ps(ymm0, ymm3, ymm6)
vfmadd231ps(ymm1, ymm3, ymm7)
vbroadcastss(mem(rax, 2*4), ymm2)
vbroadcastss(mem(rax, 3*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm8)
vfmadd231ps(ymm1, ymm2, ymm9)
vfmadd231ps(ymm0, ymm3, ymm10)
vfmadd231ps(ymm1, ymm3, ymm11)
vbroadcastss(mem(rax, 4*4), ymm2)
vbroadcastss(mem(rax, 5*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm12)
vfmadd231ps(ymm1, ymm2, ymm13)
vfmadd231ps(ymm0, ymm3, ymm14)
vfmadd231ps(ymm1, ymm3, ymm15)
add(imm(1*6*4), rax) // a += 1*6 (unroll x mr)
add(imm(1*16*4), rbx) // b += 1*16 (unroll x nr)
vmovaps(mem(rbx, -4*32), ymm0)
vmovaps(mem(rbx, -3*32), ymm1)
dec(rsi) // i -= 1;
jne(.SLOOPKLEFT) // iterate again if i != 0.
label(.SPOSTACCUM)
// ymm4..ymm15 = -a10 * b01
mov(var(alpha), rbx) // load address of alpha
vbroadcastss(mem(rbx), ymm3) // load alpha and duplicate
mov(imm(1), rsi) // load cs_b = 1
lea(mem(, rsi, 4), rsi) // cs_b *= sizeof(float)
lea(mem(rcx, rsi, 8), rdx) // load address of b11 + 8*cs_b
mov(rcx, r11) // save rcx = b11 for later
mov(rdx, r14) // save rdx = b11+8*cs_b for later
// b11 := alpha * b11 - a10 * b01
vfmsub231ps(mem(rcx), ymm3, ymm4)
add(rdi, rcx)
vfmsub231ps(mem(rdx), ymm3, ymm5)
add(rdi, rdx)
vfmsub231ps(mem(rcx), ymm3, ymm6)
add(rdi, rcx)
vfmsub231ps(mem(rdx), ymm3, ymm7)
add(rdi, rdx)
vfmsub231ps(mem(rcx), ymm3, ymm8)
add(rdi, rcx)
vfmsub231ps(mem(rdx), ymm3, ymm9)
add(rdi, rdx)
vfmsub231ps(mem(rcx), ymm3, ymm10)
add(rdi, rcx)
vfmsub231ps(mem(rdx), ymm3, ymm11)
add(rdi, rdx)
vfmsub231ps(mem(rcx), ymm3, ymm12)
add(rdi, rcx)
vfmsub231ps(mem(rdx), ymm3, ymm13)
add(rdi, rdx)
vfmsub231ps(mem(rcx), ymm3, ymm14)
//add(rdi, rcx)
vfmsub231ps(mem(rdx), ymm3, ymm15)
//add(rdi, rdx)
// prefetch c11
#if 0
mov(r8, rcx) // load address of c11 from r8
// Note: r9 = rs_c * sizeof(float)
lea(mem(r9 , r9 , 2), r13) // r13 = 3*rs_c;
lea(mem(rcx, r13, 1), rdx) // rdx = c11 + 3*rs_c;
prefetch(0, mem(rcx, 0*8)) // prefetch c11 + 0*rs_c
prefetch(0, mem(rcx, r9, 1, 0*8)) // prefetch c11 + 1*rs_c
prefetch(0, mem(rcx, r9 , 2, 0*8)) // prefetch c11 + 2*rs_c
prefetch(0, mem(rdx, 0*8)) // prefetch c11 + 3*rs_c
prefetch(0, mem(rdx, r9, 1, 0*8)) // prefetch c11 + 4*rs_c
prefetch(0, mem(rdx, r9 , 2, 0*8)) // prefetch c11 + 5*rs_c
#endif
// trsm computation begins here
// Note: contents of b11 are stored as
// ymm4 ymm5 = ( beta00..07 ) ( beta08..0F )
// ymm6 ymm7 = ( beta10..17 ) ( beta18..1F )
// ymm8 ymm9 = ( beta20..27 ) ( beta28..2F )
// ymm10 ymm11 = ( beta30..37 ) ( beta38..3F )
// ymm12 ymm13 = ( beta40..47 ) ( beta48..4F )
// ymm14 ymm15 = ( beta50..57 ) ( beta58..5F )
mov(var(a11), rax) // load address of a11
mov(r11, rcx) // recall address of b11
mov(r14, rdx) // recall address of b11+8*cs_b
lea(mem(rcx, rdi, 4), rcx) // rcx = b11 + (6-1)*rs_b
lea(mem(rcx, rdi, 1), rcx)
lea(mem(rdx, rdi, 4), rdx) // rdx = b11 + (6-1)*rs_b + 8*cs_b
lea(mem(rdx, rdi, 1), rdx)
// iteration 0 -------------
vbroadcastss(mem(5+5*6)*4(rax), ymm0) // ymm0 = (1/alpha55)
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
vmulps(ymm0, ymm14, ymm14) // ymm14 *= (1/alpha55)
vmulps(ymm0, ymm15, ymm15) // ymm15 *= (1/alpha55)
#else
vdivps(ymm0, ymm14, ymm14) // ymm14 /= alpha55
vdivps(ymm0, ymm15, ymm15) // ymm15 /= alpha55
#endif
vmovups(ymm14, mem(rcx)) // store ( beta50..beta57 ) = ymm14
vmovups(ymm15, mem(rdx)) // store ( beta58..beta5F ) = ymm15
sub(rdi, rcx) // rcx -= rs_b
sub(rdi, rdx) // rdx -= rs_b
// iteration 1 -------------
vbroadcastss(mem(4+5*6)*4(rax), ymm0) // ymm0 = alpha45
vbroadcastss(mem(4+4*6)*4(rax), ymm1) // ymm1 = (1/alpha44)
vmulps(ymm0, ymm14, ymm2) // ymm2 = alpha45 * ymm14
vmulps(ymm0, ymm15, ymm3) // ymm3 = alpha45 * ymm15
vsubps(ymm2, ymm12, ymm12) // ymm12 -= ymm2
vsubps(ymm3, ymm13, ymm13) // ymm13 -= ymm3
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
vmulps(ymm1, ymm12, ymm12) // ymm12 *= (1/alpha44)
vmulps(ymm1, ymm13, ymm13) // ymm13 *= (1/alpha44)
#else
vdivps(ymm1, ymm12, ymm12) // ymm12 /= alpha44
vdivps(ymm1, ymm13, ymm13) // ymm13 /= alpha44
#endif
vmovups(ymm12, mem(rcx)) // store ( beta40..beta47 ) = ymm12
vmovups(ymm13, mem(rdx)) // store ( beta48..beta4F ) = ymm13
sub(rdi, rcx) // rcx -= rs_b
sub(rdi, rdx) // rdx -= rs_b
// iteration 2 -------------
vbroadcastss(mem(3+5*6)*4(rax), ymm0) // ymm0 = alpha35
vbroadcastss(mem(3+4*6)*4(rax), ymm1) // ymm1 = alpha34
vmulps(ymm0, ymm14, ymm2) // ymm2 = alpha35 * ymm14
vmulps(ymm0, ymm15, ymm3) // ymm3 = alpha35 * ymm15
vbroadcastss(mem(3+3*6)*4(rax), ymm0) // ymm0 = (1/alpha33)
vfmadd231ps(ymm1, ymm12, ymm2) // ymm2 += alpha34 * ymm12
vfmadd231ps(ymm1, ymm13, ymm3) // ymm3 += alpha34 * ymm13
vsubps(ymm2, ymm10, ymm10) // ymm10 -= ymm2
vsubps(ymm3, ymm11, ymm11) // ymm11 -= ymm3
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
vmulps(ymm0, ymm10, ymm10) // ymm10 *= (1/alpha33)
vmulps(ymm0, ymm11, ymm11) // ymm11 *= (1/alpha33)
#else
vdivps(ymm0, ymm10, ymm10) // ymm10 /= alpha33
vdivps(ymm0, ymm11, ymm11) // ymm11 /= alpha33
#endif
vmovups(ymm10, mem(rcx)) // store ( beta30..beta37 ) = ymm10
vmovups(ymm11, mem(rdx)) // store ( beta38..beta3F ) = ymm11
sub(rdi, rcx) // rcx -= rs_b
sub(rdi, rdx) // rdx -= rs_b
// iteration 3 -------------
vbroadcastss(mem(2+5*6)*4(rax), ymm0) // ymm0 = alpha25
vbroadcastss(mem(2+4*6)*4(rax), ymm1) // ymm1 = alpha24
vmulps(ymm0, ymm14, ymm2) // ymm2 = alpha25 * ymm14
vmulps(ymm0, ymm15, ymm3) // ymm3 = alpha25 * ymm15
vbroadcastss(mem(2+3*6)*4(rax), ymm0) // ymm0 = alpha23
vfmadd231ps(ymm1, ymm12, ymm2) // ymm2 += alpha24 * ymm12
vfmadd231ps(ymm1, ymm13, ymm3) // ymm3 += alpha24 * ymm13
vbroadcastss(mem(2+2*6)*4(rax), ymm1) // ymm1 = (1/alpha22)
vfmadd231ps(ymm0, ymm10, ymm2) // ymm2 += alpha23 * ymm10
vfmadd231ps(ymm0, ymm11, ymm3) // ymm3 += alpha23 * ymm11
vsubps(ymm2, ymm8, ymm8) // ymm8 -= ymm2
vsubps(ymm3, ymm9, ymm9) // ymm9 -= ymm3
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
vmulps(ymm1, ymm8, ymm8) // ymm8 *= (1/alpha22)
vmulps(ymm1, ymm9, ymm9) // ymm9 *= (1/alpha22)
#else
vdivps(ymm1, ymm8, ymm8) // ymm8 /= alpha22
vdivps(ymm1, ymm9, ymm9) // ymm9 /= alpha22
#endif
vmovups(ymm8, mem(rcx)) // store ( beta20..beta27 ) = ymm8
vmovups(ymm9, mem(rdx)) // store ( beta28..beta2F ) = ymm9
sub(rdi, rcx) // rcx -= rs_b
sub(rdi, rdx) // rdx -= rs_b
// iteration 4 -------------
vbroadcastss(mem(1+5*6)*4(rax), ymm0) // ymm0 = alpha15
vbroadcastss(mem(1+4*6)*4(rax), ymm1) // ymm1 = alpha14
vmulps(ymm0, ymm14, ymm2) // ymm2 = alpha15 * ymm14
vmulps(ymm0, ymm15, ymm3) // ymm3 = alpha15 * ymm15
vbroadcastss(mem(1+3*6)*4(rax), ymm0) // ymm0 = alpha13
vfmadd231ps(ymm1, ymm12, ymm2) // ymm2 += alpha14 * ymm12
vfmadd231ps(ymm1, ymm13, ymm3) // ymm3 += alpha14 * ymm13
vbroadcastss(mem(1+2*6)*4(rax), ymm1) // ymm1 = alpha12
vfmadd231ps(ymm0, ymm10, ymm2) // ymm2 += alpha13 * ymm10
vfmadd231ps(ymm0, ymm11, ymm3) // ymm3 += alpha13 * ymm11
vbroadcastss(mem(1+1*6)*4(rax), ymm0) // ymm4 = (1/alpha11)
vfmadd231ps(ymm1, ymm8, ymm2) // ymm2 += alpha12 * ymm8
vfmadd231ps(ymm1, ymm9, ymm3) // ymm3 += alpha12 * ymm9
vsubps(ymm2, ymm6, ymm6) // ymm6 -= ymm2
vsubps(ymm3, ymm7, ymm7) // ymm7 -= ymm3
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
vmulps(ymm0, ymm6, ymm6) // ymm6 *= (1/alpha11)
vmulps(ymm0, ymm7, ymm7) // ymm7 *= (1/alpha11)
#else
vdivps(ymm0, ymm6, ymm6) // ymm6 /= alpha11
vdivps(ymm0, ymm7, ymm7) // ymm7 /= alpha11
#endif
vmovups(ymm6, mem(rcx)) // store ( beta10..beta17 ) = ymm6
vmovups(ymm7, mem(rdx)) // store ( beta18..beta1F ) = ymm7
sub(rdi, rcx) // rcx -= rs_b
sub(rdi, rdx) // rdx -= rs_b
// iteration 5 -------------
vbroadcastss(mem(0+5*6)*4(rax), ymm0) // ymm0 = alpha05
vbroadcastss(mem(0+4*6)*4(rax), ymm1) // ymm1 = alpha04
vmulps(ymm0, ymm14, ymm2) // ymm2 = alpha05 * ymm14
vmulps(ymm0, ymm15, ymm3) // ymm3 = alpha05 * ymm15
vbroadcastss(mem(0+3*6)*4(rax), ymm0) // ymm0 = alpha03
vfmadd231ps(ymm1, ymm12, ymm2) // ymm2 += alpha04 * ymm12
vfmadd231ps(ymm1, ymm13, ymm3) // ymm3 += alpha04 * ymm13
vbroadcastss(mem(0+2*6)*4(rax), ymm1) // ymm1 = alpha02
vfmadd231ps(ymm0, ymm10, ymm2) // ymm2 += alpha03 * ymm10
vfmadd231ps(ymm0, ymm11, ymm3) // ymm3 += alpha03 * ymm11
vbroadcastss(mem(0+1*6)*4(rax), ymm0) // ymm0 = alpha01
vfmadd231ps(ymm1, ymm8, ymm2) // ymm2 += alpha02 * ymm8
vfmadd231ps(ymm1, ymm9, ymm3) // ymm3 += alpha02 * ymm9
vbroadcastss(mem(0+0*6)*4(rax), ymm1) // ymm1 = (1/alpha00)
vfmadd231ps(ymm0, ymm6, ymm2) // ymm2 += alpha01 * ymm6
vfmadd231ps(ymm0, ymm7, ymm3) // ymm3 += alpha01 * ymm7
vsubps(ymm2, ymm4, ymm4) // ymm4 -= ymm2
vsubps(ymm3, ymm5, ymm5) // ymm5 -= ymm3
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
vmulps(ymm1, ymm4, ymm4) // ymm4 *= (1/alpha00)
vmulps(ymm1, ymm5, ymm5) // ymm5 *= (1/alpha00)
#else
vdivps(ymm1, ymm4, ymm4) // ymm4 /= alpha00
vdivps(ymm1, ymm5, ymm5) // ymm5 /= alpha00
#endif
vmovups(ymm4, mem(rcx)) // store ( beta00..beta07 ) = ymm4
vmovups(ymm5, mem(rdx)) // store ( beta08..beta0F ) = ymm5
sub(rdi, rcx) // rcx -= rs_b
sub(rdi, rdx) // rdx -= rs_b
mov(r8, rcx) // load address of c11 from r8
mov(r9, rdi) // load rs_c (in bytes) from r9
mov(r10, rsi) // load cs_c (in bytes) from r10
lea(mem(rcx, rsi, 8), rdx) // load address of c11 + 8*cs_c;
lea(mem(rcx, rdi, 4), r14) // load address of c11 + 4*rs_c;
// These are used in the macros below.
lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c;
lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c;
lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c;
cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4.
jz(.SROWSTORED) // jump to row storage case
cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4.
jz(.SCOLSTORED) // jump to column storage case
// if neither row- or column-
// stored, use general case.
label(.SGENSTORED)
vmovaps(ymm4, ymm0)
SGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovaps(ymm6, ymm0)
SGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovaps(ymm8, ymm0)
SGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovaps(ymm10, ymm0)
SGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovaps(ymm12, ymm0)
SGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovaps(ymm14, ymm0)
SGEMM_OUTPUT_GS_BETA_NZ
mov(rdx, rcx) // rcx = c11 + 8*cs_c
vmovaps(ymm5, ymm0)
SGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovaps(ymm7, ymm0)
SGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovaps(ymm9, ymm0)
SGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovaps(ymm11, ymm0)
SGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovaps(ymm13, ymm0)
SGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovaps(ymm15, ymm0)
SGEMM_OUTPUT_GS_BETA_NZ
jmp(.SDONE)
label(.SROWSTORED)
vmovups(ymm4, mem(rcx))
add(rdi, rcx)
vmovups(ymm5, mem(rdx))
add(rdi, rdx)
vmovups(ymm6, mem(rcx))
add(rdi, rcx)
vmovups(ymm7, mem(rdx))
add(rdi, rdx)
vmovups(ymm8, mem(rcx))
add(rdi, rcx)
vmovups(ymm9, mem(rdx))
add(rdi, rdx)
vmovups(ymm10, mem(rcx))
add(rdi, rcx)
vmovups(ymm11, mem(rdx))
add(rdi, rdx)
vmovups(ymm12, mem(rcx))
add(rdi, rcx)
vmovups(ymm13, mem(rdx))
add(rdi, rdx)
vmovups(ymm14, mem(rcx))
//add(rdi, rcx)
vmovups(ymm15, mem(rdx))
//add(rdi, rdx)
jmp(.SDONE)
label(.SCOLSTORED)
vunpcklps(ymm6, ymm4, ymm0)
vunpcklps(ymm10, ymm8, ymm1)
vshufps(imm(0x4e), ymm1, ymm0, ymm2)
vblendps(imm(0xcc), ymm2, ymm0, ymm0)
vblendps(imm(0x33), ymm2, ymm1, ymm1)
vextractf128(imm(0x1), ymm0, xmm2)
vextractf128(imm(0x1), ymm1, xmm3)
vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 )
vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 )
vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 )
vmovups(xmm3, mem(rcx, r15, 1)) // store ( gamma05..gamma35 )
vunpckhps(ymm6, ymm4, ymm0)
vunpckhps(ymm10, ymm8, ymm1)
vshufps(imm(0x4e), ymm1, ymm0, ymm2)
vblendps(imm(0xcc), ymm2, ymm0, ymm0)
vblendps(imm(0x33), ymm2, ymm1, ymm1)
vextractf128(imm(0x1), ymm0, xmm2)
vextractf128(imm(0x1), ymm1, xmm3)
vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 )
vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 )
vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 )
vmovups(xmm3, mem(rcx, r10, 1)) // store ( gamma07..gamma37 )
lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c
vunpcklps(ymm14, ymm12, ymm0)
vunpckhps(ymm14, ymm12, ymm1)
vextractf128(imm(0x1), ymm0, xmm2)
vextractf128(imm(0x1), ymm1, xmm3)
vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 )
vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 )
vmovlpd(xmm1, mem(r14, rsi, 2)) // store ( gamma42..gamma52 )
vmovhpd(xmm1, mem(r14, r13, 1)) // store ( gamma43..gamma53 )
vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 )
vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 )
vmovlpd(xmm3, mem(r14, r13, 2)) // store ( gamma46..gamma56 )
vmovhpd(xmm3, mem(r14, r10, 1)) // store ( gamma47..gamma57 )
lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c
vunpcklps(ymm7, ymm5, ymm0)
vunpcklps(ymm11, ymm9, ymm1)
vshufps(imm(0x4e), ymm1, ymm0, ymm2)
vblendps(imm(0xcc), ymm2, ymm0, ymm0)
vblendps(imm(0x33), ymm2, ymm1, ymm1)
vextractf128(imm(0x1), ymm0, xmm2)
vextractf128(imm(0x1), ymm1, xmm3)
vmovups(xmm0, mem(rcx)) // store ( gamma08..gamma38 )
vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma09..gamma39 )
vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma0C..gamma3C )
vmovups(xmm3, mem(rcx, r15, 1)) // store ( gamma0D..gamma3D )
vunpckhps(ymm7, ymm5, ymm0)
vunpckhps(ymm11, ymm9, ymm1)
vshufps(imm(0x4e), ymm1, ymm0, ymm2)
vblendps(imm(0xcc), ymm2, ymm0, ymm0)
vblendps(imm(0x33), ymm2, ymm1, ymm1)
vextractf128(imm(0x1), ymm0, xmm2)
vextractf128(imm(0x1), ymm1, xmm3)
vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma0A..gamma3A )
vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma0B..gamma3B )
vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma0E..gamma3E )
vmovups(xmm3, mem(rcx, r10, 1)) // store ( gamma0F..gamma3F )
//lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c
vunpcklps(ymm15, ymm13, ymm0)
vunpckhps(ymm15, ymm13, ymm1)
vextractf128(imm(0x1), ymm0, xmm2)
vextractf128(imm(0x1), ymm1, xmm3)
vmovlpd(xmm0, mem(r14)) // store ( gamma48..gamma58 )
vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma49..gamma59 )
vmovlpd(xmm1, mem(r14, rsi, 2)) // store ( gamma4A..gamma5A )
vmovhpd(xmm1, mem(r14, r13, 1)) // store ( gamma4B..gamma5B )
vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma4C..gamma5C )
vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma4D..gamma5D )
vmovlpd(xmm3, mem(r14, r13, 2)) // store ( gamma4E..gamma5E )
vmovhpd(xmm3, mem(r14, r10, 1)) // store ( gamma4F..gamma5F )
//lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c
label(.SDONE)
vzeroupper()
end_asm(
: // output operands (none)
: // input operands
[k_iter] "m" (k_iter), // 0
[k_left] "m" (k_left), // 1
[a10] "m" (a10), // 2
[b01] "m" (b01), // 3
[beta] "m" (beta), // 4
[alpha] "m" (alpha), // 5
[a11] "m" (a11), // 6
[b11] "m" (b11), // 7
[c11] "m" (c11), // 8
[rs_c] "m" (rs_c), // 9
[cs_c] "m" (cs_c) // 10
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
"xmm8", "xmm9", "xmm10", "xmm11",
"xmm12", "xmm13", "xmm14", "xmm15",
"memory"
)
}
#define DGEMM_OUTPUT_GS_BETA_NZ \
vextractf128(imm(1), ymm0, xmm1) \
vmovlpd(xmm0, mem(rcx)) \
vmovhpd(xmm0, mem(rcx, rsi, 1)) \
vmovlpd(xmm1, mem(rcx, rsi, 2)) \
vmovhpd(xmm1, mem(rcx, r13, 1)) /*\
vextractf128(imm(1), ymm2, xmm1) \
vmovlpd(xmm2, mem(rcx, rsi, 4)) \
vmovhpd(xmm2, mem(rcx, r15, 1)) \
vmovlpd(xmm1, mem(rcx, r13, 2)) \
vmovhpd(xmm1, mem(rcx, r10, 1))*/
void bli_dgemmtrsm_u_haswell_asm_6x8
(
dim_t k0,
double* restrict alpha,
double* restrict a10,
double* restrict a11,
double* restrict b01,
double* restrict b11,
double* restrict c11, inc_t rs_c0, inc_t cs_c0,
auxinfo_t* restrict data,
cntx_t* restrict cntx
)
{
//void* a_next = bli_auxinfo_next_a( data );
//void* b_next = bli_auxinfo_next_b( data );
// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
uint64_t k_iter = k0 / 4;
uint64_t k_left = k0 % 4;
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
double* beta = bli_dm1;
begin_asm()
vzeroall() // zero all xmm/ymm registers.
mov(var(a10), rax) // load address of a.
mov(var(b01), rbx) // load address of b.
add(imm(32*4), rbx)
// initialize loop by pre-loading
vmovapd(mem(rbx, -4*32), ymm0)
vmovapd(mem(rbx, -3*32), ymm1)
mov(var(b11), rcx) // load address of b11
mov(imm(8), rdi) // set rs_b = PACKNR = 8
lea(mem(, rdi, 8), rdi) // rs_b *= sizeof(double)
// NOTE: c11, rs_c, and cs_c aren't
// needed for a while, but we load
// them now to avoid stalling later.
mov(var(c11), r8) // load address of c11
mov(var(rs_c), r9) // load rs_c
lea(mem(, r9 , 8), r9) // rs_c *= sizeof(double)
mov(var(k_left)0, r10) // load cs_c
lea(mem(, r10, 8), r10) // cs_c *= sizeof(double)
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.DCONSIDKLEFT) // if i == 0, jump to code that
// contains the k_left loop.
label(.DLOOPKITER) // MAIN LOOP
// iteration 0
prefetch(0, mem(rax, 64*8))
vbroadcastsd(mem(rax, 0*8), ymm2)
vbroadcastsd(mem(rax, 1*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 2*8), ymm2)
vbroadcastsd(mem(rax, 3*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 4*8), ymm2)
vbroadcastsd(mem(rax, 5*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
vmovapd(mem(rbx, -2*32), ymm0)
vmovapd(mem(rbx, -1*32), ymm1)
// iteration 1
vbroadcastsd(mem(rax, 6*8), ymm2)
vbroadcastsd(mem(rax, 7*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 8*8), ymm2)
vbroadcastsd(mem(rax, 9*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 10*8), ymm2)
vbroadcastsd(mem(rax, 11*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
vmovapd(mem(rbx, 0*32), ymm0)
vmovapd(mem(rbx, 1*32), ymm1)
// iteration 2
prefetch(0, mem(rax, 76*8))
vbroadcastsd(mem(rax, 12*8), ymm2)
vbroadcastsd(mem(rax, 13*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 14*8), ymm2)
vbroadcastsd(mem(rax, 15*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 16*8), ymm2)
vbroadcastsd(mem(rax, 17*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
vmovapd(mem(rbx, 2*32), ymm0)
vmovapd(mem(rbx, 3*32), ymm1)
// iteration 3
vbroadcastsd(mem(rax, 18*8), ymm2)
vbroadcastsd(mem(rax, 19*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 20*8), ymm2)
vbroadcastsd(mem(rax, 21*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 22*8), ymm2)
vbroadcastsd(mem(rax, 23*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
add(imm(4*6*8), rax) // a += 4*6 (unroll x mr)
add(imm(4*8*8), rbx) // b += 4*8 (unroll x nr)
vmovapd(mem(rbx, -4*32), ymm0)
vmovapd(mem(rbx, -3*32), ymm1)
dec(rsi) // i -= 1;
jne(.DLOOPKITER) // iterate again if i != 0.
label(.DCONSIDKLEFT)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.DPOSTACCUM) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.DLOOPKLEFT) // EDGE LOOP
prefetch(0, mem(rax, 64*8))
vbroadcastsd(mem(rax, 0*8), ymm2)
vbroadcastsd(mem(rax, 1*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 2*8), ymm2)
vbroadcastsd(mem(rax, 3*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 4*8), ymm2)
vbroadcastsd(mem(rax, 5*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
add(imm(1*6*8), rax) // a += 1*6 (unroll x mr)
add(imm(1*8*8), rbx) // b += 1*8 (unroll x nr)
vmovapd(mem(rbx, -4*32), ymm0)
vmovapd(mem(rbx, -3*32), ymm1)
dec(rsi) // i -= 1;
jne(.DLOOPKLEFT) // iterate again if i != 0.
label(.DPOSTACCUM)
// ymm4..ymm15 = -a10 * b01
mov(var(alpha), rbx) // load address of alpha
vbroadcastsd(mem(rbx), ymm3) // load alpha and duplicate
mov(imm(1), rsi) // set cs_b = 1
lea(mem(, rsi, 8), rsi) // cs_b *= sizeof(double)
lea(mem(rcx, rsi, 4), rdx) // load address of b11 + 4*cs_b
mov(rcx, r11) // save rcx = b11 for later
mov(rdx, r14) // save rdx = b11+4*cs_b for later
// b11 := alpha * b11 - a10 * b01
vfmsub231pd(mem(rcx), ymm3, ymm4)
add(rdi, rcx)
vfmsub231pd(mem(rdx), ymm3, ymm5)
add(rdi, rdx)
vfmsub231pd(mem(rcx), ymm3, ymm6)
add(rdi, rcx)
vfmsub231pd(mem(rdx), ymm3, ymm7)
add(rdi, rdx)
vfmsub231pd(mem(rcx), ymm3, ymm8)
add(rdi, rcx)
vfmsub231pd(mem(rdx), ymm3, ymm9)
add(rdi, rdx)
vfmsub231pd(mem(rcx), ymm3, ymm10)
add(rdi, rcx)
vfmsub231pd(mem(rdx), ymm3, ymm11)
add(rdi, rdx)
vfmsub231pd(mem(rcx), ymm3, ymm12)
add(rdi, rcx)
vfmsub231pd(mem(rdx), ymm3, ymm13)
add(rdi, rdx)
vfmsub231pd(mem(rcx), ymm3, ymm14)
//add(rdi, rcx)
vfmsub231pd(mem(rdx), ymm3, ymm15)
//add(rdi, rdx)
// prefetch c11
#if 0
mov(r8, rcx) // load address of c11 from r8
// Note: r9 = rs_c * sizeof(double)
lea(mem(r9 , r9 , 2), r13) // r13 = 3*rs_c;
lea(mem(rcx, r13, 1), rdx) // rdx = c11 + 3*rs_c;
prefetch(0, mem(rcx, 7*8)) // prefetch c11 + 0*rs_c
prefetch(0, mem(rcx, r9, 1, 7*8)) // prefetch c11 + 1*rs_c
prefetch(0, mem(rcx, r9 , 2, 7*8)) // prefetch c11 + 2*rs_c
prefetch(0, mem(rdx, 7*8)) // prefetch c11 + 3*rs_c
prefetch(0, mem(rdx, r9, 1, 7*8)) // prefetch c11 + 4*rs_c
prefetch(0, mem(rdx, r9 , 2, 7*8)) // prefetch c11 + 5*rs_c
#endif
// trsm computation begins here
// Note: contents of b11 are stored as
// ymm4 ymm5 = ( beta00..03 ) ( beta04..07 )
// ymm6 ymm7 = ( beta10..13 ) ( beta14..17 )
// ymm8 ymm9 = ( beta20..23 ) ( beta24..27 )
// ymm10 ymm11 = ( beta30..33 ) ( beta34..37 )
// ymm12 ymm13 = ( beta40..43 ) ( beta44..47 )
// ymm14 ymm15 = ( beta50..53 ) ( beta54..57 )
mov(var(a11), rax) // load address of a11
mov(r11, rcx) // recall address of b11
mov(r14, rdx) // recall address of b11+4*cs_b
lea(mem(rcx, rdi, 4), rcx) // rcx = b11 + (6-1)*rs_b
lea(mem(rcx, rdi, 1), rcx)
lea(mem(rdx, rdi, 4), rdx) // rdx = b11 + (6-1)*rs_b + 4*cs_b
lea(mem(rdx, rdi, 1), rdx)
// iteration 0 -------------
vbroadcastsd(mem(5+5*6)*8(rax), ymm0) // ymm0 = (1/alpha55)
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
vmulpd(ymm0, ymm14, ymm14) // ymm14 *= (1/alpha55)
vmulpd(ymm0, ymm15, ymm15) // ymm15 *= (1/alpha55)
#else
vdivpd(ymm0, ymm14, ymm14) // ymm14 /= alpha55
vdivpd(ymm0, ymm15, ymm15) // ymm15 /= alpha55
#endif
vmovupd(ymm14, mem(rcx)) // store ( beta50..beta53 ) = ymm14
vmovupd(ymm15, mem(rdx)) // store ( beta54..beta57 ) = ymm15
sub(rdi, rcx) // rcx -= rs_b
sub(rdi, rdx) // rdx -= rs_b
// iteration 1 -------------
vbroadcastsd(mem(4+5*6)*8(rax), ymm0) // ymm0 = alpha45
vbroadcastsd(mem(4+4*6)*8(rax), ymm1) // ymm1 = (1/alpha44)
vmulpd(ymm0, ymm14, ymm2) // ymm2 = alpha45 * ymm14
vmulpd(ymm0, ymm15, ymm3) // ymm3 = alpha45 * ymm15
vsubpd(ymm2, ymm12, ymm12) // ymm12 -= ymm2
vsubpd(ymm3, ymm13, ymm13) // ymm13 -= ymm3
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
vmulpd(ymm1, ymm12, ymm12) // ymm12 *= (1/alpha44)
vmulpd(ymm1, ymm13, ymm13) // ymm13 *= (1/alpha44)
#else
vdivpd(ymm1, ymm12, ymm12) // ymm12 /= alpha44
vdivpd(ymm1, ymm13, ymm13) // ymm13 /= alpha44
#endif
vmovupd(ymm12, mem(rcx)) // store ( beta40..beta43 ) = ymm12
vmovupd(ymm13, mem(rdx)) // store ( beta44..beta47 ) = ymm13
sub(rdi, rcx) // rcx -= rs_b
sub(rdi, rdx) // rdx -= rs_b
// iteration 2 -------------
vbroadcastsd(mem(3+5*6)*8(rax), ymm0) // ymm0 = alpha35
vbroadcastsd(mem(3+4*6)*8(rax), ymm1) // ymm1 = alpha34
vmulpd(ymm0, ymm14, ymm2) // ymm2 = alpha35 * ymm14
vmulpd(ymm0, ymm15, ymm3) // ymm3 = alpha35 * ymm15
vbroadcastsd(mem(3+3*6)*8(rax), ymm0) // ymm0 = (1/alpha33)
vfmadd231pd(ymm1, ymm12, ymm2) // ymm2 += alpha34 * ymm12
vfmadd231pd(ymm1, ymm13, ymm3) // ymm3 += alpha34 * ymm13
vsubpd(ymm2, ymm10, ymm10) // ymm10 -= ymm2
vsubpd(ymm3, ymm11, ymm11) // ymm11 -= ymm3
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
vmulpd(ymm0, ymm10, ymm10) // ymm10 *= (1/alpha33)
vmulpd(ymm0, ymm11, ymm11) // ymm11 *= (1/alpha33)
#else
vdivpd(ymm0, ymm10, ymm10) // ymm10 /= alpha33
vdivpd(ymm0, ymm11, ymm11) // ymm11 /= alpha33
#endif
vmovupd(ymm10, mem(rcx)) // store ( beta30..beta33 ) = ymm10
vmovupd(ymm11, mem(rdx)) // store ( beta34..beta37 ) = ymm11
sub(rdi, rcx) // rcx -= rs_b
sub(rdi, rdx) // rdx -= rs_b
// iteration 3 -------------
vbroadcastsd(mem(2+5*6)*8(rax), ymm0) // ymm0 = alpha25
vbroadcastsd(mem(2+4*6)*8(rax), ymm1) // ymm1 = alpha24
vmulpd(ymm0, ymm14, ymm2) // ymm2 = alpha25 * ymm14
vmulpd(ymm0, ymm15, ymm3) // ymm3 = alpha25 * ymm15
vbroadcastsd(mem(2+3*6)*8(rax), ymm0) // ymm0 = alpha23
vfmadd231pd(ymm1, ymm12, ymm2) // ymm2 += alpha24 * ymm12
vfmadd231pd(ymm1, ymm13, ymm3) // ymm3 += alpha24 * ymm13
vbroadcastsd(mem(2+2*6)*8(rax), ymm1) // ymm1 = (1/alpha22)
vfmadd231pd(ymm0, ymm10, ymm2) // ymm2 += alpha23 * ymm10
vfmadd231pd(ymm0, ymm11, ymm3) // ymm3 += alpha23 * ymm11
vsubpd(ymm2, ymm8, ymm8) // ymm8 -= ymm2
vsubpd(ymm3, ymm9, ymm9) // ymm9 -= ymm3
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
vmulpd(ymm1, ymm8, ymm8) // ymm8 *= (1/alpha22)
vmulpd(ymm1, ymm9, ymm9) // ymm9 *= (1/alpha22)
#else
vdivpd(ymm1, ymm8, ymm8) // ymm8 /= alpha22
vdivpd(ymm1, ymm9, ymm9) // ymm9 /= alpha22
#endif
vmovupd(ymm8, mem(rcx)) // store ( beta20..beta23 ) = ymm8
vmovupd(ymm9, mem(rdx)) // store ( beta24..beta27 ) = ymm9
sub(rdi, rcx) // rcx -= rs_b
sub(rdi, rdx) // rdx -= rs_b
// iteration 4 -------------
vbroadcastsd(mem(1+5*6)*8(rax), ymm0) // ymm0 = alpha15
vbroadcastsd(mem(1+4*6)*8(rax), ymm1) // ymm1 = alpha14
vmulpd(ymm0, ymm14, ymm2) // ymm2 = alpha15 * ymm14
vmulpd(ymm0, ymm15, ymm3) // ymm3 = alpha15 * ymm15
vbroadcastsd(mem(1+3*6)*8(rax), ymm0) // ymm0 = alpha13
vfmadd231pd(ymm1, ymm12, ymm2) // ymm2 += alpha14 * ymm12
vfmadd231pd(ymm1, ymm13, ymm3) // ymm3 += alpha14 * ymm13
vbroadcastsd(mem(1+2*6)*8(rax), ymm1) // ymm1 = alpha12
vfmadd231pd(ymm0, ymm10, ymm2) // ymm2 += alpha13 * ymm10
vfmadd231pd(ymm0, ymm11, ymm3) // ymm3 += alpha13 * ymm11
vbroadcastsd(mem(1+1*6)*8(rax), ymm0) // ymm4 = (1/alpha11)
vfmadd231pd(ymm1, ymm8, ymm2) // ymm2 += alpha12 * ymm8
vfmadd231pd(ymm1, ymm9, ymm3) // ymm3 += alpha12 * ymm9
vsubpd(ymm2, ymm6, ymm6) // ymm6 -= ymm2
vsubpd(ymm3, ymm7, ymm7) // ymm7 -= ymm3
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
vmulpd(ymm0, ymm6, ymm6) // ymm6 *= (1/alpha11)
vmulpd(ymm0, ymm7, ymm7) // ymm7 *= (1/alpha11)
#else
vdivpd(ymm0, ymm6, ymm6) // ymm6 /= alpha11
vdivpd(ymm0, ymm7, ymm7) // ymm7 /= alpha11
#endif
vmovupd(ymm6, mem(rcx)) // store ( beta10..beta13 ) = ymm6
vmovupd(ymm7, mem(rdx)) // store ( beta14..beta17 ) = ymm7
sub(rdi, rcx) // rcx -= rs_b
sub(rdi, rdx) // rdx -= rs_b
// iteration 5 -------------
vbroadcastsd(mem(0+5*6)*8(rax), ymm0) // ymm0 = alpha05
vbroadcastsd(mem(0+4*6)*8(rax), ymm1) // ymm1 = alpha04
vmulpd(ymm0, ymm14, ymm2) // ymm2 = alpha05 * ymm14
vmulpd(ymm0, ymm15, ymm3) // ymm3 = alpha05 * ymm15
vbroadcastsd(mem(0+3*6)*8(rax), ymm0) // ymm0 = alpha03
vfmadd231pd(ymm1, ymm12, ymm2) // ymm2 += alpha04 * ymm12
vfmadd231pd(ymm1, ymm13, ymm3) // ymm3 += alpha04 * ymm13
vbroadcastsd(mem(0+2*6)*8(rax), ymm1) // ymm1 = alpha02
vfmadd231pd(ymm0, ymm10, ymm2) // ymm2 += alpha03 * ymm10
vfmadd231pd(ymm0, ymm11, ymm3) // ymm3 += alpha03 * ymm11
vbroadcastsd(mem(0+1*6)*8(rax), ymm0) // ymm0 = alpha01
vfmadd231pd(ymm1, ymm8, ymm2) // ymm2 += alpha02 * ymm8
vfmadd231pd(ymm1, ymm9, ymm3) // ymm3 += alpha02 * ymm9
vbroadcastsd(mem(0+0*6)*8(rax), ymm1) // ymm1 = (1/alpha00)
vfmadd231pd(ymm0, ymm6, ymm2) // ymm2 += alpha01 * ymm6
vfmadd231pd(ymm0, ymm7, ymm3) // ymm3 += alpha01 * ymm7
vsubpd(ymm2, ymm4, ymm4) // ymm4 -= ymm2
vsubpd(ymm3, ymm5, ymm5) // ymm5 -= ymm3
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
vmulpd(ymm1, ymm4, ymm4) // ymm4 *= (1/alpha00)
vmulpd(ymm1, ymm5, ymm5) // ymm5 *= (1/alpha00)
#else
vdivpd(ymm1, ymm4, ymm4) // ymm4 /= alpha00
vdivpd(ymm1, ymm5, ymm5) // ymm5 /= alpha00
#endif
vmovupd(ymm4, mem(rcx)) // store ( beta00..beta03 ) = ymm4
vmovupd(ymm5, mem(rdx)) // store ( beta04..beta07 ) = ymm5
sub(rdi, rcx) // rcx -= rs_b
sub(rdi, rdx) // rdx -= rs_b
mov(r8, rcx) // load address of c11 from r8
mov(r9, rdi) // load rs_c (in bytes) from r9
mov(r10, rsi) // load cs_c (in bytes) from r10
lea(mem(rcx, rsi, 4), rdx) // load address of c11 + 4*cs_c;
lea(mem(rcx, rdi, 4), r14) // load address of c11 + 4*rs_c;
// These are used in the macros below.
lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c;
//lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c;
//lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c;
cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8.
jz(.DROWSTORED) // jump to row storage case
cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8.
jz(.DCOLSTORED) // jump to column storage case
// if neither row- or column-
// stored, use general case.
label(.DGENSTORED)
vmovapd(ymm4, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovapd(ymm6, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovapd(ymm8, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovapd(ymm10, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovapd(ymm12, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovapd(ymm14, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
mov(rdx, rcx) // rcx = c11 + 4*cs_c
vmovapd(ymm5, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovapd(ymm7, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovapd(ymm9, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovapd(ymm11, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovapd(ymm13, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c11 += rs_c;
vmovapd(ymm15, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
jmp(.DDONE)
label(.DROWSTORED)
vmovupd(ymm4, mem(rcx))
add(rdi, rcx)
vmovupd(ymm5, mem(rdx))
add(rdi, rdx)
vmovupd(ymm6, mem(rcx))
add(rdi, rcx)
vmovupd(ymm7, mem(rdx))
add(rdi, rdx)
vmovupd(ymm8, mem(rcx))
add(rdi, rcx)
vmovupd(ymm9, mem(rdx))
add(rdi, rdx)
vmovupd(ymm10, mem(rcx))
add(rdi, rcx)
vmovupd(ymm11, mem(rdx))
add(rdi, rdx)
vmovupd(ymm12, mem(rcx))
add(rdi, rcx)
vmovupd(ymm13, mem(rdx))
add(rdi, rdx)
vmovupd(ymm14, mem(rcx))
//add(rdi, rcx)
vmovupd(ymm15, mem(rdx))
//add(rdi, rdx)
jmp(.DDONE)
label(.DCOLSTORED)
vunpcklpd(ymm6, ymm4, ymm0)
vunpckhpd(ymm6, ymm4, ymm1)
vunpcklpd(ymm10, ymm8, ymm2)
vunpckhpd(ymm10, ymm8, ymm3)
vinsertf128(imm(0x1), xmm2, ymm0, ymm4)
vinsertf128(imm(0x1), xmm3, ymm1, ymm6)
vperm2f128(imm(0x31), ymm2, ymm0, ymm8)
vperm2f128(imm(0x31), ymm3, ymm1, ymm10)
vmovupd(ymm4, mem(rcx))
vmovupd(ymm6, mem(rcx, rsi, 1))
vmovupd(ymm8, mem(rcx, rsi, 2))
vmovupd(ymm10, mem(rcx, r13, 1))
lea(mem(rcx, rsi, 4), rcx)
vunpcklpd(ymm14, ymm12, ymm0)
vunpckhpd(ymm14, ymm12, ymm1)
vextractf128(imm(0x1), ymm0, xmm2)
vextractf128(imm(0x1), ymm1, xmm3)
vmovupd(xmm0, mem(r14))
vmovupd(xmm1, mem(r14, rsi, 1))
vmovupd(xmm2, mem(r14, rsi, 2))
vmovupd(xmm3, mem(r14, r13, 1))
lea(mem(r14, rsi, 4), r14)
vunpcklpd(ymm7, ymm5, ymm0)
vunpckhpd(ymm7, ymm5, ymm1)
vunpcklpd(ymm11, ymm9, ymm2)
vunpckhpd(ymm11, ymm9, ymm3)
vinsertf128(imm(0x1), xmm2, ymm0, ymm5)
vinsertf128(imm(0x1), xmm3, ymm1, ymm7)
vperm2f128(imm(0x31), ymm2, ymm0, ymm9)
vperm2f128(imm(0x31), ymm3, ymm1, ymm11)
vmovupd(ymm5, mem(rcx))
vmovupd(ymm7, mem(rcx, rsi, 1))
vmovupd(ymm9, mem(rcx, rsi, 2))
vmovupd(ymm11, mem(rcx, r13, 1))
//lea(mem(rcx, rsi, 4), rcx)
vunpcklpd(ymm15, ymm13, ymm0)
vunpckhpd(ymm15, ymm13, ymm1)
vextractf128(imm(0x1), ymm0, xmm2)
vextractf128(imm(0x1), ymm1, xmm3)
vmovupd(xmm0, mem(r14))
vmovupd(xmm1, mem(r14, rsi, 1))
vmovupd(xmm2, mem(r14, rsi, 2))
vmovupd(xmm3, mem(r14, r13, 1))
//lea(mem(r14, rsi, 4), r14)
label(.DDONE)
vzeroupper()
end_asm(
: // output operands (none)
: // input operands
[k_iter] "m" (k_iter), // 0
[k_left] "m" (k_left), // 1
[a10] "m" (a10), // 2
[b01] "m" (b01), // 3
[beta] "m" (beta), // 4
[alpha] "m" (alpha), // 5
[a11] "m" (a11), // 6
[b11] "m" (b11), // 7
[c11] "m" (c11), // 8
[rs_c] "m" (rs_c), // 9
[cs_c] "m" (cs_c) // 10
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
"xmm8", "xmm9", "xmm10", "xmm11",
"xmm12", "xmm13", "xmm14", "xmm15",
"memory"
)
}