mirror of
https://github.com/amd/blis.git
synced 2026-05-11 01:30:00 +00:00
Details:
- Removed explicit reference to The University of Texas at Austin in the
third clause of the license comment blocks of all relevant files and
replaced it with a more all-encompassing "copyright holder(s)".
- Removed duplicate words ("derived") from a few kernels' license
comment blocks.
- Homogenized license comment block in kernels/zen/3/bli_gemm_small.c
with format of all other comment blocks.
1495 lines
38 KiB
C
1495 lines
38 KiB
C
/*
|
|
|
|
BLIS
|
|
An object-based framework for developing high-performance BLAS-like
|
|
libraries.
|
|
|
|
Copyright (C) 2014, The University of Texas at Austin
|
|
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are
|
|
met:
|
|
- Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
- Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
- Neither the name(s) of the copyright holder(s) nor the names of its
|
|
contributors may be used to endorse or promote products derived
|
|
from this software without specific prior written permission.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*/
|
|
|
|
#include "blis.h"
|
|
|
|
#define BLIS_ASM_SYNTAX_ATT
|
|
#include "bli_x86_asm_macros.h"
|
|
|
|
void bli_sgemm_penryn_asm_8x4
|
|
(
|
|
dim_t k0,
|
|
float* restrict alpha,
|
|
float* restrict a,
|
|
float* restrict b,
|
|
float* restrict beta,
|
|
float* restrict c, inc_t rs_c0, inc_t cs_c0,
|
|
auxinfo_t* restrict data,
|
|
cntx_t* restrict cntx
|
|
)
|
|
{
|
|
//void* a_next = bli_auxinfo_next_a( data );
|
|
void* b_next = bli_auxinfo_next_b( data );
|
|
|
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
|
// different size than is expected by load instructions.
|
|
uint64_t k_iter = k0 / 4;
|
|
uint64_t k_left = k0 % 4;
|
|
uint64_t rs_c = rs_c0;
|
|
uint64_t cs_c = cs_c0;
|
|
|
|
begin_asm()
|
|
|
|
|
|
mov(var(a), rax) // load address of a.
|
|
mov(var(b), rbx) // load address of b.
|
|
mov(var(b_next), r9) // load address of b_next.
|
|
|
|
sub(imm(0-8*16), rax) // increment pointers to allow byte
|
|
sub(imm(0-8*16), rbx) // offsets in the unrolled iterations.
|
|
|
|
movaps(mem(rax, -8*16), xmm0) // initialize loop by pre-loading elements
|
|
movaps(mem(rax, -7*16), xmm1) // of a and b.
|
|
movaps(mem(rbx, -8*16), xmm2)
|
|
|
|
mov(var(c), rcx) // load address of c
|
|
mov(var(cs_c), rdi) // load cs_c
|
|
lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float)
|
|
mov(rdi, r12) // make a copy of cs_c (in bytes)
|
|
lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c;
|
|
|
|
prefetch(2, mem(r9, 0*4)) // prefetch b_next
|
|
|
|
xorps(xmm3, xmm3)
|
|
xorps(xmm4, xmm4)
|
|
xorps(xmm5, xmm5)
|
|
xorps(xmm6, xmm6)
|
|
|
|
prefetch(2, mem(rcx, 6*4)) // prefetch c + 0*cs_c
|
|
xorps(xmm8, xmm8)
|
|
xorps(xmm9, xmm9)
|
|
prefetch(2, mem(rcx, rdi, 1, 6*4)) // prefetch c + 1*cs_c
|
|
xorps(xmm10, xmm10)
|
|
xorps(xmm11, xmm11)
|
|
prefetch(2, mem(r10, 6*4)) // prefetch c + 2*cs_c
|
|
xorps(xmm12, xmm12)
|
|
xorps(xmm13, xmm13)
|
|
prefetch(2, mem(r10, rdi, 1, 6*4)) // prefetch c + 3*cs_c
|
|
xorps(xmm14, xmm14)
|
|
xorps(xmm15, xmm15)
|
|
|
|
|
|
|
|
mov(var(k_iter), rsi) // i = k_iter;
|
|
test(rsi, rsi) // check i via logical AND.
|
|
je(.SCONSIDKLEFT) // if i == 0, jump to code that
|
|
// contains the k_left loop.
|
|
|
|
|
|
label(.SLOOPKITER) // MAIN LOOP
|
|
|
|
prefetch(0, mem(rax, (4*35+1)*8))
|
|
|
|
addps(xmm6, xmm10) // iteration 0
|
|
addps(xmm3, xmm14)
|
|
movaps(xmm2, xmm3)
|
|
pshufd(imm(0x39), xmm2, xmm7)
|
|
mulps(xmm0, xmm2)
|
|
mulps(xmm1, xmm3)
|
|
|
|
addps(xmm4, xmm11)
|
|
addps(xmm5, xmm15)
|
|
movaps(xmm7, xmm5)
|
|
pshufd(imm(0x39), xmm7, xmm6)
|
|
mulps(xmm0, xmm7)
|
|
mulps(xmm1, xmm5)
|
|
|
|
addps(xmm2, xmm8)
|
|
movaps(mem(rbx, -7*16), xmm2)
|
|
addps(xmm3, xmm12)
|
|
movaps(xmm6, xmm3)
|
|
pshufd(imm(0x39), xmm6, xmm4)
|
|
mulps(xmm0, xmm6)
|
|
mulps(xmm1, xmm3)
|
|
|
|
addps(xmm7, xmm9)
|
|
addps(xmm5, xmm13)
|
|
movaps(xmm4, xmm5)
|
|
mulps(xmm0, xmm4)
|
|
movaps(mem(rax, -6*16), xmm0)
|
|
mulps(xmm1, xmm5)
|
|
movaps(mem(rax, -5*16), xmm1)
|
|
|
|
|
|
addps(xmm6, xmm10) // iteration 1
|
|
addps(xmm3, xmm14)
|
|
movaps(xmm2, xmm3)
|
|
pshufd(imm(0x39), xmm2, xmm7)
|
|
mulps(xmm0, xmm2)
|
|
mulps(xmm1, xmm3)
|
|
|
|
addps(xmm4, xmm11)
|
|
addps(xmm5, xmm15)
|
|
movaps(xmm7, xmm5)
|
|
pshufd(imm(0x39), xmm7, xmm6)
|
|
mulps(xmm0, xmm7)
|
|
mulps(xmm1, xmm5)
|
|
|
|
addps(xmm2, xmm8)
|
|
movaps(mem(rbx, -6*16), xmm2)
|
|
addps(xmm3, xmm12)
|
|
movaps(xmm6, xmm3)
|
|
pshufd(imm(0x39), xmm6, xmm4)
|
|
mulps(xmm0, xmm6)
|
|
mulps(xmm1, xmm3)
|
|
|
|
addps(xmm7, xmm9)
|
|
addps(xmm5, xmm13)
|
|
movaps(xmm4, xmm5)
|
|
mulps(xmm0, xmm4)
|
|
movaps(mem(rax, -4*16), xmm0)
|
|
mulps(xmm1, xmm5)
|
|
movaps(mem(rax, -3*16), xmm1)
|
|
|
|
|
|
addps(xmm6, xmm10) // iteration 2
|
|
addps(xmm3, xmm14)
|
|
movaps(xmm2, xmm3)
|
|
pshufd(imm(0x39), xmm2, xmm7)
|
|
mulps(xmm0, xmm2)
|
|
mulps(xmm1, xmm3)
|
|
|
|
addps(xmm4, xmm11)
|
|
addps(xmm5, xmm15)
|
|
movaps(xmm7, xmm5)
|
|
pshufd(imm(0x39), xmm7, xmm6)
|
|
mulps(xmm0, xmm7)
|
|
mulps(xmm1, xmm5)
|
|
|
|
addps(xmm2, xmm8)
|
|
movaps(mem(rbx, -5*16), xmm2)
|
|
addps(xmm3, xmm12)
|
|
movaps(xmm6, xmm3)
|
|
pshufd(imm(0x39), xmm6, xmm4)
|
|
mulps(xmm0, xmm6)
|
|
mulps(xmm1, xmm3)
|
|
|
|
addps(xmm7, xmm9)
|
|
addps(xmm5, xmm13)
|
|
movaps(xmm4, xmm5)
|
|
mulps(xmm0, xmm4)
|
|
movaps(mem(rax, -2*16), xmm0)
|
|
mulps(xmm1, xmm5)
|
|
movaps(mem(rax, -1*16), xmm1)
|
|
|
|
|
|
addps(xmm6, xmm10) // iteration 3
|
|
addps(xmm3, xmm14)
|
|
movaps(xmm2, xmm3)
|
|
pshufd(imm(0x39), xmm2, xmm7)
|
|
mulps(xmm0, xmm2)
|
|
mulps(xmm1, xmm3)
|
|
|
|
sub(imm(0-4*8*4), rax) // a += 4*8 (unroll x mr)
|
|
|
|
addps(xmm4, xmm11)
|
|
addps(xmm5, xmm15)
|
|
movaps(xmm7, xmm5)
|
|
pshufd(imm(0x39), xmm7, xmm6)
|
|
mulps(xmm0, xmm7)
|
|
mulps(xmm1, xmm5)
|
|
|
|
sub(imm(0-4*4*4), r9) // b_next += 4*4 (unroll x nr)
|
|
|
|
addps(xmm2, xmm8)
|
|
movaps(mem(rbx, -4*16), xmm2)
|
|
addps(xmm3, xmm12)
|
|
movaps(xmm6, xmm3)
|
|
pshufd(imm(0x39), xmm6, xmm4)
|
|
mulps(xmm0, xmm6)
|
|
mulps(xmm1, xmm3)
|
|
|
|
sub(imm(0-4*4*4), rbx) // b += 4*4 (unroll x nr)
|
|
|
|
addps(xmm7, xmm9)
|
|
addps(xmm5, xmm13)
|
|
movaps(xmm4, xmm5)
|
|
mulps(xmm0, xmm4)
|
|
movaps(mem(rax, -8*16), xmm0)
|
|
mulps(xmm1, xmm5)
|
|
movaps(mem(rax, -7*16), xmm1)
|
|
|
|
prefetch(2, mem(r9, 0*4)) // prefetch b_next[0]
|
|
prefetch(2, mem(r9, 16*4)) // prefetch b_next[16]
|
|
|
|
|
|
dec(rsi) // i -= 1;
|
|
jne(.SLOOPKITER) // iterate again if i != 0.
|
|
|
|
|
|
|
|
label(.SCONSIDKLEFT)
|
|
|
|
mov(var(k_left), rsi) // i = k_left;
|
|
test(rsi, rsi) // check i via logical AND.
|
|
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
|
|
// else, we prepare to enter k_left loop.
|
|
|
|
|
|
label(.SLOOPKLEFT) // EDGE LOOP
|
|
|
|
addps(xmm6, xmm10) // iteration 0
|
|
addps(xmm3, xmm14)
|
|
movaps(xmm2, xmm3)
|
|
pshufd(imm(0x39), xmm2, xmm7)
|
|
mulps(xmm0, xmm2)
|
|
mulps(xmm1, xmm3)
|
|
|
|
addps(xmm4, xmm11)
|
|
addps(xmm5, xmm15)
|
|
movaps(xmm7, xmm5)
|
|
pshufd(imm(0x39), xmm7, xmm6)
|
|
mulps(xmm0, xmm7)
|
|
mulps(xmm1, xmm5)
|
|
|
|
addps(xmm2, xmm8)
|
|
movaps(mem(rbx, -7*16), xmm2)
|
|
addps(xmm3, xmm12)
|
|
movaps(xmm6, xmm3)
|
|
pshufd(imm(0x39), xmm6, xmm4)
|
|
mulps(xmm0, xmm6)
|
|
mulps(xmm1, xmm3)
|
|
|
|
addps(xmm7, xmm9)
|
|
addps(xmm5, xmm13)
|
|
movaps(xmm4, xmm5)
|
|
mulps(xmm0, xmm4)
|
|
movaps(mem(rax, -6*16), xmm0)
|
|
mulps(xmm1, xmm5)
|
|
movaps(mem(rax, -5*16), xmm1)
|
|
|
|
sub(imm(0-1*8*4), rax) // a += 8 (1 x mr)
|
|
sub(imm(0-1*4*4), rbx) // b += 4 (1 x nr)
|
|
|
|
|
|
dec(rsi) // i -= 1;
|
|
jne(.SLOOPKLEFT) // iterate again if i != 0.
|
|
|
|
|
|
|
|
label(.SPOSTACCUM)
|
|
|
|
addps(xmm6, xmm10)
|
|
addps(xmm3, xmm14)
|
|
addps(xmm4, xmm11)
|
|
addps(xmm5, xmm15)
|
|
|
|
|
|
mov(var(alpha), rax) // load address of alpha
|
|
mov(var(beta), rbx) // load address of beta
|
|
movss(mem(rax), xmm6) // load alpha to bottom 4 bytes of xmm6
|
|
movss(mem(rbx), xmm7) // load beta to bottom 4 bytes of xmm7
|
|
pshufd(imm(0x00), xmm6, xmm6) // populate xmm6 with four alphas
|
|
pshufd(imm(0x00), xmm7, xmm7) // populate xmm7 with four betas
|
|
|
|
|
|
mov(var(rs_c), rsi) // load rs_c
|
|
mov(rsi, r8) // make a copy of rs_c
|
|
|
|
lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float)
|
|
lea(mem(rsi, rsi, 2), r11) // r11 = 3*(rs_c * sizeof(float))
|
|
|
|
lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c;
|
|
|
|
// xmm8: xmm9: xmm10: xmm11:
|
|
// ( ab00 ( ab01 ( ab02 ( ab03
|
|
// ab11 ab12 ab13 ab10
|
|
// ab22 ab23 ab20 ab21
|
|
// ab33 ) ab30 ) ab31 ) ab32 )
|
|
//
|
|
// xmm12: xmm13: xmm14: xmm15:
|
|
// ( ab40 ( ab41 ( ab42 ( ab43
|
|
// ab51 ab52 ab53 ab50
|
|
// ab62 ab63 ab60 ab61
|
|
// ab73 ) ab70 ) ab71 ) ab72 )
|
|
movaps(xmm9, xmm4)
|
|
shufps(imm(0xd8), xmm8, xmm9)
|
|
shufps(imm(0xd8), xmm11, xmm8)
|
|
shufps(imm(0xd8), xmm10, xmm11)
|
|
shufps(imm(0xd8), xmm4, xmm10)
|
|
|
|
movaps(xmm8, xmm4)
|
|
shufps(imm(0xd8), xmm10, xmm8)
|
|
shufps(imm(0xd8), xmm4, xmm10)
|
|
movaps(xmm9, xmm5)
|
|
shufps(imm(0xd8), xmm11, xmm9)
|
|
shufps(imm(0xd8), xmm5, xmm11)
|
|
|
|
movaps(xmm13, xmm4)
|
|
shufps(imm(0xd8), xmm12, xmm13)
|
|
shufps(imm(0xd8), xmm15, xmm12)
|
|
shufps(imm(0xd8), xmm14, xmm15)
|
|
shufps(imm(0xd8), xmm4, xmm14)
|
|
|
|
movaps(xmm12, xmm4)
|
|
shufps(imm(0xd8), xmm14, xmm12)
|
|
shufps(imm(0xd8), xmm4, xmm14)
|
|
movaps(xmm13, xmm5)
|
|
shufps(imm(0xd8), xmm15, xmm13)
|
|
shufps(imm(0xd8), xmm5, xmm15)
|
|
// xmm8: xmm9: xmm10: xmm11:
|
|
// ( ab00 ( ab01 ( ab02 ( ab03
|
|
// ab10 ab11 ab12 ab13
|
|
// ab20 ab21 ab22 ab23
|
|
// ab30 ) ab31 ) ab32 ) ab33 )
|
|
//
|
|
// xmm12: xmm13: xmm14: xmm15:
|
|
// ( ab40 ( ab41 ( ab42 ( ab43
|
|
// ab50 ab51 ab52 ab53
|
|
// ab60 ab61 ab62 ab63
|
|
// ab70 ) ab71 ) ab72 ) ab73 )
|
|
|
|
|
|
|
|
// determine if
|
|
// c % 16 == 0, AND
|
|
// 8*cs_c % 16 == 0, AND
|
|
// rs_c == 1
|
|
// ie: aligned, ldim aligned, and
|
|
// column-stored
|
|
|
|
cmp(imm(1), r8) // set ZF if rs_c == 1.
|
|
sete(bl) // bl = ( ZF == 1 ? 1 : 0 );
|
|
test(imm(15), rcx) // set ZF if c & 16 is zero.
|
|
setz(bh) // bh = ( ZF == 1 ? 1 : 0 );
|
|
test(imm(15), r12) // set ZF if (4*cs_c) & 16 is zero.
|
|
setz(al) // al = ( ZF == 1 ? 1 : 0 );
|
|
// and(bl,bh) followed by
|
|
// and(bh,al) will reveal result
|
|
|
|
// now avoid loading C if beta == 0
|
|
|
|
xorpd(xmm0, xmm0) // set xmm0 to zero.
|
|
ucomisd(xmm0, xmm7) // check if beta == 0.
|
|
je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
|
|
|
|
|
|
// check if aligned/column-stored
|
|
and(bl, bh) // set ZF if bl & bh == 1.
|
|
and(bh, al) // set ZF if bh & al == 1.
|
|
jne(.SCOLSTORED) // jump to column storage case
|
|
|
|
|
|
|
|
label(.SGENSTORED)
|
|
|
|
movlps(mem(rcx), xmm0) // load c00 ~ c30
|
|
movhps(mem(rcx, rsi, 1), xmm0)
|
|
movlps(mem(rcx, rsi, 2), xmm1)
|
|
movhps(mem(rcx, r11, 1), xmm1)
|
|
shufps(imm(0x88), xmm1, xmm0)
|
|
|
|
mulps(xmm6, xmm8) // scale by alpha,
|
|
mulps(xmm7, xmm0) // scale by beta,
|
|
addps(xmm8, xmm0) // add the gemm result,
|
|
|
|
movss(xmm0, mem(rcx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rcx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rcx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rcx, r11, 1))
|
|
|
|
add(rdi, rcx)
|
|
|
|
|
|
movlps(mem(rdx), xmm0) // load c40 ~ c70
|
|
movhps(mem(rdx, rsi, 1), xmm0)
|
|
movlps(mem(rdx, rsi, 2), xmm1)
|
|
movhps(mem(rdx, r11, 1), xmm1)
|
|
shufps(imm(0x88), xmm1, xmm0)
|
|
|
|
mulps(xmm6, xmm12) // scale by alpha,
|
|
mulps(xmm7, xmm0) // scale by beta,
|
|
addps(xmm12, xmm0) // add the gemm result,
|
|
|
|
movss(xmm0, mem(rdx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rdx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rdx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rdx, r11, 1))
|
|
|
|
add(rdi, rdx)
|
|
|
|
|
|
movlps(mem(rcx), xmm0) // load c01 ~ c31
|
|
movhps(mem(rcx, rsi, 1), xmm0)
|
|
movlps(mem(rcx, rsi, 2), xmm1)
|
|
movhps(mem(rcx, r11, 1), xmm1)
|
|
shufps(imm(0x88), xmm1, xmm0)
|
|
|
|
mulps(xmm6, xmm9) // scale by alpha,
|
|
mulps(xmm7, xmm0) // scale by beta,
|
|
addps(xmm9, xmm0) // add the gemm result,
|
|
|
|
movss(xmm0, mem(rcx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rcx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rcx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rcx, r11, 1))
|
|
|
|
add(rdi, rcx)
|
|
|
|
|
|
movlps(mem(rdx), xmm0) // load c41 ~ c71
|
|
movhps(mem(rdx, rsi, 1), xmm0)
|
|
movlps(mem(rdx, rsi, 2), xmm1)
|
|
movhps(mem(rdx, r11, 1), xmm1)
|
|
shufps(imm(0x88), xmm1, xmm0)
|
|
|
|
mulps(xmm6, xmm13) // scale by alpha,
|
|
mulps(xmm7, xmm0) // scale by beta,
|
|
addps(xmm13, xmm0) // add the gemm result,
|
|
|
|
movss(xmm0, mem(rdx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rdx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rdx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rdx, r11, 1))
|
|
|
|
add(rdi, rdx)
|
|
|
|
|
|
movlps(mem(rcx), xmm0) // load c02 ~ c32
|
|
movhps(mem(rcx, rsi, 1), xmm0)
|
|
movlps(mem(rcx, rsi, 2), xmm1)
|
|
movhps(mem(rcx, r11, 1), xmm1)
|
|
shufps(imm(0x88), xmm1, xmm0)
|
|
|
|
mulps(xmm6, xmm10) // scale by alpha,
|
|
mulps(xmm7, xmm0) // scale by beta,
|
|
addps(xmm10, xmm0) // add the gemm result,
|
|
|
|
movss(xmm0, mem(rcx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rcx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rcx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rcx, r11, 1))
|
|
|
|
add(rdi, rcx)
|
|
|
|
|
|
movlps(mem(rdx), xmm0) // load c42 ~ c72
|
|
movhps(mem(rdx, rsi, 1), xmm0)
|
|
movlps(mem(rdx, rsi, 2), xmm1)
|
|
movhps(mem(rdx, r11, 1), xmm1)
|
|
shufps(imm(0x88), xmm1, xmm0)
|
|
|
|
mulps(xmm6, xmm14) // scale by alpha,
|
|
mulps(xmm7, xmm0) // scale by beta,
|
|
addps(xmm14, xmm0) // add the gemm result,
|
|
|
|
movss(xmm0, mem(rdx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rdx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rdx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rdx, r11, 1))
|
|
|
|
add(rdi, rdx)
|
|
|
|
|
|
movlps(mem(rcx), xmm0) // load c03 ~ c33
|
|
movhps(mem(rcx, rsi, 1), xmm0)
|
|
movlps(mem(rcx, rsi, 2), xmm1)
|
|
movhps(mem(rcx, r11, 1), xmm1)
|
|
shufps(imm(0x88), xmm1, xmm0)
|
|
|
|
mulps(xmm6, xmm11) // scale by alpha,
|
|
mulps(xmm7, xmm0) // scale by beta,
|
|
addps(xmm11, xmm0) // add the gemm result,
|
|
|
|
movss(xmm0, mem(rcx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rcx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rcx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rcx, r11, 1))
|
|
|
|
|
|
|
|
|
|
movlps(mem(rdx), xmm0) // load c43 ~ c73
|
|
movhps(mem(rdx, rsi, 1), xmm0)
|
|
movlps(mem(rdx, rsi, 2), xmm1)
|
|
movhps(mem(rdx, r11, 1), xmm1)
|
|
shufps(imm(0x88), xmm1, xmm0)
|
|
|
|
mulps(xmm6, xmm15) // scale by alpha,
|
|
mulps(xmm7, xmm0) // scale by beta,
|
|
addps(xmm15, xmm0) // add the gemm result,
|
|
|
|
movss(xmm0, mem(rdx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rdx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rdx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rdx, r11, 1))
|
|
|
|
|
|
|
|
|
|
jmp(.SDONE) // jump to end.
|
|
|
|
|
|
|
|
label(.SCOLSTORED)
|
|
|
|
movaps(mem(rcx), xmm0) // load c00 ~ c30,
|
|
mulps(xmm6, xmm8) // scale by alpha,
|
|
mulps(xmm7, xmm0) // scale by beta,
|
|
addps(xmm8, xmm0) // add the gemm result,
|
|
movaps(xmm0, mem(rcx)) // and store back to memory.
|
|
add(rdi, rcx)
|
|
|
|
movaps(mem(rdx), xmm1) // load c40 ~ c70,
|
|
mulps(xmm6, xmm12) // scale by alpha,
|
|
mulps(xmm7, xmm1) // scale by beta,
|
|
addps(xmm12, xmm1) // add the gemm result,
|
|
movaps(xmm1, mem(rdx)) // and store back to memory.
|
|
add(rdi, rdx)
|
|
|
|
|
|
|
|
movaps(mem(rcx), xmm0) // load c01 ~ c31,
|
|
mulps(xmm6, xmm9) // scale by alpha,
|
|
mulps(xmm7, xmm0) // scale by beta,
|
|
addps(xmm9, xmm0) // add the gemm result,
|
|
movaps(xmm0, mem(rcx)) // and store back to memory.
|
|
add(rdi, rcx)
|
|
|
|
movaps(mem(rdx), xmm1) // load c41 ~ c71,
|
|
mulps(xmm6, xmm13) // scale by alpha,
|
|
mulps(xmm7, xmm1) // scale by beta,
|
|
addps(xmm13, xmm1) // add the gemm result,
|
|
movaps(xmm1, mem(rdx)) // and store back to memory.
|
|
add(rdi, rdx)
|
|
|
|
|
|
|
|
movaps(mem(rcx), xmm0) // load c02 ~ c32,
|
|
mulps(xmm6, xmm10) // scale by alpha,
|
|
mulps(xmm7, xmm0) // scale by beta,
|
|
addps(xmm10, xmm0) // add the gemm result,
|
|
movaps(xmm0, mem(rcx)) // and store back to memory.
|
|
add(rdi, rcx)
|
|
|
|
movaps(mem(rdx), xmm1) // load c42 ~ c72,
|
|
mulps(xmm6, xmm14) // scale by alpha,
|
|
mulps(xmm7, xmm1) // scale by beta,
|
|
addps(xmm14, xmm1) // add the gemm result,
|
|
movaps(xmm1, mem(rdx)) // and store back to memory.
|
|
add(rdi, rdx)
|
|
|
|
|
|
|
|
movaps(mem(rcx), xmm0) // load c03 ~ c33,
|
|
mulps(xmm6, xmm11) // scale by alpha,
|
|
mulps(xmm7, xmm0) // scale by beta,
|
|
addps(xmm11, xmm0) // add the gemm result,
|
|
movaps(xmm0, mem(rcx)) // and store back to memory.
|
|
|
|
|
|
movaps(mem(rdx), xmm1) // load c43 ~ c73,
|
|
mulps(xmm6, xmm15) // scale by alpha,
|
|
mulps(xmm7, xmm1) // scale by beta,
|
|
addps(xmm15, xmm1) // add the gemm result,
|
|
movaps(xmm1, mem(rdx)) // and store back to memory.
|
|
|
|
jmp(.SDONE) // jump to end.
|
|
|
|
|
|
|
|
|
|
label(.SBETAZERO)
|
|
// check if aligned/column-stored
|
|
and(bl, bh) // set ZF if bl & bh == 1.
|
|
and(bh, al) // set ZF if bh & al == 1.
|
|
jne(.SCOLSTORBZ) // jump to column storage case
|
|
|
|
|
|
|
|
label(.SGENSTORBZ)
|
|
|
|
mulps(xmm6, xmm8) // scale by alpha,
|
|
movaps(xmm8, xmm0)
|
|
|
|
movss(xmm0, mem(rcx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rcx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rcx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rcx, r11, 1))
|
|
|
|
add(rdi, rcx)
|
|
|
|
|
|
mulps(xmm6, xmm12) // scale by alpha,
|
|
movaps(xmm12, xmm0)
|
|
|
|
movss(xmm0, mem(rdx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rdx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rdx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rdx, r11, 1))
|
|
|
|
add(rdi, rdx)
|
|
|
|
|
|
mulps(xmm6, xmm9) // scale by alpha,
|
|
movaps(xmm9, xmm0)
|
|
|
|
movss(xmm0, mem(rcx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rcx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rcx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rcx, r11, 1))
|
|
|
|
add(rdi, rcx)
|
|
|
|
|
|
mulps(xmm6, xmm13) // scale by alpha,
|
|
movaps(xmm13, xmm0)
|
|
|
|
movss(xmm0, mem(rdx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rdx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rdx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rdx, r11, 1))
|
|
|
|
add(rdi, rdx)
|
|
|
|
|
|
mulps(xmm6, xmm10) // scale by alpha,
|
|
movaps(xmm10, xmm0)
|
|
|
|
movss(xmm0, mem(rcx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rcx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rcx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rcx, r11, 1))
|
|
|
|
add(rdi, rcx)
|
|
|
|
|
|
mulps(xmm6, xmm14) // scale by alpha,
|
|
movaps(xmm14, xmm0)
|
|
|
|
movss(xmm0, mem(rdx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rdx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rdx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rdx, r11, 1))
|
|
|
|
add(rdi, rdx)
|
|
|
|
|
|
mulps(xmm6, xmm11) // scale by alpha,
|
|
movaps(xmm11, xmm0)
|
|
|
|
movss(xmm0, mem(rcx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rcx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rcx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rcx, r11, 1))
|
|
|
|
|
|
|
|
|
|
mulps(xmm6, xmm15) // scale by alpha,
|
|
movaps(xmm15, xmm0)
|
|
|
|
movss(xmm0, mem(rdx)) // and store back to memory.
|
|
pshufd(imm(0x39), xmm0, xmm1)
|
|
movss(xmm1, mem(rdx, rsi, 1))
|
|
pshufd(imm(0x39), xmm1, xmm2)
|
|
movss(xmm2, mem(rdx, rsi, 2))
|
|
pshufd(imm(0x39), xmm2, xmm3)
|
|
movss(xmm3, mem(rdx, r11, 1))
|
|
|
|
|
|
|
|
|
|
jmp(.SDONE) // jump to end.
|
|
|
|
|
|
|
|
label(.SCOLSTORBZ)
|
|
|
|
// skip loading c00 ~ c30,
|
|
mulps(xmm6, xmm8) // scale by alpha,
|
|
movaps(xmm8, mem(rcx)) // and store back to memory.
|
|
add(rdi, rcx)
|
|
// skip loading c40 ~ c70,
|
|
mulps(xmm6, xmm12) // scale by alpha,
|
|
movaps(xmm12, mem(rdx)) // and store back to memory.
|
|
add(rdi, rdx)
|
|
|
|
|
|
// skip loading c01 ~ c31,
|
|
mulps(xmm6, xmm9) // scale by alpha,
|
|
movaps(xmm9, mem(rcx)) // and store back to memory.
|
|
add(rdi, rcx)
|
|
// skip loading c41 ~ c71,
|
|
mulps(xmm6, xmm13) // scale by alpha,
|
|
movaps(xmm13, mem(rdx)) // and store back to memory.
|
|
add(rdi, rdx)
|
|
|
|
|
|
// skip loading c02 ~ c32,
|
|
mulps(xmm6, xmm10) // scale by alpha,
|
|
movaps(xmm10, mem(rcx)) // and store back to memory.
|
|
add(rdi, rcx)
|
|
// skip loading c42 ~ c72,
|
|
mulps(xmm6, xmm14) // scale by alpha,
|
|
movaps(xmm14, mem(rdx)) // and store back to memory.
|
|
add(rdi, rdx)
|
|
|
|
|
|
// skip loading c03 ~ c33,
|
|
mulps(xmm6, xmm11) // scale by alpha,
|
|
movaps(xmm11, mem(rcx)) // and store back to memory.
|
|
|
|
// skip loading c43 ~ c73,
|
|
mulps(xmm6, xmm15) // scale by alpha,
|
|
movaps(xmm15, mem(rdx)) // and store back to memory.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
label(.SDONE)
|
|
|
|
|
|
end_asm(
|
|
: // output operands (none)
|
|
: // input operands
|
|
[k_iter] "m" (k_iter), // 0
|
|
[k_left] "m" (k_left), // 1
|
|
[a] "m" (a), // 2
|
|
[b] "m" (b), // 3
|
|
[alpha] "m" (alpha), // 4
|
|
[beta] "m" (beta), // 5
|
|
[c] "m" (c), // 6
|
|
[rs_c] "m" (rs_c), // 7
|
|
[cs_c] "m" (cs_c), // 8
|
|
[b_next] "m" (b_next)/*, // 9
|
|
[a_next] "m" (a_next)*/ // 10
|
|
: // register clobber list
|
|
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12",
|
|
"xmm0", "xmm1", "xmm2", "xmm3",
|
|
"xmm4", "xmm5", "xmm6", "xmm7",
|
|
"xmm8", "xmm9", "xmm10", "xmm11",
|
|
"xmm12", "xmm13", "xmm14", "xmm15",
|
|
"memory"
|
|
)
|
|
}
|
|
|
|
void bli_dgemm_penryn_asm_4x4
|
|
(
|
|
dim_t k0,
|
|
double* restrict alpha,
|
|
double* restrict a,
|
|
double* restrict b,
|
|
double* restrict beta,
|
|
double* restrict c, inc_t rs_c0, inc_t cs_c0,
|
|
auxinfo_t* restrict data,
|
|
cntx_t* restrict cntx
|
|
)
|
|
{
|
|
void* a_next = bli_auxinfo_next_a( data );
|
|
void* b_next = bli_auxinfo_next_b( data );
|
|
|
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
|
// different size than is expected by load instructions.
|
|
uint64_t k_iter = k0 / 4;
|
|
uint64_t k_left = k0 % 4;
|
|
uint64_t rs_c = rs_c0;
|
|
uint64_t cs_c = cs_c0;
|
|
|
|
begin_asm()
|
|
|
|
|
|
mov(var(a), rax) // load address of a.
|
|
mov(var(b), rbx) // load address of b.
|
|
mov(var(b_next), r9) // load address of b_next.
|
|
mov(var(a_next), r11) // load address of a_next.
|
|
|
|
sub(imm(0-8*16), rax) // increment pointers to allow byte
|
|
sub(imm(0-8*16), rbx) // offsets in the unrolled iterations.
|
|
|
|
movaps(mem(rax, -8*16), xmm0) // initialize loop by pre-loading elements
|
|
movaps(mem(rax, -7*16), xmm1) // of a and b.
|
|
movaps(mem(rbx, -8*16), xmm2)
|
|
|
|
mov(var(c), rcx) // load address of c
|
|
mov(var(cs_c), rdi) // load cs_c
|
|
lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double)
|
|
mov(rdi, r12) // make a copy of cs_c (in bytes)
|
|
lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c;
|
|
|
|
prefetch(2, mem(r9, 0*8)) // prefetch b_next
|
|
|
|
xorpd(xmm3, xmm3)
|
|
xorpd(xmm4, xmm4)
|
|
xorpd(xmm5, xmm5)
|
|
xorpd(xmm6, xmm6)
|
|
|
|
prefetch(2, mem(rcx, 3*8)) // prefetch c + 0*cs_c
|
|
xorpd(xmm8, xmm8)
|
|
xorpd(xmm9, xmm9)
|
|
prefetch(2, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c
|
|
xorpd(xmm10, xmm10)
|
|
xorpd(xmm11, xmm11)
|
|
prefetch(2, mem(r10, 3*8)) // prefetch c + 2*cs_c
|
|
xorpd(xmm12, xmm12)
|
|
xorpd(xmm13, xmm13)
|
|
prefetch(2, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c
|
|
xorpd(xmm14, xmm14)
|
|
xorpd(xmm15, xmm15)
|
|
|
|
|
|
|
|
mov(var(k_iter), rsi) // i = k_iter;
|
|
test(rsi, rsi) // check i via logical AND.
|
|
je(.DCONSIDKLEFT) // if i == 0, jump to code that
|
|
// contains the k_left loop.
|
|
|
|
|
|
label(.DLOOPKITER) // MAIN LOOP
|
|
|
|
prefetch(0, mem(rax, (4*35+1)*8))
|
|
//prefetch(0, mem(rax, (8*97+4)*8))
|
|
|
|
//prefetch(0, mem(r11, 67*4*8)) // prefetch a_next[0]
|
|
|
|
addpd(xmm3, xmm11) // iteration 0
|
|
movaps(mem(rbx, -7*16), xmm3)
|
|
addpd(xmm4, xmm15)
|
|
movaps(xmm2, xmm4)
|
|
pshufd(imm(0x4e), xmm2, xmm7)
|
|
mulpd(xmm0, xmm2)
|
|
mulpd(xmm1, xmm4)
|
|
|
|
addpd(xmm5, xmm10)
|
|
addpd(xmm6, xmm14)
|
|
movaps(xmm7, xmm6)
|
|
mulpd(xmm0, xmm7)
|
|
mulpd(xmm1, xmm6)
|
|
|
|
addpd(xmm2, xmm9)
|
|
movaps(mem(rbx, -6*16), xmm2)
|
|
addpd(xmm4, xmm13)
|
|
movaps(xmm3, xmm4)
|
|
pshufd(imm(0x4e), xmm3, xmm5)
|
|
mulpd(xmm0, xmm3)
|
|
mulpd(xmm1, xmm4)
|
|
|
|
addpd(xmm7, xmm8)
|
|
addpd(xmm6, xmm12)
|
|
movaps(xmm5, xmm6)
|
|
mulpd(xmm0, xmm5)
|
|
movaps(mem(rax, -6*16), xmm0)
|
|
mulpd(xmm1, xmm6)
|
|
movaps(mem(rax, -5*16), xmm1)
|
|
|
|
|
|
|
|
addpd(xmm3, xmm11) // iteration 1
|
|
movaps(mem(rbx, -5*16), xmm3)
|
|
addpd(xmm4, xmm15)
|
|
movaps(xmm2, xmm4)
|
|
pshufd(imm(0x4e), xmm2, xmm7)
|
|
mulpd(xmm0, xmm2)
|
|
mulpd(xmm1, xmm4)
|
|
|
|
addpd(xmm5, xmm10)
|
|
addpd(xmm6, xmm14)
|
|
movaps(xmm7, xmm6)
|
|
mulpd(xmm0, xmm7)
|
|
mulpd(xmm1, xmm6)
|
|
|
|
addpd(xmm2, xmm9)
|
|
movaps(mem(rbx, -4*16), xmm2)
|
|
addpd(xmm4, xmm13)
|
|
movaps(xmm3, xmm4)
|
|
pshufd(imm(0x4e), xmm3, xmm5)
|
|
mulpd(xmm0, xmm3)
|
|
mulpd(xmm1, xmm4)
|
|
|
|
addpd(xmm7, xmm8)
|
|
addpd(xmm6, xmm12)
|
|
movaps(xmm5, xmm6)
|
|
mulpd(xmm0, xmm5)
|
|
movaps(mem(rax, -4*16), xmm0)
|
|
mulpd(xmm1, xmm6)
|
|
movaps(mem(rax, -3*16), xmm1)
|
|
|
|
|
|
prefetch(0, mem(rax, (4*37+1)*8))
|
|
//prefetch(0, mem(rax, (8*97+12)*8))
|
|
|
|
//prefetch(0, mem(r11, 69*4*8)) // prefetch a_next[8]
|
|
//sub(imm(-4*4*8), r11) // a_next += 4*4 (unroll x mr)
|
|
|
|
|
|
|
|
addpd(xmm3, xmm11) // iteration 2
|
|
movaps(mem(rbx, -3*16), xmm3)
|
|
addpd(xmm4, xmm15)
|
|
movaps(xmm2, xmm4)
|
|
pshufd(imm(0x4e), xmm2, xmm7)
|
|
mulpd(xmm0, xmm2)
|
|
mulpd(xmm1, xmm4)
|
|
|
|
addpd(xmm5, xmm10)
|
|
addpd(xmm6, xmm14)
|
|
movaps(xmm7, xmm6)
|
|
mulpd(xmm0, xmm7)
|
|
mulpd(xmm1, xmm6)
|
|
|
|
addpd(xmm2, xmm9)
|
|
movaps(mem(rbx, -2*16), xmm2)
|
|
addpd(xmm4, xmm13)
|
|
movaps(xmm3, xmm4)
|
|
pshufd(imm(0x4e), xmm3, xmm5)
|
|
mulpd(xmm0, xmm3)
|
|
mulpd(xmm1, xmm4)
|
|
|
|
|
|
addpd(xmm7, xmm8)
|
|
addpd(xmm6, xmm12)
|
|
movaps(xmm5, xmm6)
|
|
mulpd(xmm0, xmm5)
|
|
movaps(mem(rax, -2*16), xmm0)
|
|
mulpd(xmm1, xmm6)
|
|
movaps(mem(rax, -1*16), xmm1)
|
|
|
|
|
|
|
|
addpd(xmm3, xmm11) // iteration 3
|
|
movaps(mem(rbx, -1*16), xmm3)
|
|
addpd(xmm4, xmm15)
|
|
movaps(xmm2, xmm4)
|
|
pshufd(imm(0x4e), xmm2, xmm7)
|
|
mulpd(xmm0, xmm2)
|
|
mulpd(xmm1, xmm4)
|
|
|
|
sub(imm(0-4*4*8), rax) // a += 4*4 (unroll x mr)
|
|
|
|
addpd(xmm5, xmm10)
|
|
addpd(xmm6, xmm14)
|
|
movaps(xmm7, xmm6)
|
|
mulpd(xmm0, xmm7)
|
|
mulpd(xmm1, xmm6)
|
|
|
|
sub(imm(0-4*4*8), r9) // b_next += 4*4 (unroll x nr)
|
|
|
|
addpd(xmm2, xmm9)
|
|
movaps(mem(rbx, 0*16), xmm2)
|
|
addpd(xmm4, xmm13)
|
|
movaps(xmm3, xmm4)
|
|
pshufd(imm(0x4e), xmm3, xmm5)
|
|
mulpd(xmm0, xmm3)
|
|
mulpd(xmm1, xmm4)
|
|
|
|
sub(imm(0-4*4*8), rbx) // b += 4*4 (unroll x nr)
|
|
|
|
addpd(xmm7, xmm8)
|
|
addpd(xmm6, xmm12)
|
|
movaps(xmm5, xmm6)
|
|
mulpd(xmm0, xmm5)
|
|
movaps(mem(rax, -8*16), xmm0)
|
|
mulpd(xmm1, xmm6)
|
|
movaps(mem(rax, -7*16), xmm1)
|
|
|
|
prefetch(2, mem(r9, 0*8)) // prefetch b_next[0]
|
|
prefetch(2, mem(r9, 8*8)) // prefetch b_next[8]
|
|
|
|
dec(rsi) // i -= 1;
|
|
jne(.DLOOPKITER) // iterate again if i != 0.
|
|
|
|
|
|
|
|
//prefetch(2, mem(r9, -8*8)) // prefetch b_next[-8]
|
|
|
|
|
|
|
|
label(.DCONSIDKLEFT)
|
|
|
|
mov(var(k_left), rsi) // i = k_left;
|
|
test(rsi, rsi) // check i via logical AND.
|
|
je(.DPOSTACCUM) // if i == 0, we're done; jump to end.
|
|
// else, we prepare to enter k_left loop.
|
|
|
|
|
|
label(.DLOOPKLEFT) // EDGE LOOP
|
|
|
|
addpd(xmm3, xmm11) // iteration 0
|
|
movaps(mem(rbx, -7*16), xmm3)
|
|
addpd(xmm4, xmm15)
|
|
movaps(xmm2, xmm4)
|
|
pshufd(imm(0x4e), xmm2, xmm7)
|
|
mulpd(xmm0, xmm2)
|
|
mulpd(xmm1, xmm4)
|
|
|
|
addpd(xmm5, xmm10)
|
|
addpd(xmm6, xmm14)
|
|
movaps(xmm7, xmm6)
|
|
mulpd(xmm0, xmm7)
|
|
mulpd(xmm1, xmm6)
|
|
|
|
addpd(xmm2, xmm9)
|
|
movaps(mem(rbx, -6*16), xmm2)
|
|
addpd(xmm4, xmm13)
|
|
movaps(xmm3, xmm4)
|
|
pshufd(imm(0x4e), xmm3, xmm5)
|
|
mulpd(xmm0, xmm3)
|
|
mulpd(xmm1, xmm4)
|
|
|
|
addpd(xmm7, xmm8)
|
|
addpd(xmm6, xmm12)
|
|
movaps(xmm5, xmm6)
|
|
mulpd(xmm0, xmm5)
|
|
movaps(mem(rax, -6*16), xmm0)
|
|
mulpd(xmm1, xmm6)
|
|
movaps(mem(rax, -5*16), xmm1)
|
|
|
|
|
|
sub(imm(0-4*1*8), rax) // a += 4 (1 x mr)
|
|
sub(imm(0-4*1*8), rbx) // b += 4 (1 x nr)
|
|
|
|
|
|
dec(rsi) // i -= 1;
|
|
jne(.DLOOPKLEFT) // iterate again if i != 0.
|
|
|
|
|
|
|
|
label(.DPOSTACCUM)
|
|
|
|
addpd(xmm3, xmm11)
|
|
addpd(xmm4, xmm15)
|
|
addpd(xmm5, xmm10)
|
|
addpd(xmm6, xmm14)
|
|
|
|
|
|
mov(var(alpha), rax) // load address of alpha
|
|
mov(var(beta), rbx) // load address of beta
|
|
movddup(mem(rax), xmm6) // load alpha and duplicate
|
|
movddup(mem(rbx), xmm7) // load beta and duplicate
|
|
|
|
|
|
mov(var(rs_c), rsi) // load rs_c
|
|
mov(rsi, r8) // make a copy of rs_c
|
|
|
|
lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double)
|
|
|
|
lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c;
|
|
|
|
// xmm8: xmm9: xmm10: xmm11:
|
|
// ( ab01 ( ab00 ( ab03 ( ab02
|
|
// ab10 ) ab11 ) ab12 ) ab13 )
|
|
//
|
|
// xmm12: xmm13: xmm14: xmm15:
|
|
// ( ab21 ( ab20 ( ab23 ( ab22
|
|
// ab30 ) ab31 ) ab32 ) ab33 )
|
|
movaps(xmm8, xmm0)
|
|
movsd(xmm9, xmm8)
|
|
movsd(xmm0, xmm9)
|
|
|
|
movaps(xmm10, xmm0)
|
|
movsd(xmm11, xmm10)
|
|
movsd(xmm0, xmm11)
|
|
|
|
movaps(xmm12, xmm0)
|
|
movsd(xmm13, xmm12)
|
|
movsd(xmm0, xmm13)
|
|
|
|
movaps(xmm14, xmm0)
|
|
movsd(xmm15, xmm14)
|
|
movsd(xmm0, xmm15)
|
|
// xmm8: xmm9: xmm10: xmm11:
|
|
// ( ab00 ( ab01 ( ab02 ( ab03
|
|
// ab10 ) ab11 ) ab12 ) ab13 )
|
|
//
|
|
// xmm12: xmm13: xmm14: xmm15:
|
|
// ( ab20 ( ab21 ( ab22 ( ab23
|
|
// ab30 ) ab31 ) ab32 ) ab33 )
|
|
|
|
|
|
|
|
// determine if
|
|
// c % 16 == 0, AND
|
|
// 8*cs_c % 16 == 0, AND
|
|
// rs_c == 1
|
|
// ie: aligned, ldim aligned, and
|
|
// column-stored
|
|
|
|
cmp(imm(1), r8) // set ZF if rs_c == 1.
|
|
sete(bl) // bl = ( ZF == 1 ? 1 : 0 );
|
|
test(imm(15), rcx) // set ZF if c & 16 is zero.
|
|
setz(bh) // bh = ( ZF == 1 ? 1 : 0 );
|
|
test(imm(15), r12) // set ZF if (8*cs_c) & 16 is zero.
|
|
setz(al) // al = ( ZF == 1 ? 1 : 0 );
|
|
// and(bl,bh) followed by
|
|
// and(bh,al) will reveal result
|
|
|
|
// now avoid loading C if beta == 0
|
|
|
|
xorpd(xmm0, xmm0) // set xmm0 to zero.
|
|
ucomisd(xmm0, xmm7) // check if beta == 0.
|
|
je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case
|
|
|
|
|
|
// check if aligned/column-stored
|
|
and(bl, bh) // set ZF if bl & bh == 1.
|
|
and(bh, al) // set ZF if bh & al == 1.
|
|
jne(.DCOLSTORED) // jump to column storage case
|
|
|
|
|
|
|
|
label(.DGENSTORED)
|
|
|
|
movlpd(mem(rcx), xmm0) // load c00 and c10,
|
|
movhpd(mem(rcx, rsi, 1), xmm0)
|
|
mulpd(xmm6, xmm8) // scale by alpha,
|
|
mulpd(xmm7, xmm0) // scale by beta,
|
|
addpd(xmm8, xmm0) // add the gemm result,
|
|
movlpd(xmm0, mem(rcx)) // and store back to memory.
|
|
movhpd(xmm0, mem(rcx, rsi, 1))
|
|
add(rdi, rcx)
|
|
|
|
movlpd(mem(rdx), xmm1) // load c20 and c30,
|
|
movhpd(mem(rdx, rsi, 1), xmm1)
|
|
mulpd(xmm6, xmm12) // scale by alpha,
|
|
mulpd(xmm7, xmm1) // scale by beta,
|
|
addpd(xmm12, xmm1) // add the gemm result,
|
|
movlpd(xmm1, mem(rdx)) // and store back to memory.
|
|
movhpd(xmm1, mem(rdx, rsi, 1))
|
|
add(rdi, rdx)
|
|
|
|
|
|
|
|
movlpd(mem(rcx), xmm0) // load c01 and c11,
|
|
movhpd(mem(rcx, rsi, 1), xmm0)
|
|
mulpd(xmm6, xmm9) // scale by alpha,
|
|
mulpd(xmm7, xmm0) // scale by beta,
|
|
addpd(xmm9, xmm0) // add the gemm result,
|
|
movlpd(xmm0, mem(rcx)) // and store back to memory.
|
|
movhpd(xmm0, mem(rcx, rsi, 1))
|
|
add(rdi, rcx)
|
|
|
|
movlpd(mem(rdx), xmm1) // load c21 and c31,
|
|
movhpd(mem(rdx, rsi, 1), xmm1)
|
|
mulpd(xmm6, xmm13) // scale by alpha,
|
|
mulpd(xmm7, xmm1) // scale by beta,
|
|
addpd(xmm13, xmm1) // add the gemm result,
|
|
movlpd(xmm1, mem(rdx)) // and store back to memory.
|
|
movhpd(xmm1, mem(rdx, rsi, 1))
|
|
add(rdi, rdx)
|
|
|
|
|
|
|
|
movlpd(mem(rcx), xmm0) // load c02 and c12,
|
|
movhpd(mem(rcx, rsi, 1), xmm0)
|
|
mulpd(xmm6, xmm10) // scale by alpha,
|
|
mulpd(xmm7, xmm0) // scale by beta,
|
|
addpd(xmm10, xmm0) // add the gemm result,
|
|
movlpd(xmm0, mem(rcx)) // and store back to memory.
|
|
movhpd(xmm0, mem(rcx, rsi, 1))
|
|
add(rdi, rcx)
|
|
|
|
movlpd(mem(rdx), xmm1) // load c22 and c32,
|
|
movhpd(mem(rdx, rsi, 1), xmm1)
|
|
mulpd(xmm6, xmm14) // scale by alpha,
|
|
mulpd(xmm7, xmm1) // scale by beta,
|
|
addpd(xmm14, xmm1) // add the gemm result,
|
|
movlpd(xmm1, mem(rdx)) // and store back to memory.
|
|
movhpd(xmm1, mem(rdx, rsi, 1))
|
|
add(rdi, rdx)
|
|
|
|
|
|
|
|
movlpd(mem(rcx), xmm0) // load c03 and c13,
|
|
movhpd(mem(rcx, rsi, 1), xmm0)
|
|
mulpd(xmm6, xmm11) // scale by alpha,
|
|
mulpd(xmm7, xmm0) // scale by beta,
|
|
addpd(xmm11, xmm0) // add the gemm result,
|
|
movlpd(xmm0, mem(rcx)) // and store back to memory.
|
|
movhpd(xmm0, mem(rcx, rsi, 1))
|
|
|
|
|
|
movlpd(mem(rdx), xmm1) // load c23 and c33,
|
|
movhpd(mem(rdx, rsi, 1), xmm1)
|
|
mulpd(xmm6, xmm15) // scale by alpha,
|
|
mulpd(xmm7, xmm1) // scale by beta,
|
|
addpd(xmm15, xmm1) // add the gemm result,
|
|
movlpd(xmm1, mem(rdx)) // and store back to memory.
|
|
movhpd(xmm1, mem(rdx, rsi, 1))
|
|
|
|
jmp(.DDONE) // jump to end.
|
|
|
|
|
|
|
|
label(.DCOLSTORED)
|
|
|
|
movaps(mem(rcx), xmm0) // load c00 and c10,
|
|
mulpd(xmm6, xmm8) // scale by alpha,
|
|
mulpd(xmm7, xmm0) // scale by beta,
|
|
addpd(xmm8, xmm0) // add the gemm result,
|
|
movaps(xmm0, mem(rcx)) // and store back to memory.
|
|
add(rdi, rcx)
|
|
|
|
movaps(mem(rdx), xmm1) // load c20 and c30,
|
|
mulpd(xmm6, xmm12) // scale by alpha,
|
|
mulpd(xmm7, xmm1) // scale by beta,
|
|
addpd(xmm12, xmm1) // add the gemm result,
|
|
movaps(xmm1, mem(rdx)) // and store back to memory.
|
|
add(rdi, rdx)
|
|
|
|
|
|
|
|
movaps(mem(rcx), xmm0) // load c01 and c11,
|
|
mulpd(xmm6, xmm9) // scale by alpha,
|
|
mulpd(xmm7, xmm0) // scale by beta,
|
|
addpd(xmm9, xmm0) // add the gemm result,
|
|
movaps(xmm0, mem(rcx)) // and store back to memory.
|
|
add(rdi, rcx)
|
|
|
|
movaps(mem(rdx), xmm1) // load c21 and c31,
|
|
mulpd(xmm6, xmm13) // scale by alpha,
|
|
mulpd(xmm7, xmm1) // scale by beta,
|
|
addpd(xmm13, xmm1) // add the gemm result,
|
|
movaps(xmm1, mem(rdx)) // and store back to memory.
|
|
add(rdi, rdx)
|
|
|
|
|
|
|
|
movaps(mem(rcx), xmm0) // load c02 and c12,
|
|
mulpd(xmm6, xmm10) // scale by alpha,
|
|
mulpd(xmm7, xmm0) // scale by beta,
|
|
addpd(xmm10, xmm0) // add the gemm result,
|
|
movaps(xmm0, mem(rcx)) // and store back to memory.
|
|
add(rdi, rcx)
|
|
|
|
movaps(mem(rdx), xmm1) // load c22 and c32,
|
|
mulpd(xmm6, xmm14) // scale by alpha,
|
|
mulpd(xmm7, xmm1) // scale by beta,
|
|
addpd(xmm14, xmm1) // add the gemm result,
|
|
movaps(xmm1, mem(rdx)) // and store back to memory.
|
|
add(rdi, rdx)
|
|
|
|
|
|
|
|
movaps(mem(rcx), xmm0) // load c03 and c13,
|
|
mulpd(xmm6, xmm11) // scale by alpha,
|
|
mulpd(xmm7, xmm0) // scale by beta,
|
|
addpd(xmm11, xmm0) // add the gemm result,
|
|
movaps(xmm0, mem(rcx)) // and store back to memory.
|
|
|
|
|
|
movaps(mem(rdx), xmm1) // load c23 and c33,
|
|
mulpd(xmm6, xmm15) // scale by alpha,
|
|
mulpd(xmm7, xmm1) // scale by beta,
|
|
addpd(xmm15, xmm1) // add the gemm result,
|
|
movaps(xmm1, mem(rdx)) // and store back to memory.
|
|
|
|
jmp(.DDONE) // jump to end.
|
|
|
|
|
|
|
|
|
|
label(.DBETAZERO)
|
|
// check if aligned/column-stored
|
|
and(bl, bh) // set ZF if bl & bh == 1.
|
|
and(bh, al) // set ZF if bh & al == 1.
|
|
jne(.DCOLSTORBZ) // jump to column storage case
|
|
|
|
|
|
|
|
label(.DGENSTORBZ)
|
|
// skip loading c00 and c10,
|
|
mulpd(xmm6, xmm8) // scale by alpha,
|
|
movlpd(xmm8, mem(rcx)) // and store back to memory.
|
|
movhpd(xmm8, mem(rcx, rsi, 1))
|
|
add(rdi, rcx)
|
|
// skip loading c20 and c30,
|
|
mulpd(xmm6, xmm12) // scale by alpha,
|
|
movlpd(xmm12, mem(rdx)) // and store back to memory.
|
|
movhpd(xmm12, mem(rdx, rsi, 1))
|
|
add(rdi, rdx)
|
|
|
|
|
|
// skip loading c01 and c11,
|
|
mulpd(xmm6, xmm9) // scale by alpha,
|
|
movlpd(xmm9, mem(rcx)) // and store back to memory.
|
|
movhpd(xmm9, mem(rcx, rsi, 1))
|
|
add(rdi, rcx)
|
|
// skip loading c21 and c31,
|
|
mulpd(xmm6, xmm13) // scale by alpha,
|
|
movlpd(xmm13, mem(rdx)) // and store back to memory.
|
|
movhpd(xmm13, mem(rdx, rsi, 1))
|
|
add(rdi, rdx)
|
|
|
|
|
|
// skip loading c02 and c12,
|
|
mulpd(xmm6, xmm10) // scale by alpha,
|
|
movlpd(xmm10, mem(rcx)) // and store back to memory.
|
|
movhpd(xmm10, mem(rcx, rsi, 1))
|
|
add(rdi, rcx)
|
|
// skip loading c22 and c32,
|
|
mulpd(xmm6, xmm14) // scale by alpha,
|
|
movlpd(xmm14, mem(rdx)) // and store back to memory.
|
|
movhpd(xmm14, mem(rdx, rsi, 1))
|
|
add(rdi, rdx)
|
|
|
|
|
|
// skip loading c03 and c13,
|
|
mulpd(xmm6, xmm11) // scale by alpha,
|
|
movlpd(xmm11, mem(rcx)) // and store back to memory.
|
|
movhpd(xmm11, mem(rcx, rsi, 1))
|
|
|
|
// skip loading c23 and c33,
|
|
mulpd(xmm6, xmm15) // scale by alpha,
|
|
movlpd(xmm15, mem(rdx)) // and store back to memory.
|
|
movhpd(xmm15, mem(rdx, rsi, 1))
|
|
|
|
jmp(.DDONE) // jump to end.
|
|
|
|
|
|
|
|
label(.DCOLSTORBZ)
|
|
|
|
// skip loading c00 and c10,
|
|
mulpd(xmm6, xmm8) // scale by alpha,
|
|
movaps(xmm8, mem(rcx)) // and store back to memory.
|
|
add(rdi, rcx)
|
|
// skip loading c20 and c30,
|
|
mulpd(xmm6, xmm12) // scale by alpha,
|
|
movaps(xmm12, mem(rdx)) // and store back to memory.
|
|
add(rdi, rdx)
|
|
|
|
|
|
// skip loading c01 and c11,
|
|
mulpd(xmm6, xmm9) // scale by alpha,
|
|
movaps(xmm9, mem(rcx)) // and store back to memory.
|
|
add(rdi, rcx)
|
|
// skip loading c21 and c31,
|
|
mulpd(xmm6, xmm13) // scale by alpha,
|
|
movaps(xmm13, mem(rdx)) // and store back to memory.
|
|
add(rdi, rdx)
|
|
|
|
|
|
// skip loading c02 and c12,
|
|
mulpd(xmm6, xmm10) // scale by alpha,
|
|
movaps(xmm10, mem(rcx)) // and store back to memory.
|
|
add(rdi, rcx)
|
|
// skip loading c22 and c32,
|
|
mulpd(xmm6, xmm14) // scale by alpha,
|
|
movaps(xmm14, mem(rdx)) // and store back to memory.
|
|
add(rdi, rdx)
|
|
|
|
|
|
// skip loading c03 and c13,
|
|
mulpd(xmm6, xmm11) // scale by alpha,
|
|
movaps(xmm11, mem(rcx)) // and store back to memory.
|
|
|
|
// skip loading c23 and c33,
|
|
mulpd(xmm6, xmm15) // scale by alpha,
|
|
movaps(xmm15, mem(rdx)) // and store back to memory.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
label(.DDONE)
|
|
|
|
|
|
end_asm(
|
|
: // output operands (none)
|
|
: // input operands
|
|
[k_iter] "m" (k_iter), // 0
|
|
[k_left] "m" (k_left), // 1
|
|
[a] "m" (a), // 2
|
|
[b] "m" (b), // 3
|
|
[alpha] "m" (alpha), // 4
|
|
[beta] "m" (beta), // 5
|
|
[c] "m" (c), // 6
|
|
[rs_c] "m" (rs_c), // 7
|
|
[cs_c] "m" (cs_c), // 8
|
|
[b_next] "m" (b_next), // 9
|
|
[a_next] "m" (a_next) // 10
|
|
: // register clobber list
|
|
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12",
|
|
"xmm0", "xmm1", "xmm2", "xmm3",
|
|
"xmm4", "xmm5", "xmm6", "xmm7",
|
|
"xmm8", "xmm9", "xmm10", "xmm11",
|
|
"xmm12", "xmm13", "xmm14", "xmm15",
|
|
"memory"
|
|
)
|
|
}
|
|
|
|
|