DGEMM Improvements

- We prefetch next panel while packing 8xk panel.
- Modified prefetch offsets for dgemm native and 
  dgemm_small kernel.

AMD-Internal: [CPUPL-2366]

Change-Id: Ife609e789c8b87169c73bb0a30d6f1af20fb30ed
This commit is contained in:
Harsh Dave
2022-08-10 03:59:00 -05:00
parent 88e44c64e3
commit e2e1dadee1
5 changed files with 261 additions and 236 deletions

View File

@@ -556,9 +556,10 @@ void dgemm_
#ifdef BLIS_ENABLE_SMALL_MATRIX
//if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2))
if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) ||
((n0 <= 10) && (k0 <=10)) )
if(((m0 == n0) && (m0 < 400) && (k0 < 1000)) ||
( (m0 != n0) && (( ((m0 + n0 -k0) < 1500) &&
((m0 + k0-n0) < 1500) && ((n0 + k0-m0) < 1500) ) ||
((n0 <= 100) && (k0 <=100)))))
{
err_t status = BLIS_FAILURE;
if (bli_is_notrans(blis_transa))

View File

@@ -101,6 +101,8 @@ void bli_dpackm_haswell_asm_8xk
// assembly region, this constraint should be lifted.
const bool unitk = bli_deq1( *kappa );
double* restrict a_next = a + cdim0;
// -------------------------------------------------------------------------
@@ -267,7 +269,7 @@ void bli_dpackm_haswell_asm_8xk
label(.DCOLUNIT)
lea(mem(r10, r10, 2), r13) // r13 = 3*lda
mov(var(a_next), rcx)
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.DCONKLEFTCOLU) // if i == 0, jump to code that
@@ -278,22 +280,27 @@ void bli_dpackm_haswell_asm_8xk
vmovupd(mem(rax, 0), ymm0)
vmovupd(mem(rax, 32), ymm1)
prefetch(0, mem(rcx,7*8))
vmovupd(ymm0, mem(rbx, 0*64+ 0))
vmovupd(ymm1, mem(rbx, 0*64+32))
vmovupd(mem(rax, r10, 1, 0), ymm2)
vmovupd(mem(rax, r10, 1, 32), ymm3)
prefetch(0, mem(rcx, r10, 1,7*8))
vmovupd(ymm2, mem(rbx, 1*64+ 0))
vmovupd(ymm3, mem(rbx, 1*64+32))
vmovupd(mem(rax, r10, 2, 0), ymm4)
vmovupd(mem(rax, r10, 2, 32), ymm5)
prefetch(0, mem(rcx, r10, 2,7*8))
vmovupd(ymm4, mem(rbx, 2*64+ 0))
vmovupd(ymm5, mem(rbx, 2*64+32))
vmovupd(mem(rax, r13, 1, 0), ymm6)
vmovupd(mem(rax, r13, 1, 32), ymm7)
prefetch(0, mem(rcx, r13, 1,7*8))
add(r14, rax) // a += 4*lda;
add(r14, rcx)
vmovupd(ymm6, mem(rbx, 3*64+ 0))
vmovupd(ymm7, mem(rbx, 3*64+32))
add(imm(4*8*8), rbx) // p += 4*ldp = 4*8;
@@ -315,7 +322,9 @@ void bli_dpackm_haswell_asm_8xk
vmovupd(mem(rax, 0), ymm0)
vmovupd(mem(rax, 32), ymm1)
prefetch(0, mem(rcx,7*8))
add(r10, rax) // a += lda;
add(r10, rcx)
vmovupd(ymm0, mem(rbx, 0*64+ 0))
vmovupd(ymm1, mem(rbx, 0*64+32))
add(imm(8*8), rbx) // p += ldp = 8;
@@ -343,7 +352,8 @@ void bli_dpackm_haswell_asm_8xk
[p] "m" (p),
[ldp] "m" (ldp),
[kappa] "m" (kappa),
[one] "m" (one)
[one] "m" (one),
[a_next] "m" (a_next)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15",

View File

@@ -102,7 +102,19 @@ void bli_sgemm_haswell_asm_6x16
begin_asm()
vzeroall() // zero all xmm/ymm registers.
//vzeroall() // zero all xmm/ymm registers.
vxorps( ymm4, ymm4, ymm4)
vmovaps( ymm4, ymm5)
vmovaps( ymm4, ymm6)
vmovaps( ymm4, ymm7)
vmovaps( ymm4, ymm8)
vmovaps( ymm4, ymm9)
vmovaps( ymm4, ymm10)
vmovaps( ymm4, ymm11)
vmovaps( ymm4, ymm12)
vmovaps( ymm4, ymm13)
vmovaps( ymm4, ymm14)
vmovaps( ymm4, ymm15)
mov(var(a), rax) // load address of a.
@@ -141,7 +153,7 @@ void bli_sgemm_haswell_asm_6x16
// iteration 0
prefetch(0, mem(rax, 64*4))
vbroadcastss(mem(rax, 0*4), ymm2)
vbroadcastss(mem(rax, 1*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm4)
@@ -167,6 +179,8 @@ void bli_sgemm_haswell_asm_6x16
vmovaps(mem(rbx, -1*32), ymm1)
// iteration 1
prefetch(0, mem(rax, 72*4))
vbroadcastss(mem(rax, 6*4), ymm2)
vbroadcastss(mem(rax, 7*4), ymm3)
vfmadd231ps(ymm0, ymm2, ymm4)
@@ -192,7 +206,7 @@ void bli_sgemm_haswell_asm_6x16
vmovaps(mem(rbx, 1*32), ymm1)
// iteration 2
prefetch(0, mem(rax, 76*4))
prefetch(0, mem(rax, 80*4))
vbroadcastss(mem(rax, 12*4), ymm2)
vbroadcastss(mem(rax, 13*4), ymm3)
@@ -1010,76 +1024,78 @@ void bli_dgemm_haswell_asm_6x8
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 2*8), ymm2)
vbroadcastsd(mem(rax, 3*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 4*8), ymm2)
vbroadcastsd(mem(rax, 5*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
vmovapd(mem(rbx, -2*32), ymm0)
vmovapd(mem(rbx, -1*32), ymm1)
// iteration 1
prefetch(0, mem(rax, 72*8))
vbroadcastsd(mem(rax, 6*8), ymm2)
vbroadcastsd(mem(rax, 7*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 8*8), ymm2)
vbroadcastsd(mem(rax, 9*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 10*8), ymm2)
vbroadcastsd(mem(rax, 11*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
vmovapd(mem(rbx, 0*32), ymm0)
vmovapd(mem(rbx, 1*32), ymm1)
// iteration 2
prefetch(0, mem(rax, 76*8))
prefetch(0, mem(rax, 80*8))
vbroadcastsd(mem(rax, 12*8), ymm2)
vbroadcastsd(mem(rax, 13*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 14*8), ymm2)
vbroadcastsd(mem(rax, 15*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 16*8), ymm2)
vbroadcastsd(mem(rax, 17*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
vmovapd(mem(rbx, 2*32), ymm0)
vmovapd(mem(rbx, 3*32), ymm1)
// iteration 3
vbroadcastsd(mem(rax, 18*8), ymm2)
vbroadcastsd(mem(rax, 19*8), ymm3)
@@ -1087,91 +1103,91 @@ void bli_dgemm_haswell_asm_6x8
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 20*8), ymm2)
vbroadcastsd(mem(rax, 21*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 22*8), ymm2)
vbroadcastsd(mem(rax, 23*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
add(imm(4*6*8), rax) // a += 4*6 (unroll x mr)
add(imm(4*8*8), rbx) // b += 4*8 (unroll x nr)
vmovapd(mem(rbx, -4*32), ymm0)
vmovapd(mem(rbx, -3*32), ymm1)
dec(rsi) // i -= 1;
jne(.DLOOPKITER) // iterate again if i != 0.
label(.DCONSIDKLEFT)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.DPOSTACCUM) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.DLOOPKLEFT) // EDGE LOOP
prefetch(0, mem(rax, 64*8))
vbroadcastsd(mem(rax, 0*8), ymm2)
vbroadcastsd(mem(rax, 1*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 2*8), ymm2)
vbroadcastsd(mem(rax, 3*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 4*8), ymm2)
vbroadcastsd(mem(rax, 5*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
add(imm(1*6*8), rax) // a += 1*6 (unroll x mr)
add(imm(1*8*8), rbx) // b += 1*8 (unroll x nr)
vmovapd(mem(rbx, -4*32), ymm0)
vmovapd(mem(rbx, -3*32), ymm1)
dec(rsi) // i -= 1;
jne(.DLOOPKLEFT) // iterate again if i != 0.
label(.DPOSTACCUM)
mov(var(alpha), rax) // load address of alpha
mov(var(beta), rbx) // load address of beta
vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate
vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate
vmulpd(ymm0, ymm4, ymm4) // scale by alpha
vmulpd(ymm0, ymm5, ymm5)
vmulpd(ymm0, ymm6, ymm6)
@@ -1184,179 +1200,179 @@ void bli_dgemm_haswell_asm_6x8
vmulpd(ymm0, ymm13, ymm13)
vmulpd(ymm0, ymm14, ymm14)
vmulpd(ymm0, ymm15, ymm15)
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double)
lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c;
lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c;
lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c;
//lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c;
//lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c;
// now avoid loading C if beta == 0
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
vucomisd(xmm0, xmm3) // set ZF if beta == 0.
je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case
cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8.
jz(.DROWSTORED) // jump to row storage case
cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8.
jz(.DCOLSTORED) // jump to column storage case
label(.DGENSTORED)
DGEMM_INPUT_GS_BETA_NZ
vfmadd213pd(ymm4, ymm3, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
DGEMM_INPUT_GS_BETA_NZ
vfmadd213pd(ymm6, ymm3, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
DGEMM_INPUT_GS_BETA_NZ
vfmadd213pd(ymm8, ymm3, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
DGEMM_INPUT_GS_BETA_NZ
vfmadd213pd(ymm10, ymm3, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
DGEMM_INPUT_GS_BETA_NZ
vfmadd213pd(ymm12, ymm3, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
DGEMM_INPUT_GS_BETA_NZ
vfmadd213pd(ymm14, ymm3, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
mov(rdx, rcx) // rcx = c + 4*cs_c
DGEMM_INPUT_GS_BETA_NZ
vfmadd213pd(ymm5, ymm3, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
DGEMM_INPUT_GS_BETA_NZ
vfmadd213pd(ymm7, ymm3, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
DGEMM_INPUT_GS_BETA_NZ
vfmadd213pd(ymm9, ymm3, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
DGEMM_INPUT_GS_BETA_NZ
vfmadd213pd(ymm11, ymm3, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
DGEMM_INPUT_GS_BETA_NZ
vfmadd213pd(ymm13, ymm3, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
DGEMM_INPUT_GS_BETA_NZ
vfmadd213pd(ymm15, ymm3, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
vfmadd231pd(mem(rcx), ymm3, ymm4)
vmovupd(ymm4, mem(rcx))
add(rdi, rcx)
vfmadd231pd(mem(rdx), ymm3, ymm5)
vmovupd(ymm5, mem(rdx))
add(rdi, rdx)
vfmadd231pd(mem(rcx), ymm3, ymm6)
vmovupd(ymm6, mem(rcx))
add(rdi, rcx)
vfmadd231pd(mem(rdx), ymm3, ymm7)
vmovupd(ymm7, mem(rdx))
add(rdi, rdx)
vfmadd231pd(mem(rcx), ymm3, ymm8)
vmovupd(ymm8, mem(rcx))
add(rdi, rcx)
vfmadd231pd(mem(rdx), ymm3, ymm9)
vmovupd(ymm9, mem(rdx))
add(rdi, rdx)
vfmadd231pd(mem(rcx), ymm3, ymm10)
vmovupd(ymm10, mem(rcx))
add(rdi, rcx)
vfmadd231pd(mem(rdx), ymm3, ymm11)
vmovupd(ymm11, mem(rdx))
add(rdi, rdx)
vfmadd231pd(mem(rcx), ymm3, ymm12)
vmovupd(ymm12, mem(rcx))
add(rdi, rcx)
vfmadd231pd(mem(rdx), ymm3, ymm13)
vmovupd(ymm13, mem(rdx))
add(rdi, rdx)
vfmadd231pd(mem(rcx), ymm3, ymm14)
vmovupd(ymm14, mem(rcx))
//add(rdi, rcx)
vfmadd231pd(mem(rdx), ymm3, ymm15)
vmovupd(ymm15, mem(rdx))
//add(rdi, rdx)
jmp(.DDONE) // jump to end.
label(.DCOLSTORED)
vunpcklpd(ymm6, ymm4, ymm0)
vunpckhpd(ymm6, ymm4, ymm1)
vunpcklpd(ymm10, ymm8, ymm2)
@@ -1365,9 +1381,9 @@ void bli_dgemm_haswell_asm_6x8
vinsertf128(imm(0x1), xmm3, ymm1, ymm6)
vperm2f128(imm(0x31), ymm2, ymm0, ymm8)
vperm2f128(imm(0x31), ymm3, ymm1, ymm10)
vbroadcastsd(mem(rbx), ymm3)
vfmadd231pd(mem(rcx), ymm3, ymm4)
vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6)
vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8)
@@ -1376,14 +1392,14 @@ void bli_dgemm_haswell_asm_6x8
vmovupd(ymm6, mem(rcx, rsi, 1))
vmovupd(ymm8, mem(rcx, rsi, 2))
vmovupd(ymm10, mem(rcx, r13, 1))
lea(mem(rcx, rsi, 4), rcx)
vunpcklpd(ymm14, ymm12, ymm0)
vunpckhpd(ymm14, ymm12, ymm1)
vextractf128(imm(0x1), ymm0, xmm2)
vextractf128(imm(0x1), ymm1, xmm4)
vfmadd231pd(mem(r14), xmm3, xmm0)
vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1)
vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2)
@@ -1392,10 +1408,10 @@ void bli_dgemm_haswell_asm_6x8
vmovupd(xmm1, mem(r14, rsi, 1))
vmovupd(xmm2, mem(r14, rsi, 2))
vmovupd(xmm4, mem(r14, r13, 1))
lea(mem(r14, rsi, 4), r14)
vunpcklpd(ymm7, ymm5, ymm0)
vunpckhpd(ymm7, ymm5, ymm1)
vunpcklpd(ymm11, ymm9, ymm2)
@@ -1404,9 +1420,9 @@ void bli_dgemm_haswell_asm_6x8
vinsertf128(imm(0x1), xmm3, ymm1, ymm7)
vperm2f128(imm(0x31), ymm2, ymm0, ymm9)
vperm2f128(imm(0x31), ymm3, ymm1, ymm11)
vbroadcastsd(mem(rbx), ymm3)
vfmadd231pd(mem(rcx), ymm3, ymm5)
vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7)
vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9)
@@ -1415,14 +1431,14 @@ void bli_dgemm_haswell_asm_6x8
vmovupd(ymm7, mem(rcx, rsi, 1))
vmovupd(ymm9, mem(rcx, rsi, 2))
vmovupd(ymm11, mem(rcx, r13, 1))
//lea(mem(rcx, rsi, 4), rcx)
vunpcklpd(ymm15, ymm13, ymm0)
vunpckhpd(ymm15, ymm13, ymm1)
vextractf128(imm(0x1), ymm0, xmm2)
vextractf128(imm(0x1), ymm1, xmm4)
vfmadd231pd(mem(r14), xmm3, xmm0)
vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1)
vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2)
@@ -1431,139 +1447,139 @@ void bli_dgemm_haswell_asm_6x8
vmovupd(xmm1, mem(r14, rsi, 1))
vmovupd(xmm2, mem(r14, rsi, 2))
vmovupd(xmm4, mem(r14, r13, 1))
//lea(mem(r14, rsi, 4), r14)
jmp(.DDONE) // jump to end.
label(.DBETAZERO)
cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8.
jz(.DROWSTORBZ) // jump to row storage case
cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8.
jz(.DCOLSTORBZ) // jump to column storage case
label(.DGENSTORBZ)
vmovapd(ymm4, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
vmovapd(ymm6, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
vmovapd(ymm8, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
vmovapd(ymm10, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
vmovapd(ymm12, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
vmovapd(ymm14, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
mov(rdx, rcx) // rcx = c + 4*cs_c
vmovapd(ymm5, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
vmovapd(ymm7, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
vmovapd(ymm9, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
vmovapd(ymm11, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
vmovapd(ymm13, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
add(rdi, rcx) // c += rs_c;
vmovapd(ymm15, ymm0)
DGEMM_OUTPUT_GS_BETA_NZ
jmp(.DDONE) // jump to end.
label(.DROWSTORBZ)
vmovupd(ymm4, mem(rcx))
add(rdi, rcx)
vmovupd(ymm5, mem(rdx))
add(rdi, rdx)
vmovupd(ymm6, mem(rcx))
add(rdi, rcx)
vmovupd(ymm7, mem(rdx))
add(rdi, rdx)
vmovupd(ymm8, mem(rcx))
add(rdi, rcx)
vmovupd(ymm9, mem(rdx))
add(rdi, rdx)
vmovupd(ymm10, mem(rcx))
add(rdi, rcx)
vmovupd(ymm11, mem(rdx))
add(rdi, rdx)
vmovupd(ymm12, mem(rcx))
add(rdi, rcx)
vmovupd(ymm13, mem(rdx))
add(rdi, rdx)
vmovupd(ymm14, mem(rcx))
//add(rdi, rcx)
vmovupd(ymm15, mem(rdx))
//add(rdi, rdx)
jmp(.DDONE) // jump to end.
label(.DCOLSTORBZ)
vunpcklpd(ymm6, ymm4, ymm0)
vunpckhpd(ymm6, ymm4, ymm1)
vunpcklpd(ymm10, ymm8, ymm2)
@@ -1572,27 +1588,27 @@ void bli_dgemm_haswell_asm_6x8
vinsertf128(imm(0x1), xmm3, ymm1, ymm6)
vperm2f128(imm(0x31), ymm2, ymm0, ymm8)
vperm2f128(imm(0x31), ymm3, ymm1, ymm10)
vmovupd(ymm4, mem(rcx))
vmovupd(ymm6, mem(rcx, rsi, 1))
vmovupd(ymm8, mem(rcx, rsi, 2))
vmovupd(ymm10, mem(rcx, r13, 1))
lea(mem(rcx, rsi, 4), rcx)
vunpcklpd(ymm14, ymm12, ymm0)
vunpckhpd(ymm14, ymm12, ymm1)
vextractf128(imm(0x1), ymm0, xmm2)
vextractf128(imm(0x1), ymm1, xmm4)
vmovupd(xmm0, mem(r14))
vmovupd(xmm1, mem(r14, rsi, 1))
vmovupd(xmm2, mem(r14, rsi, 2))
vmovupd(xmm4, mem(r14, r13, 1))
lea(mem(r14, rsi, 4), r14)
vunpcklpd(ymm7, ymm5, ymm0)
vunpckhpd(ymm7, ymm5, ymm1)
vunpcklpd(ymm11, ymm9, ymm2)
@@ -1601,32 +1617,31 @@ void bli_dgemm_haswell_asm_6x8
vinsertf128(imm(0x1), xmm3, ymm1, ymm7)
vperm2f128(imm(0x31), ymm2, ymm0, ymm9)
vperm2f128(imm(0x31), ymm3, ymm1, ymm11)
vmovupd(ymm5, mem(rcx))
vmovupd(ymm7, mem(rcx, rsi, 1))
vmovupd(ymm9, mem(rcx, rsi, 2))
vmovupd(ymm11, mem(rcx, r13, 1))
//lea(mem(rcx, rsi, 4), rcx)
vunpcklpd(ymm15, ymm13, ymm0)
vunpckhpd(ymm15, ymm13, ymm1)
vextractf128(imm(0x1), ymm0, xmm2)
vextractf128(imm(0x1), ymm1, xmm4)
vmovupd(xmm0, mem(r14))
vmovupd(xmm1, mem(r14, rsi, 1))
vmovupd(xmm2, mem(r14, rsi, 2))
vmovupd(xmm4, mem(r14, r13, 1))
//lea(mem(r14, rsi, 4), r14)
label(.DDONE)
//lea(mem(r14, rsi, 4), r14)
label(.DDONE)
vzeroupper()
end_asm(
: // output operands (none)

View File

@@ -867,8 +867,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m
label(.DRETURN)
vzeroupper()
end_asm(
: // output operands (none)

View File

@@ -1951,12 +1951,12 @@ static err_t bli_sgemm_small
tA_packed = D_A_pack;
#ifdef BLIS_ENABLE_PREFETCH
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 7), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 15), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0);
#endif
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
@@ -2111,12 +2111,12 @@ static err_t bli_sgemm_small
tA = tA_packed + row_idx_packed;
#ifdef BLIS_ENABLE_PREFETCH
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 7), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 15), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0);
#endif
// clear scratch registers.
ymm4 = _mm256_setzero_pd();
@@ -4513,12 +4513,12 @@ err_t bli_dgemm_small_At
tA = tA_packed + row_idx_packed;
#ifdef BLIS_ENABLE_PREFETCH
_mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 7), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 15), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0);
_mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0);
_mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0);
#endif
// clear scratch registers.
ymm4 = _mm256_setzero_pd();