mirror of
https://github.com/amd/blis.git
synced 2026-05-11 17:50:00 +00:00
DGEMM Improvements
- We prefetch next panel while packing 8xk panel. - Modified prefetch offsets for dgemm native and dgemm_small kernel. AMD-Internal: [CPUPL-2366] Change-Id: Ife609e789c8b87169c73bb0a30d6f1af20fb30ed
This commit is contained in:
@@ -556,9 +556,10 @@ void dgemm_
|
||||
|
||||
#ifdef BLIS_ENABLE_SMALL_MATRIX
|
||||
|
||||
//if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2))
|
||||
if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) ||
|
||||
((n0 <= 10) && (k0 <=10)) )
|
||||
if(((m0 == n0) && (m0 < 400) && (k0 < 1000)) ||
|
||||
( (m0 != n0) && (( ((m0 + n0 -k0) < 1500) &&
|
||||
((m0 + k0-n0) < 1500) && ((n0 + k0-m0) < 1500) ) ||
|
||||
((n0 <= 100) && (k0 <=100)))))
|
||||
{
|
||||
err_t status = BLIS_FAILURE;
|
||||
if (bli_is_notrans(blis_transa))
|
||||
|
||||
@@ -101,6 +101,8 @@ void bli_dpackm_haswell_asm_8xk
|
||||
// assembly region, this constraint should be lifted.
|
||||
const bool unitk = bli_deq1( *kappa );
|
||||
|
||||
double* restrict a_next = a + cdim0;
|
||||
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
@@ -267,7 +269,7 @@ void bli_dpackm_haswell_asm_8xk
|
||||
label(.DCOLUNIT)
|
||||
|
||||
lea(mem(r10, r10, 2), r13) // r13 = 3*lda
|
||||
|
||||
mov(var(a_next), rcx)
|
||||
mov(var(k_iter), rsi) // i = k_iter;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.DCONKLEFTCOLU) // if i == 0, jump to code that
|
||||
@@ -278,22 +280,27 @@ void bli_dpackm_haswell_asm_8xk
|
||||
|
||||
vmovupd(mem(rax, 0), ymm0)
|
||||
vmovupd(mem(rax, 32), ymm1)
|
||||
prefetch(0, mem(rcx,7*8))
|
||||
vmovupd(ymm0, mem(rbx, 0*64+ 0))
|
||||
vmovupd(ymm1, mem(rbx, 0*64+32))
|
||||
|
||||
vmovupd(mem(rax, r10, 1, 0), ymm2)
|
||||
vmovupd(mem(rax, r10, 1, 32), ymm3)
|
||||
prefetch(0, mem(rcx, r10, 1,7*8))
|
||||
vmovupd(ymm2, mem(rbx, 1*64+ 0))
|
||||
vmovupd(ymm3, mem(rbx, 1*64+32))
|
||||
|
||||
vmovupd(mem(rax, r10, 2, 0), ymm4)
|
||||
vmovupd(mem(rax, r10, 2, 32), ymm5)
|
||||
prefetch(0, mem(rcx, r10, 2,7*8))
|
||||
vmovupd(ymm4, mem(rbx, 2*64+ 0))
|
||||
vmovupd(ymm5, mem(rbx, 2*64+32))
|
||||
|
||||
vmovupd(mem(rax, r13, 1, 0), ymm6)
|
||||
vmovupd(mem(rax, r13, 1, 32), ymm7)
|
||||
prefetch(0, mem(rcx, r13, 1,7*8))
|
||||
add(r14, rax) // a += 4*lda;
|
||||
add(r14, rcx)
|
||||
vmovupd(ymm6, mem(rbx, 3*64+ 0))
|
||||
vmovupd(ymm7, mem(rbx, 3*64+32))
|
||||
add(imm(4*8*8), rbx) // p += 4*ldp = 4*8;
|
||||
@@ -315,7 +322,9 @@ void bli_dpackm_haswell_asm_8xk
|
||||
|
||||
vmovupd(mem(rax, 0), ymm0)
|
||||
vmovupd(mem(rax, 32), ymm1)
|
||||
prefetch(0, mem(rcx,7*8))
|
||||
add(r10, rax) // a += lda;
|
||||
add(r10, rcx)
|
||||
vmovupd(ymm0, mem(rbx, 0*64+ 0))
|
||||
vmovupd(ymm1, mem(rbx, 0*64+32))
|
||||
add(imm(8*8), rbx) // p += ldp = 8;
|
||||
@@ -343,7 +352,8 @@ void bli_dpackm_haswell_asm_8xk
|
||||
[p] "m" (p),
|
||||
[ldp] "m" (ldp),
|
||||
[kappa] "m" (kappa),
|
||||
[one] "m" (one)
|
||||
[one] "m" (one),
|
||||
[a_next] "m" (a_next)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15",
|
||||
|
||||
@@ -102,7 +102,19 @@ void bli_sgemm_haswell_asm_6x16
|
||||
|
||||
begin_asm()
|
||||
|
||||
vzeroall() // zero all xmm/ymm registers.
|
||||
//vzeroall() // zero all xmm/ymm registers.
|
||||
vxorps( ymm4, ymm4, ymm4)
|
||||
vmovaps( ymm4, ymm5)
|
||||
vmovaps( ymm4, ymm6)
|
||||
vmovaps( ymm4, ymm7)
|
||||
vmovaps( ymm4, ymm8)
|
||||
vmovaps( ymm4, ymm9)
|
||||
vmovaps( ymm4, ymm10)
|
||||
vmovaps( ymm4, ymm11)
|
||||
vmovaps( ymm4, ymm12)
|
||||
vmovaps( ymm4, ymm13)
|
||||
vmovaps( ymm4, ymm14)
|
||||
vmovaps( ymm4, ymm15)
|
||||
|
||||
|
||||
mov(var(a), rax) // load address of a.
|
||||
@@ -141,7 +153,7 @@ void bli_sgemm_haswell_asm_6x16
|
||||
|
||||
// iteration 0
|
||||
prefetch(0, mem(rax, 64*4))
|
||||
|
||||
|
||||
vbroadcastss(mem(rax, 0*4), ymm2)
|
||||
vbroadcastss(mem(rax, 1*4), ymm3)
|
||||
vfmadd231ps(ymm0, ymm2, ymm4)
|
||||
@@ -167,6 +179,8 @@ void bli_sgemm_haswell_asm_6x16
|
||||
vmovaps(mem(rbx, -1*32), ymm1)
|
||||
|
||||
// iteration 1
|
||||
prefetch(0, mem(rax, 72*4))
|
||||
|
||||
vbroadcastss(mem(rax, 6*4), ymm2)
|
||||
vbroadcastss(mem(rax, 7*4), ymm3)
|
||||
vfmadd231ps(ymm0, ymm2, ymm4)
|
||||
@@ -192,7 +206,7 @@ void bli_sgemm_haswell_asm_6x16
|
||||
vmovaps(mem(rbx, 1*32), ymm1)
|
||||
|
||||
// iteration 2
|
||||
prefetch(0, mem(rax, 76*4))
|
||||
prefetch(0, mem(rax, 80*4))
|
||||
|
||||
vbroadcastss(mem(rax, 12*4), ymm2)
|
||||
vbroadcastss(mem(rax, 13*4), ymm3)
|
||||
@@ -1010,76 +1024,78 @@ void bli_dgemm_haswell_asm_6x8
|
||||
vfmadd231pd(ymm1, ymm2, ymm5)
|
||||
vfmadd231pd(ymm0, ymm3, ymm6)
|
||||
vfmadd231pd(ymm1, ymm3, ymm7)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 2*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 3*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm8)
|
||||
vfmadd231pd(ymm1, ymm2, ymm9)
|
||||
vfmadd231pd(ymm0, ymm3, ymm10)
|
||||
vfmadd231pd(ymm1, ymm3, ymm11)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 4*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 5*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm12)
|
||||
vfmadd231pd(ymm1, ymm2, ymm13)
|
||||
vfmadd231pd(ymm0, ymm3, ymm14)
|
||||
vfmadd231pd(ymm1, ymm3, ymm15)
|
||||
|
||||
|
||||
vmovapd(mem(rbx, -2*32), ymm0)
|
||||
vmovapd(mem(rbx, -1*32), ymm1)
|
||||
|
||||
|
||||
// iteration 1
|
||||
prefetch(0, mem(rax, 72*8))
|
||||
|
||||
vbroadcastsd(mem(rax, 6*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 7*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm4)
|
||||
vfmadd231pd(ymm1, ymm2, ymm5)
|
||||
vfmadd231pd(ymm0, ymm3, ymm6)
|
||||
vfmadd231pd(ymm1, ymm3, ymm7)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 8*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 9*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm8)
|
||||
vfmadd231pd(ymm1, ymm2, ymm9)
|
||||
vfmadd231pd(ymm0, ymm3, ymm10)
|
||||
vfmadd231pd(ymm1, ymm3, ymm11)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 10*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 11*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm12)
|
||||
vfmadd231pd(ymm1, ymm2, ymm13)
|
||||
vfmadd231pd(ymm0, ymm3, ymm14)
|
||||
vfmadd231pd(ymm1, ymm3, ymm15)
|
||||
|
||||
|
||||
vmovapd(mem(rbx, 0*32), ymm0)
|
||||
vmovapd(mem(rbx, 1*32), ymm1)
|
||||
|
||||
|
||||
// iteration 2
|
||||
prefetch(0, mem(rax, 76*8))
|
||||
|
||||
prefetch(0, mem(rax, 80*8))
|
||||
|
||||
vbroadcastsd(mem(rax, 12*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 13*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm4)
|
||||
vfmadd231pd(ymm1, ymm2, ymm5)
|
||||
vfmadd231pd(ymm0, ymm3, ymm6)
|
||||
vfmadd231pd(ymm1, ymm3, ymm7)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 14*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 15*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm8)
|
||||
vfmadd231pd(ymm1, ymm2, ymm9)
|
||||
vfmadd231pd(ymm0, ymm3, ymm10)
|
||||
vfmadd231pd(ymm1, ymm3, ymm11)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 16*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 17*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm12)
|
||||
vfmadd231pd(ymm1, ymm2, ymm13)
|
||||
vfmadd231pd(ymm0, ymm3, ymm14)
|
||||
vfmadd231pd(ymm1, ymm3, ymm15)
|
||||
|
||||
|
||||
vmovapd(mem(rbx, 2*32), ymm0)
|
||||
vmovapd(mem(rbx, 3*32), ymm1)
|
||||
|
||||
|
||||
// iteration 3
|
||||
vbroadcastsd(mem(rax, 18*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 19*8), ymm3)
|
||||
@@ -1087,91 +1103,91 @@ void bli_dgemm_haswell_asm_6x8
|
||||
vfmadd231pd(ymm1, ymm2, ymm5)
|
||||
vfmadd231pd(ymm0, ymm3, ymm6)
|
||||
vfmadd231pd(ymm1, ymm3, ymm7)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 20*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 21*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm8)
|
||||
vfmadd231pd(ymm1, ymm2, ymm9)
|
||||
vfmadd231pd(ymm0, ymm3, ymm10)
|
||||
vfmadd231pd(ymm1, ymm3, ymm11)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 22*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 23*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm12)
|
||||
vfmadd231pd(ymm1, ymm2, ymm13)
|
||||
vfmadd231pd(ymm0, ymm3, ymm14)
|
||||
vfmadd231pd(ymm1, ymm3, ymm15)
|
||||
|
||||
|
||||
add(imm(4*6*8), rax) // a += 4*6 (unroll x mr)
|
||||
add(imm(4*8*8), rbx) // b += 4*8 (unroll x nr)
|
||||
|
||||
|
||||
vmovapd(mem(rbx, -4*32), ymm0)
|
||||
vmovapd(mem(rbx, -3*32), ymm1)
|
||||
|
||||
|
||||
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.DLOOPKITER) // iterate again if i != 0.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DCONSIDKLEFT)
|
||||
|
||||
|
||||
mov(var(k_left), rsi) // i = k_left;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.DPOSTACCUM) // if i == 0, we're done; jump to end.
|
||||
// else, we prepare to enter k_left loop.
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DLOOPKLEFT) // EDGE LOOP
|
||||
|
||||
|
||||
prefetch(0, mem(rax, 64*8))
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 0*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 1*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm4)
|
||||
vfmadd231pd(ymm1, ymm2, ymm5)
|
||||
vfmadd231pd(ymm0, ymm3, ymm6)
|
||||
vfmadd231pd(ymm1, ymm3, ymm7)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 2*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 3*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm8)
|
||||
vfmadd231pd(ymm1, ymm2, ymm9)
|
||||
vfmadd231pd(ymm0, ymm3, ymm10)
|
||||
vfmadd231pd(ymm1, ymm3, ymm11)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 4*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 5*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm12)
|
||||
vfmadd231pd(ymm1, ymm2, ymm13)
|
||||
vfmadd231pd(ymm0, ymm3, ymm14)
|
||||
vfmadd231pd(ymm1, ymm3, ymm15)
|
||||
|
||||
|
||||
add(imm(1*6*8), rax) // a += 1*6 (unroll x mr)
|
||||
add(imm(1*8*8), rbx) // b += 1*8 (unroll x nr)
|
||||
|
||||
|
||||
vmovapd(mem(rbx, -4*32), ymm0)
|
||||
vmovapd(mem(rbx, -3*32), ymm1)
|
||||
|
||||
|
||||
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.DLOOPKLEFT) // iterate again if i != 0.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DPOSTACCUM)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mov(var(alpha), rax) // load address of alpha
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate
|
||||
vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate
|
||||
|
||||
|
||||
vmulpd(ymm0, ymm4, ymm4) // scale by alpha
|
||||
vmulpd(ymm0, ymm5, ymm5)
|
||||
vmulpd(ymm0, ymm6, ymm6)
|
||||
@@ -1184,179 +1200,179 @@ void bli_dgemm_haswell_asm_6x8
|
||||
vmulpd(ymm0, ymm13, ymm13)
|
||||
vmulpd(ymm0, ymm14, ymm14)
|
||||
vmulpd(ymm0, ymm15, ymm15)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double)
|
||||
|
||||
|
||||
lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c;
|
||||
lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c;
|
||||
|
||||
|
||||
lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c;
|
||||
//lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c;
|
||||
//lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
// now avoid loading C if beta == 0
|
||||
|
||||
|
||||
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
|
||||
vucomisd(xmm0, xmm3) // set ZF if beta == 0.
|
||||
je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case
|
||||
|
||||
|
||||
|
||||
|
||||
cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8.
|
||||
jz(.DROWSTORED) // jump to row storage case
|
||||
|
||||
|
||||
|
||||
|
||||
cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8.
|
||||
jz(.DCOLSTORED) // jump to column storage case
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DGENSTORED)
|
||||
|
||||
|
||||
|
||||
|
||||
DGEMM_INPUT_GS_BETA_NZ
|
||||
vfmadd213pd(ymm4, ymm3, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
DGEMM_INPUT_GS_BETA_NZ
|
||||
vfmadd213pd(ymm6, ymm3, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
DGEMM_INPUT_GS_BETA_NZ
|
||||
vfmadd213pd(ymm8, ymm3, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
DGEMM_INPUT_GS_BETA_NZ
|
||||
vfmadd213pd(ymm10, ymm3, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
DGEMM_INPUT_GS_BETA_NZ
|
||||
vfmadd213pd(ymm12, ymm3, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
DGEMM_INPUT_GS_BETA_NZ
|
||||
vfmadd213pd(ymm14, ymm3, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
|
||||
|
||||
|
||||
|
||||
mov(rdx, rcx) // rcx = c + 4*cs_c
|
||||
|
||||
|
||||
|
||||
|
||||
DGEMM_INPUT_GS_BETA_NZ
|
||||
vfmadd213pd(ymm5, ymm3, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
DGEMM_INPUT_GS_BETA_NZ
|
||||
vfmadd213pd(ymm7, ymm3, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
DGEMM_INPUT_GS_BETA_NZ
|
||||
vfmadd213pd(ymm9, ymm3, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
DGEMM_INPUT_GS_BETA_NZ
|
||||
vfmadd213pd(ymm11, ymm3, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
DGEMM_INPUT_GS_BETA_NZ
|
||||
vfmadd213pd(ymm13, ymm3, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
DGEMM_INPUT_GS_BETA_NZ
|
||||
vfmadd213pd(ymm15, ymm3, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DROWSTORED)
|
||||
|
||||
|
||||
|
||||
|
||||
vfmadd231pd(mem(rcx), ymm3, ymm4)
|
||||
vmovupd(ymm4, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
vfmadd231pd(mem(rdx), ymm3, ymm5)
|
||||
vmovupd(ymm5, mem(rdx))
|
||||
add(rdi, rdx)
|
||||
|
||||
|
||||
|
||||
|
||||
vfmadd231pd(mem(rcx), ymm3, ymm6)
|
||||
vmovupd(ymm6, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
vfmadd231pd(mem(rdx), ymm3, ymm7)
|
||||
vmovupd(ymm7, mem(rdx))
|
||||
add(rdi, rdx)
|
||||
|
||||
|
||||
|
||||
|
||||
vfmadd231pd(mem(rcx), ymm3, ymm8)
|
||||
vmovupd(ymm8, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
vfmadd231pd(mem(rdx), ymm3, ymm9)
|
||||
vmovupd(ymm9, mem(rdx))
|
||||
add(rdi, rdx)
|
||||
|
||||
|
||||
|
||||
|
||||
vfmadd231pd(mem(rcx), ymm3, ymm10)
|
||||
vmovupd(ymm10, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
vfmadd231pd(mem(rdx), ymm3, ymm11)
|
||||
vmovupd(ymm11, mem(rdx))
|
||||
add(rdi, rdx)
|
||||
|
||||
|
||||
|
||||
|
||||
vfmadd231pd(mem(rcx), ymm3, ymm12)
|
||||
vmovupd(ymm12, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
vfmadd231pd(mem(rdx), ymm3, ymm13)
|
||||
vmovupd(ymm13, mem(rdx))
|
||||
add(rdi, rdx)
|
||||
|
||||
|
||||
|
||||
|
||||
vfmadd231pd(mem(rcx), ymm3, ymm14)
|
||||
vmovupd(ymm14, mem(rcx))
|
||||
//add(rdi, rcx)
|
||||
vfmadd231pd(mem(rdx), ymm3, ymm15)
|
||||
vmovupd(ymm15, mem(rdx))
|
||||
//add(rdi, rdx)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DCOLSTORED)
|
||||
|
||||
|
||||
|
||||
|
||||
vunpcklpd(ymm6, ymm4, ymm0)
|
||||
vunpckhpd(ymm6, ymm4, ymm1)
|
||||
vunpcklpd(ymm10, ymm8, ymm2)
|
||||
@@ -1365,9 +1381,9 @@ void bli_dgemm_haswell_asm_6x8
|
||||
vinsertf128(imm(0x1), xmm3, ymm1, ymm6)
|
||||
vperm2f128(imm(0x31), ymm2, ymm0, ymm8)
|
||||
vperm2f128(imm(0x31), ymm3, ymm1, ymm10)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rbx), ymm3)
|
||||
|
||||
|
||||
vfmadd231pd(mem(rcx), ymm3, ymm4)
|
||||
vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6)
|
||||
vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8)
|
||||
@@ -1376,14 +1392,14 @@ void bli_dgemm_haswell_asm_6x8
|
||||
vmovupd(ymm6, mem(rcx, rsi, 1))
|
||||
vmovupd(ymm8, mem(rcx, rsi, 2))
|
||||
vmovupd(ymm10, mem(rcx, r13, 1))
|
||||
|
||||
|
||||
lea(mem(rcx, rsi, 4), rcx)
|
||||
|
||||
|
||||
vunpcklpd(ymm14, ymm12, ymm0)
|
||||
vunpckhpd(ymm14, ymm12, ymm1)
|
||||
vextractf128(imm(0x1), ymm0, xmm2)
|
||||
vextractf128(imm(0x1), ymm1, xmm4)
|
||||
|
||||
|
||||
vfmadd231pd(mem(r14), xmm3, xmm0)
|
||||
vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1)
|
||||
vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2)
|
||||
@@ -1392,10 +1408,10 @@ void bli_dgemm_haswell_asm_6x8
|
||||
vmovupd(xmm1, mem(r14, rsi, 1))
|
||||
vmovupd(xmm2, mem(r14, rsi, 2))
|
||||
vmovupd(xmm4, mem(r14, r13, 1))
|
||||
|
||||
|
||||
lea(mem(r14, rsi, 4), r14)
|
||||
|
||||
|
||||
|
||||
|
||||
vunpcklpd(ymm7, ymm5, ymm0)
|
||||
vunpckhpd(ymm7, ymm5, ymm1)
|
||||
vunpcklpd(ymm11, ymm9, ymm2)
|
||||
@@ -1404,9 +1420,9 @@ void bli_dgemm_haswell_asm_6x8
|
||||
vinsertf128(imm(0x1), xmm3, ymm1, ymm7)
|
||||
vperm2f128(imm(0x31), ymm2, ymm0, ymm9)
|
||||
vperm2f128(imm(0x31), ymm3, ymm1, ymm11)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rbx), ymm3)
|
||||
|
||||
|
||||
vfmadd231pd(mem(rcx), ymm3, ymm5)
|
||||
vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7)
|
||||
vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9)
|
||||
@@ -1415,14 +1431,14 @@ void bli_dgemm_haswell_asm_6x8
|
||||
vmovupd(ymm7, mem(rcx, rsi, 1))
|
||||
vmovupd(ymm9, mem(rcx, rsi, 2))
|
||||
vmovupd(ymm11, mem(rcx, r13, 1))
|
||||
|
||||
|
||||
//lea(mem(rcx, rsi, 4), rcx)
|
||||
|
||||
|
||||
vunpcklpd(ymm15, ymm13, ymm0)
|
||||
vunpckhpd(ymm15, ymm13, ymm1)
|
||||
vextractf128(imm(0x1), ymm0, xmm2)
|
||||
vextractf128(imm(0x1), ymm1, xmm4)
|
||||
|
||||
|
||||
vfmadd231pd(mem(r14), xmm3, xmm0)
|
||||
vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1)
|
||||
vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2)
|
||||
@@ -1431,139 +1447,139 @@ void bli_dgemm_haswell_asm_6x8
|
||||
vmovupd(xmm1, mem(r14, rsi, 1))
|
||||
vmovupd(xmm2, mem(r14, rsi, 2))
|
||||
vmovupd(xmm4, mem(r14, r13, 1))
|
||||
|
||||
|
||||
//lea(mem(r14, rsi, 4), r14)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DBETAZERO)
|
||||
|
||||
|
||||
cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8.
|
||||
jz(.DROWSTORBZ) // jump to row storage case
|
||||
|
||||
|
||||
cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8.
|
||||
jz(.DCOLSTORBZ) // jump to column storage case
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DGENSTORBZ)
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm4, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm6, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm8, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm10, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm12, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm14, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
|
||||
|
||||
|
||||
|
||||
mov(rdx, rcx) // rcx = c + 4*cs_c
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm5, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm7, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm9, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm11, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm13, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
add(rdi, rcx) // c += rs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm15, ymm0)
|
||||
DGEMM_OUTPUT_GS_BETA_NZ
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
|
||||
|
||||
|
||||
|
||||
vmovupd(ymm4, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
vmovupd(ymm5, mem(rdx))
|
||||
add(rdi, rdx)
|
||||
|
||||
|
||||
vmovupd(ymm6, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
vmovupd(ymm7, mem(rdx))
|
||||
add(rdi, rdx)
|
||||
|
||||
|
||||
|
||||
|
||||
vmovupd(ymm8, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
vmovupd(ymm9, mem(rdx))
|
||||
add(rdi, rdx)
|
||||
|
||||
|
||||
|
||||
|
||||
vmovupd(ymm10, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
vmovupd(ymm11, mem(rdx))
|
||||
add(rdi, rdx)
|
||||
|
||||
|
||||
|
||||
|
||||
vmovupd(ymm12, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
vmovupd(ymm13, mem(rdx))
|
||||
add(rdi, rdx)
|
||||
|
||||
|
||||
|
||||
|
||||
vmovupd(ymm14, mem(rcx))
|
||||
//add(rdi, rcx)
|
||||
vmovupd(ymm15, mem(rdx))
|
||||
//add(rdi, rdx)
|
||||
|
||||
|
||||
|
||||
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DCOLSTORBZ)
|
||||
|
||||
|
||||
|
||||
|
||||
vunpcklpd(ymm6, ymm4, ymm0)
|
||||
vunpckhpd(ymm6, ymm4, ymm1)
|
||||
vunpcklpd(ymm10, ymm8, ymm2)
|
||||
@@ -1572,27 +1588,27 @@ void bli_dgemm_haswell_asm_6x8
|
||||
vinsertf128(imm(0x1), xmm3, ymm1, ymm6)
|
||||
vperm2f128(imm(0x31), ymm2, ymm0, ymm8)
|
||||
vperm2f128(imm(0x31), ymm3, ymm1, ymm10)
|
||||
|
||||
|
||||
vmovupd(ymm4, mem(rcx))
|
||||
vmovupd(ymm6, mem(rcx, rsi, 1))
|
||||
vmovupd(ymm8, mem(rcx, rsi, 2))
|
||||
vmovupd(ymm10, mem(rcx, r13, 1))
|
||||
|
||||
|
||||
lea(mem(rcx, rsi, 4), rcx)
|
||||
|
||||
|
||||
vunpcklpd(ymm14, ymm12, ymm0)
|
||||
vunpckhpd(ymm14, ymm12, ymm1)
|
||||
vextractf128(imm(0x1), ymm0, xmm2)
|
||||
vextractf128(imm(0x1), ymm1, xmm4)
|
||||
|
||||
|
||||
vmovupd(xmm0, mem(r14))
|
||||
vmovupd(xmm1, mem(r14, rsi, 1))
|
||||
vmovupd(xmm2, mem(r14, rsi, 2))
|
||||
vmovupd(xmm4, mem(r14, r13, 1))
|
||||
|
||||
|
||||
lea(mem(r14, rsi, 4), r14)
|
||||
|
||||
|
||||
|
||||
|
||||
vunpcklpd(ymm7, ymm5, ymm0)
|
||||
vunpckhpd(ymm7, ymm5, ymm1)
|
||||
vunpcklpd(ymm11, ymm9, ymm2)
|
||||
@@ -1601,32 +1617,31 @@ void bli_dgemm_haswell_asm_6x8
|
||||
vinsertf128(imm(0x1), xmm3, ymm1, ymm7)
|
||||
vperm2f128(imm(0x31), ymm2, ymm0, ymm9)
|
||||
vperm2f128(imm(0x31), ymm3, ymm1, ymm11)
|
||||
|
||||
|
||||
vmovupd(ymm5, mem(rcx))
|
||||
vmovupd(ymm7, mem(rcx, rsi, 1))
|
||||
vmovupd(ymm9, mem(rcx, rsi, 2))
|
||||
vmovupd(ymm11, mem(rcx, r13, 1))
|
||||
|
||||
|
||||
//lea(mem(rcx, rsi, 4), rcx)
|
||||
|
||||
|
||||
vunpcklpd(ymm15, ymm13, ymm0)
|
||||
vunpckhpd(ymm15, ymm13, ymm1)
|
||||
vextractf128(imm(0x1), ymm0, xmm2)
|
||||
vextractf128(imm(0x1), ymm1, xmm4)
|
||||
|
||||
|
||||
vmovupd(xmm0, mem(r14))
|
||||
vmovupd(xmm1, mem(r14, rsi, 1))
|
||||
vmovupd(xmm2, mem(r14, rsi, 2))
|
||||
vmovupd(xmm4, mem(r14, r13, 1))
|
||||
|
||||
//lea(mem(r14, rsi, 4), r14)
|
||||
|
||||
|
||||
|
||||
label(.DDONE)
|
||||
|
||||
//lea(mem(r14, rsi, 4), r14)
|
||||
|
||||
|
||||
|
||||
label(.DDONE)
|
||||
vzeroupper()
|
||||
|
||||
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
|
||||
@@ -867,8 +867,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m
|
||||
|
||||
|
||||
label(.DRETURN)
|
||||
|
||||
|
||||
vzeroupper()
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
|
||||
@@ -1951,12 +1951,12 @@ static err_t bli_sgemm_small
|
||||
tA_packed = D_A_pack;
|
||||
|
||||
#ifdef BLIS_ENABLE_PREFETCH
|
||||
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 7), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 15), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0);
|
||||
#endif
|
||||
// clear scratch registers.
|
||||
ymm4 = _mm256_setzero_pd();
|
||||
@@ -2111,12 +2111,12 @@ static err_t bli_sgemm_small
|
||||
tA = tA_packed + row_idx_packed;
|
||||
|
||||
#ifdef BLIS_ENABLE_PREFETCH
|
||||
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 7), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 15), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0);
|
||||
#endif
|
||||
// clear scratch registers.
|
||||
ymm4 = _mm256_setzero_pd();
|
||||
@@ -4513,12 +4513,12 @@ err_t bli_dgemm_small_At
|
||||
tA = tA_packed + row_idx_packed;
|
||||
|
||||
#ifdef BLIS_ENABLE_PREFETCH
|
||||
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 7), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 15), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0);
|
||||
_mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0);
|
||||
#endif
|
||||
// clear scratch registers.
|
||||
ymm4 = _mm256_setzero_pd();
|
||||
|
||||
Reference in New Issue
Block a user