From e2e1dadee11e3ca1c2b9821192bfd23dcab6578e Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Wed, 10 Aug 2022 03:59:00 -0500 Subject: [PATCH] DGEMM Improvements - We prefetch next panel while packing 8xk panel. - Modified prefetch offsets for dgemm native and dgemm_small kernel. AMD-Internal: [CPUPL-2366] Change-Id: Ife609e789c8b87169c73bb0a30d6f1af20fb30ed --- frame/compat/bla_gemm_amd.c | 7 +- .../haswell/1m/bli_packm_haswell_asm_d8xk.c | 14 +- kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c | 437 +++++++++--------- .../3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c | 3 +- kernels/zen/3/bli_gemm_small.c | 36 +- 5 files changed, 261 insertions(+), 236 deletions(-) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index d46a69f0f..942d94f34 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -556,9 +556,10 @@ void dgemm_ #ifdef BLIS_ENABLE_SMALL_MATRIX - //if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) - if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || - ((n0 <= 10) && (k0 <=10)) ) + if(((m0 == n0) && (m0 < 400) && (k0 < 1000)) || + ( (m0 != n0) && (( ((m0 + n0 -k0) < 1500) && + ((m0 + k0-n0) < 1500) && ((n0 + k0-m0) < 1500) ) || + ((n0 <= 100) && (k0 <=100))))) { err_t status = BLIS_FAILURE; if (bli_is_notrans(blis_transa)) diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c index 9deb564ce..3b03d38fb 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c @@ -101,6 +101,8 @@ void bli_dpackm_haswell_asm_8xk // assembly region, this constraint should be lifted. const bool unitk = bli_deq1( *kappa ); + double* restrict a_next = a + cdim0; + // ------------------------------------------------------------------------- @@ -267,7 +269,7 @@ void bli_dpackm_haswell_asm_8xk label(.DCOLUNIT) lea(mem(r10, r10, 2), r13) // r13 = 3*lda - + mov(var(a_next), rcx) mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONKLEFTCOLU) // if i == 0, jump to code that @@ -278,22 +280,27 @@ void bli_dpackm_haswell_asm_8xk vmovupd(mem(rax, 0), ymm0) vmovupd(mem(rax, 32), ymm1) + prefetch(0, mem(rcx,7*8)) vmovupd(ymm0, mem(rbx, 0*64+ 0)) vmovupd(ymm1, mem(rbx, 0*64+32)) vmovupd(mem(rax, r10, 1, 0), ymm2) vmovupd(mem(rax, r10, 1, 32), ymm3) + prefetch(0, mem(rcx, r10, 1,7*8)) vmovupd(ymm2, mem(rbx, 1*64+ 0)) vmovupd(ymm3, mem(rbx, 1*64+32)) vmovupd(mem(rax, r10, 2, 0), ymm4) vmovupd(mem(rax, r10, 2, 32), ymm5) + prefetch(0, mem(rcx, r10, 2,7*8)) vmovupd(ymm4, mem(rbx, 2*64+ 0)) vmovupd(ymm5, mem(rbx, 2*64+32)) vmovupd(mem(rax, r13, 1, 0), ymm6) vmovupd(mem(rax, r13, 1, 32), ymm7) + prefetch(0, mem(rcx, r13, 1,7*8)) add(r14, rax) // a += 4*lda; + add(r14, rcx) vmovupd(ymm6, mem(rbx, 3*64+ 0)) vmovupd(ymm7, mem(rbx, 3*64+32)) add(imm(4*8*8), rbx) // p += 4*ldp = 4*8; @@ -315,7 +322,9 @@ void bli_dpackm_haswell_asm_8xk vmovupd(mem(rax, 0), ymm0) vmovupd(mem(rax, 32), ymm1) + prefetch(0, mem(rcx,7*8)) add(r10, rax) // a += lda; + add(r10, rcx) vmovupd(ymm0, mem(rbx, 0*64+ 0)) vmovupd(ymm1, mem(rbx, 0*64+32)) add(imm(8*8), rbx) // p += ldp = 8; @@ -343,7 +352,8 @@ void bli_dpackm_haswell_asm_8xk [p] "m" (p), [ldp] "m" (ldp), [kappa] "m" (kappa), - [one] "m" (one) + [one] "m" (one), + [a_next] "m" (a_next) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15", diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index e6d47268f..5187d0bcb 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -102,7 +102,19 @@ void bli_sgemm_haswell_asm_6x16 begin_asm() - vzeroall() // zero all xmm/ymm registers. + //vzeroall() // zero all xmm/ymm registers. + vxorps( ymm4, ymm4, ymm4) + vmovaps( ymm4, ymm5) + vmovaps( ymm4, ymm6) + vmovaps( ymm4, ymm7) + vmovaps( ymm4, ymm8) + vmovaps( ymm4, ymm9) + vmovaps( ymm4, ymm10) + vmovaps( ymm4, ymm11) + vmovaps( ymm4, ymm12) + vmovaps( ymm4, ymm13) + vmovaps( ymm4, ymm14) + vmovaps( ymm4, ymm15) mov(var(a), rax) // load address of a. @@ -141,7 +153,7 @@ void bli_sgemm_haswell_asm_6x16 // iteration 0 prefetch(0, mem(rax, 64*4)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -167,6 +179,8 @@ void bli_sgemm_haswell_asm_6x16 vmovaps(mem(rbx, -1*32), ymm1) // iteration 1 + prefetch(0, mem(rax, 72*4)) + vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -192,7 +206,7 @@ void bli_sgemm_haswell_asm_6x16 vmovaps(mem(rbx, 1*32), ymm1) // iteration 2 - prefetch(0, mem(rax, 76*4)) + prefetch(0, mem(rax, 80*4)) vbroadcastss(mem(rax, 12*4), ymm2) vbroadcastss(mem(rax, 13*4), ymm3) @@ -1010,76 +1024,78 @@ void bli_dgemm_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - + // iteration 1 + prefetch(0, mem(rax, 72*8)) + vbroadcastsd(mem(rax, 6*8), ymm2) vbroadcastsd(mem(rax, 7*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 8*8), ymm2) vbroadcastsd(mem(rax, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 10*8), ymm2) vbroadcastsd(mem(rax, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - + // iteration 2 - prefetch(0, mem(rax, 76*8)) - + prefetch(0, mem(rax, 80*8)) + vbroadcastsd(mem(rax, 12*8), ymm2) vbroadcastsd(mem(rax, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 14*8), ymm2) vbroadcastsd(mem(rax, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 16*8), ymm2) vbroadcastsd(mem(rax, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) @@ -1087,91 +1103,91 @@ void bli_dgemm_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 20*8), ymm2) vbroadcastsd(mem(rax, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 22*8), ymm2) vbroadcastsd(mem(rax, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*6*8), rax) // a += 4*6 (unroll x mr) add(imm(4*8*8), rbx) // b += 4*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*6*8), rax) // a += 1*6 (unroll x mr) add(imm(1*8*8), rbx) // b += 1*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) @@ -1184,179 +1200,179 @@ void bli_dgemm_haswell_asm_6x8 vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; - + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; //lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; //lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. jz(.DROWSTORED) // jump to row storage case - - + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - - + + + label(.DGENSTORED) - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm4, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm6, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm8, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm10, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm12, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm14, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - + + mov(rdx, rcx) // rcx = c + 4*cs_c - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm5, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm7, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm9, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm11, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm13, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm15, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - - + + + jmp(.DDONE) // jump to end. - - - + + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx), ymm3, ymm4) vmovupd(ymm4, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm5) vmovupd(ymm5, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm6) vmovupd(ymm6, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm7) vmovupd(ymm7, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm8) vmovupd(ymm8, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm9) vmovupd(ymm9, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm10) vmovupd(ymm10, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm11) vmovupd(ymm11, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm12) vmovupd(ymm12, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm13) vmovupd(ymm13, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm14) vmovupd(ymm14, mem(rcx)) //add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm15) vmovupd(ymm15, mem(rdx)) //add(rdi, rdx) - - - + + + jmp(.DDONE) // jump to end. - - - + + + label(.DCOLSTORED) - - + + vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1365,9 +1381,9 @@ void bli_dgemm_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm6) vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - + vbroadcastsd(mem(rbx), ymm3) - + vfmadd231pd(mem(rcx), ymm3, ymm4) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) @@ -1376,14 +1392,14 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, r13, 1)) - + lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm14, ymm12, ymm0) vunpckhpd(ymm14, ymm12, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - + vfmadd231pd(mem(r14), xmm3, xmm0) vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) @@ -1392,10 +1408,10 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm4, mem(r14, r13, 1)) - + lea(mem(r14, rsi, 4), r14) - - + + vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -1404,9 +1420,9 @@ void bli_dgemm_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm7) vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - + vbroadcastsd(mem(rbx), ymm3) - + vfmadd231pd(mem(rcx), ymm3, ymm5) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) @@ -1415,14 +1431,14 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, r13, 1)) - + //lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm15, ymm13, ymm0) vunpckhpd(ymm15, ymm13, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - + vfmadd231pd(mem(r14), xmm3, xmm0) vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) @@ -1431,139 +1447,139 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm4, mem(r14, r13, 1)) - + //lea(mem(r14, rsi, 4), r14) - - - + + + jmp(.DDONE) // jump to end. - - - + + + label(.DBETAZERO) - + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. jz(.DROWSTORBZ) // jump to row storage case - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - - + + + label(.DGENSTORBZ) - - + + vmovapd(ymm4, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm6, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm8, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm10, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm12, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm14, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - + + mov(rdx, rcx) // rcx = c + 4*cs_c - - + + vmovapd(ymm5, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm7, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm9, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm11, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm13, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm15, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - - + + + jmp(.DDONE) // jump to end. - - - + + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx)) add(rdi, rcx) vmovupd(ymm5, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm6, mem(rcx)) add(rdi, rcx) vmovupd(ymm7, mem(rdx)) add(rdi, rdx) - - + + vmovupd(ymm8, mem(rcx)) add(rdi, rcx) vmovupd(ymm9, mem(rdx)) add(rdi, rdx) - - + + vmovupd(ymm10, mem(rcx)) add(rdi, rcx) vmovupd(ymm11, mem(rdx)) add(rdi, rdx) - - + + vmovupd(ymm12, mem(rcx)) add(rdi, rcx) vmovupd(ymm13, mem(rdx)) add(rdi, rdx) - - + + vmovupd(ymm14, mem(rcx)) //add(rdi, rcx) vmovupd(ymm15, mem(rdx)) //add(rdi, rdx) - - + + jmp(.DDONE) // jump to end. - - - + + + label(.DCOLSTORBZ) - - + + vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1572,27 +1588,27 @@ void bli_dgemm_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm6) vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - + vmovupd(ymm4, mem(rcx)) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, r13, 1)) - + lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm14, ymm12, ymm0) vunpckhpd(ymm14, ymm12, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - + vmovupd(xmm0, mem(r14)) vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm4, mem(r14, r13, 1)) - + lea(mem(r14, rsi, 4), r14) - - + + vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -1601,32 +1617,31 @@ void bli_dgemm_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm7) vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - + vmovupd(ymm5, mem(rcx)) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, r13, 1)) - + //lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm15, ymm13, ymm0) vunpckhpd(ymm15, ymm13, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - + vmovupd(xmm0, mem(r14)) vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm4, mem(r14, r13, 1)) - - //lea(mem(r14, rsi, 4), r14) - - - - label(.DDONE) + //lea(mem(r14, rsi, 4), r14) + + + + label(.DDONE) vzeroupper() - + end_asm( : // output operands (none) diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index 41da73f36..107917d07 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -867,8 +867,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index 0cf5c8c5c..e232e28e5 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -1951,12 +1951,12 @@ static err_t bli_sgemm_small tA_packed = D_A_pack; #ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0); #endif // clear scratch registers. ymm4 = _mm256_setzero_pd(); @@ -2111,12 +2111,12 @@ static err_t bli_sgemm_small tA = tA_packed + row_idx_packed; #ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0); #endif // clear scratch registers. ymm4 = _mm256_setzero_pd(); @@ -4513,12 +4513,12 @@ err_t bli_dgemm_small_At tA = tA_packed + row_idx_packed; #ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0); #endif // clear scratch registers. ymm4 = _mm256_setzero_pd();