From 1c719c91a3ef0be29a918097652beef35647d4b2 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Thu, 4 Jun 2020 17:21:08 -0500 Subject: [PATCH] Bugfixes, cleanup of sup dgemm ukernels. Details: - Fixed a few not-really-bugs: - Previously, the d6x8m kernels were still prefetching the next upanel of A using MR*rs_a instead of ps_a (same for prefetching of next upanel of B in d6x8n kernels using NR*cs_b instead of ps_b). Given that the upanels might be packed, using ps_a or ps_b is the correct way to compute the prefetch address. - Fixed an obscure bug in the rd_d6x8m kernel that, by dumb luck, executed as intended even though it was based on a faulty pointer management. Basically, in the rd_d6x8m kernel, the pointer for B (stored in rdx) was loaded only once, outside of the jj loop, and in the second iteration its new position was calculated by incrementing rdx by the *absolute* offset (four columns), which happened to be the same as the relative offset (also four columns) that was needed. It worked only because that loop only executed twice. A similar issue was fixed in the rd_d6x8n kernels. - Various cleanups and additions, including: - Factored out the loading of rs_c into rdi in rd_d6x8[mn] kernels so that it is loaded only once outside of the loops rather than multiple times inside the loops. - Changed outer loop in rd kernels so that the jump/comparison and loop bounds more closely mimic what you'd see in higher-level source code. That is, something like: for( i = 0; i < 6; i+=3 ) rather than something like: for( i = 0; i <= 3; i+=3 ) - Switched row-based IO to use byte offsets instead of byte column strides (e.g. via rsi register), which were known to be 8 anyway since otherwise that conditional branch wouldn't have executed. - Cleaned up and homogenized prefetching a bit. - Updated the comments that show the before and after of the in-register transpositions. - Added comments to column-based IO cases to indicate which columns are being accessed/updated. - Added rbp register to clobber lists. - Removed some dead (commented out) code. - Fixed some copy-paste typos in comments in the rv_6x8n kernels. - Cleaned up whitespace (including leading ws -> tabs). - Moved edge case (non-milli) kernels to their own directory, d6x8, and split them into separate files based on the "NR" value of the kernels (Mx8, Mx4, Mx2, etc.). - Moved config-specific reference Mx1 kernels into their own file (e.g. bli_gemmsup_r_haswell_ref_dMx1.c) inside the d6x8 directory. - Added rd_dMx1 assembly kernels, which seems marginally faster than the corresponding reference kernels. - Updated comments in ref_kernels/bli_cntx_ref.c and changed to using the row-oriented reference kernels for all storage combos. --- .../3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c | 266 +- .../3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c | 310 +- .../3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c | 813 ++-- .../3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c | 909 ++--- .../sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c | 158 + .../d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c | 1698 +++++++++ .../d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c | 1794 +++++++++ .../d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c | 1450 ++++++++ .../d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c | 1617 ++++++++ .../d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c | 2496 +++++++++++++ .../d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c | 2600 +++++++++++++ .../d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c | 3095 ++++++++++++++++ .../d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c | 3260 +++++++++++++++++ .../old}/bli_gemmsup_rd_haswell_asm_d6x8.c | 19 +- .../old}/bli_gemmsup_rv_haswell_asm_d6x8.c | 1 + kernels/haswell/bli_kernels_haswell.h | 23 +- ref_kernels/bli_cntx_ref.c | 43 +- 17 files changed, 19158 insertions(+), 1394 deletions(-) create mode 100644 kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c create mode 100644 kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c create mode 100644 kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c create mode 100644 kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c create mode 100644 kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c create mode 100644 kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c create mode 100644 kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c create mode 100644 kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c create mode 100644 kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c rename kernels/haswell/3/sup/{ => d6x8/old}/bli_gemmsup_rd_haswell_asm_d6x8.c (99%) rename kernels/haswell/3/sup/{ => d6x8/old}/bli_gemmsup_rv_haswell_asm_d6x8.c (99%) diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c index e2476d8d1..1820277d5 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c @@ -65,23 +65,6 @@ // Prototype reference microkernels. GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) -#if 0 -// Define parameters and variables for edge case kernel map. -#define NUM_MR 4 -#define NUM_NR 4 -#define FUNCPTR_T dgemmsup_ker_ft - -static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; -static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; -static FUNCPTR_T kmap[NUM_MR][NUM_NR] = -{ /* 8 4 2 1 */ -/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8m, bli_dgemmsup_rd_haswell_asm_6x4m, bli_dgemmsup_rd_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref_6x1 }, -/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8m, bli_dgemmsup_rd_haswell_asm_3x4m, bli_dgemmsup_rd_haswell_asm_3x2m, bli_dgemmsup_r_haswell_ref_3x1 }, -/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8m, bli_dgemmsup_rd_haswell_asm_2x4m, bli_dgemmsup_rd_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref_2x1 }, -/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8m, bli_dgemmsup_rd_haswell_asm_1x4m, bli_dgemmsup_rd_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref_1x1 } -}; -#endif - void bli_dgemmsup_rd_haswell_asm_6x8m ( @@ -135,7 +118,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m } if ( 1 == n_left ) { -#if 0 + #if 0 const dim_t nr_cur = 1; bli_dgemmsup_r_haswell_ref @@ -144,14 +127,14 @@ void bli_dgemmsup_rd_haswell_asm_6x8m alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, beta, cij, rs_c0, cs_c0, data, cntx ); -#else + #else bli_dgemv_ex ( BLIS_NO_TRANSPOSE, conjb, m0, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, beta, cij, rs_c0, cntx, NULL ); -#endif + #endif } return; } @@ -193,7 +176,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a - mov(var(b), rdx) // load address of b. + //mov(var(b), rdx) // load address of b. //mov(var(rs_b), r10) // load rs_b mov(var(cs_b), r11) // load cs_b //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) @@ -204,8 +187,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8m //mov(var(c), r12) // load address of c - //mov(var(rs_c), rdi) // load rs_c - //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) @@ -222,10 +205,11 @@ void bli_dgemmsup_rd_haswell_asm_6x8m mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b mov(var(c), r12) // load address of c lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; - imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; @@ -266,14 +250,14 @@ void bli_dgemmsup_rd_haswell_asm_6x8m lea(mem(rdx), rbx) // rbx = b_jj; -#if 0 - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) - prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c #endif - lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a @@ -290,15 +274,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m // ---------------------------------- iteration 0 #if 1 - prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a - prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a - prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -327,7 +311,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -354,15 +338,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m // ---------------------------------- iteration 2 #if 1 - prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a - prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a - prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -391,7 +375,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -436,15 +420,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m label(.DLOOPKITER4) // EDGE LOOP (ymm) #if 1 - prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a - prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a - prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -493,7 +477,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) vmovsd(mem(rax, r8, 2), xmm2) - add(imm(1*8), rax) // a += 1*cs_b = 1*8; + add(imm(1*8), rax) // a += 1*cs_a = 1*8; vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -527,8 +511,6 @@ void bli_dgemmsup_rd_haswell_asm_6x8m label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 // ymm5 ymm8 ymm11 ymm14 @@ -536,14 +518,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) - vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vaddpd( xmm0, xmm1, xmm0 ) vhaddpd( ymm13, ymm10, ymm2 ) vextractf128(imm(1), ymm2, xmm1 ) - vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) - + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) vhaddpd( ymm8, ymm5, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) @@ -554,6 +537,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8m vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) vhaddpd( ymm9, ymm6, ymm0 ) @@ -565,15 +550,14 @@ void bli_dgemmsup_rd_haswell_asm_6x8m vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) - // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) - // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) - // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta @@ -661,8 +645,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8m add(imm(4), r15) // jj += 4; - cmp(imm(4), r15) // compare jj to 4 - jle(.DLOOP3X4J) // if jj <= 4, jump to beginning + cmp(imm(8), r15) // compare jj to 8 + jl(.DLOOP3X4J) // if jj < 8, jump to beginning // of jj loop; otherwise, loop ends. @@ -692,7 +676,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -803,8 +787,8 @@ void bli_dgemmsup_rd_haswell_asm_6x4m mov(var(c), r12) // load address of c - //mov(var(rs_c), rdi) // load rs_c - //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) @@ -813,13 +797,13 @@ void bli_dgemmsup_rd_haswell_asm_6x4m // rdx = rbx = b // r9 = m dim index ii // r15 = n dim index jj - // r10 = unused mov(var(m_iter), r9) // ii = m_iter; label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + #if 0 vzeroall() // zero all xmm/ymm registers. #else @@ -846,14 +830,14 @@ void bli_dgemmsup_rd_haswell_asm_6x4m lea(mem(rdx), rbx) // rbx = b; -#if 0 - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) - prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c #endif - lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a @@ -870,15 +854,15 @@ void bli_dgemmsup_rd_haswell_asm_6x4m // ---------------------------------- iteration 0 #if 1 - prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a - prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a - prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -907,7 +891,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -934,15 +918,15 @@ void bli_dgemmsup_rd_haswell_asm_6x4m // ---------------------------------- iteration 2 #if 1 - prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a - prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a - prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -971,7 +955,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1014,11 +998,17 @@ void bli_dgemmsup_rd_haswell_asm_6x4m label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 1 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1067,7 +1057,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) vmovsd(mem(rax, r8, 2), xmm2) - add(imm(1*8), rax) // a += 1*cs_b = 1*8; + add(imm(1*8), rax) // a += 1*cs_a = 1*8; vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1101,8 +1091,6 @@ void bli_dgemmsup_rd_haswell_asm_6x4m label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 // ymm5 ymm8 ymm11 ymm14 @@ -1110,14 +1098,15 @@ void bli_dgemmsup_rd_haswell_asm_6x4m vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) - vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vaddpd( xmm0, xmm1, xmm0 ) vhaddpd( ymm13, ymm10, ymm2 ) vextractf128(imm(1), ymm2, xmm1 ) - vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) - + // ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7) + // ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13) vhaddpd( ymm8, ymm5, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) @@ -1128,7 +1117,8 @@ void bli_dgemmsup_rd_haswell_asm_6x4m vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) - + // ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8) + // ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14) vhaddpd( ymm9, ymm6, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) @@ -1139,15 +1129,14 @@ void bli_dgemmsup_rd_haswell_asm_6x4m vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // ymm6[0] = sum(ymm6); ymm6[1] = sum(ymm9) + // ymm6[2] = sum(ymm12); ymm6[3] = sum(ymm15) - // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) - // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) - // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta @@ -1259,7 +1248,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1366,6 +1355,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a mov(var(c), r12) // load address of c @@ -1375,41 +1365,44 @@ void bli_dgemmsup_rd_haswell_asm_6x2m // r12 = rcx = c - // r14 = rax = a - // rdx = rbx = b - // r9 = m dim index ii + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii mov(var(m_iter), r9) // ii = m_iter; - label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] #if 0 - vzeroall() // zero all xmm/ymm registers. + vzeroall() // zero all xmm/ymm registers. #else - // skylake can execute 3 vxorpd ipc with - // a latency of 1 cycle, while vzeroall - // has a latency of 12 cycles. - vxorpd(ymm4, ymm4, ymm4) - vxorpd(ymm5, ymm5, ymm5) - vxorpd(ymm6, ymm6, ymm6) - vxorpd(ymm7, ymm7, ymm7) - vxorpd(ymm8, ymm8, ymm8) - vxorpd(ymm9, ymm9, ymm9) - vxorpd(ymm10, ymm10, ymm10) - vxorpd(ymm11, ymm11, ymm11) - vxorpd(ymm12, ymm12, ymm12) - vxorpd(ymm13, ymm13, ymm13) - vxorpd(ymm14, ymm14, ymm14) - vxorpd(ymm15, ymm15, ymm15) + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) #endif - lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c; - lea(mem(r14), rax) // rax = a + 6*ii*rs_a; - lea(mem(rdx), rbx) // rbx = b; + lea(mem(r12), rcx) // rcx = c_ii; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b; +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) lea(mem(rcx, rdi, 2), r10) // lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c @@ -1418,6 +1411,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c +#endif @@ -1433,6 +1427,12 @@ void bli_dgemmsup_rd_haswell_asm_6x2m // ---------------------------------- iteration 0 +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + vmovupd(mem(rbx ), ymm0) vmovupd(mem(rbx, r11, 1), ymm1) add(imm(4*8), rbx) // b += 4*rs_b = 4*8; @@ -1497,6 +1497,12 @@ void bli_dgemmsup_rd_haswell_asm_6x2m // ---------------------------------- iteration 2 +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif + vmovupd(mem(rbx ), ymm0) vmovupd(mem(rbx, r11, 1), ymm1) add(imm(4*8), rbx) // b += 4*rs_b = 4*8; @@ -1578,6 +1584,12 @@ void bli_dgemmsup_rd_haswell_asm_6x2m label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a +#endif vmovupd(mem(rbx ), ymm0) vmovupd(mem(rbx, r11, 1), ymm1) @@ -1703,15 +1715,18 @@ void bli_dgemmsup_rd_haswell_asm_6x2m vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm14 ) - // xmm4 = sum(ymm4) sum(ymm5) - // xmm6 = sum(ymm6) sum(ymm7) - // xmm8 = sum(ymm8) sum(ymm9) - // xmm10 = sum(ymm10) sum(ymm11) - // xmm12 = sum(ymm12) sum(ymm13) - // xmm14 = sum(ymm14) sum(ymm15) + // xmm4[0:1] = sum(ymm4) sum(ymm5) + // xmm6[0:1] = sum(ymm6) sum(ymm7) + // xmm8[0:1] = sum(ymm8) sum(ymm9) + // xmm10[0:1] = sum(ymm10) sum(ymm11) + // xmm12[0:1] = sum(ymm12) sum(ymm13) + // xmm14[0:1] = sum(ymm14) sum(ymm15) + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate @@ -1810,13 +1825,13 @@ void bli_dgemmsup_rd_haswell_asm_6x2m lea(mem(r12, rdi, 4), r12) // - lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c - lea(mem(r14, r8, 4), r14) // - lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + lea(mem(r14, r8, 4), r14) // + lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a - dec(r9) // ii -= 1; - jne(.DLOOP3X4I) // iterate again if ii != 0. + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. @@ -1846,7 +1861,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1872,6 +1887,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m const dim_t mr_cur = 3; bli_dgemmsup_rd_haswell_asm_3x2 + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, nr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, @@ -1884,6 +1900,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m const dim_t mr_cur = 2; bli_dgemmsup_rd_haswell_asm_2x2 + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, nr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, @@ -1896,6 +1913,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m const dim_t mr_cur = 1; bli_dgemmsup_rd_haswell_asm_1x2 + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, nr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c index 53cad668e..4ccc2855e 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c @@ -65,23 +65,6 @@ // Prototype reference microkernels. GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) -#if 0 -// Define parameters and variables for edge case kernel map. -#define NUM_MR 4 -#define NUM_NR 4 -#define FUNCPTR_T dgemmsup_ker_ft - -static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 }; -static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; -static FUNCPTR_T kmap[NUM_MR][NUM_NR] = -{ /* 8 4 2 1 */ -/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8n, bli_dgemmsup_rd_haswell_asm_6x4n, bli_dgemmsup_rd_haswell_asm_6x2n, bli_dgemmsup_r_haswell_ref_6x1 }, -/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8n, bli_dgemmsup_rd_haswell_asm_3x4n, bli_dgemmsup_rd_haswell_asm_3x2n, bli_dgemmsup_r_haswell_ref_3x1 }, -/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8n, bli_dgemmsup_rd_haswell_asm_2x4n, bli_dgemmsup_rd_haswell_asm_2x2n, bli_dgemmsup_r_haswell_ref_2x1 }, -/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8n, bli_dgemmsup_rd_haswell_asm_1x4n, bli_dgemmsup_rd_haswell_asm_1x2n, bli_dgemmsup_r_haswell_ref_1x1 } -}; -#endif - void bli_dgemmsup_rd_haswell_asm_6x8n ( @@ -161,6 +144,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n const dim_t mr_cur = 3; bli_dgemmsup_rd_haswell_asm_3x8n + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, n0, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, @@ -173,6 +157,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n const dim_t mr_cur = 2; bli_dgemmsup_rd_haswell_asm_2x8n + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, n0, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, @@ -182,23 +167,24 @@ void bli_dgemmsup_rd_haswell_asm_6x8n } if ( 1 == m_left ) { -#if 0 + #if 1 const dim_t mr_cur = 1; bli_dgemmsup_rd_haswell_asm_1x8n + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, n0, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, beta, cij, rs_c0, cs_c0, data, cntx ); -#else + #else bli_dgemv_ex ( BLIS_TRANSPOSE, conja, k0, n0, alpha, bj, rs_b0, cs_b0, ai, cs_a0, beta, cij, cs_c0, cntx, NULL ); -#endif + #endif } return; } @@ -231,7 +217,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n //vzeroall() // zero all xmm/ymm registers. - mov(var(a), rdx) // load address of a. + //mov(var(a), rdx) // load address of a. mov(var(rs_a), r8) // load rs_a //mov(var(cs_a), r9) // load cs_a lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) @@ -247,11 +233,12 @@ void bli_dgemmsup_rd_haswell_asm_6x8n lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a //mov(var(c), r12) // load address of c - //mov(var(rs_c), rdi) // load rs_c - //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) @@ -267,12 +254,10 @@ void bli_dgemmsup_rd_haswell_asm_6x8n + mov(var(a), rdx) // load address of a mov(var(b), r14) // load address of b mov(var(c), r12) // load address of c - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) - lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii; imul(rdi, rsi) // rsi *= rs_c lea(mem(r12, rsi, 1), r12) // r12 = c + 3*ii*rs_c; @@ -309,19 +294,20 @@ void bli_dgemmsup_rd_haswell_asm_6x8n vxorpd(ymm15, ymm15, ymm15) #endif + lea(mem(r12), rcx) // rcx = c_iijj; lea(mem(rdx), rax) // rax = a_ii; lea(mem(r14), rbx) // rbx = b_jj; #if 1 - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) - prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c #endif - lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + //lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b @@ -338,13 +324,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n // ---------------------------------- iteration 0 -#if 0 - prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b - prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b - prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b - prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b - add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; -#else +#if 1 prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b #endif @@ -352,7 +332,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -386,7 +366,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -420,7 +400,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -455,7 +435,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -502,7 +482,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -551,7 +531,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) vmovsd(mem(rax, r8, 2), xmm2) - add(imm(1*8), rax) // a += 1*cs_b = 1*8; + add(imm(1*8), rax) // a += 1*cs_a = 1*8; vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -585,8 +565,6 @@ void bli_dgemmsup_rd_haswell_asm_6x8n label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 // ymm5 ymm8 ymm11 ymm14 @@ -594,14 +572,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8n vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) - vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vaddpd( xmm0, xmm1, xmm0 ) vhaddpd( ymm13, ymm10, ymm2 ) vextractf128(imm(1), ymm2, xmm1 ) - vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) - + // ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7) + // ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13) vhaddpd( ymm8, ymm5, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) @@ -612,7 +591,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8n vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) - + // ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8) + // ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14) vhaddpd( ymm9, ymm6, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) @@ -623,15 +603,14 @@ void bli_dgemmsup_rd_haswell_asm_6x8n vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // ymm6[0] = sum(ymm6); ymm6[1] = sum(ymm9) + // ymm6[2] = sum(ymm12); ymm6[3] = sum(ymm15) - // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) - // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) - // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta @@ -716,8 +695,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8n add(imm(3), r9) // ii += 3; - cmp(imm(3), r9) // compare ii to 3 - jle(.DLOOP3X4I) // if ii <= 3, jump to beginning + cmp(imm(6), r9) // compare ii to 6 + jl(.DLOOP3X4I) // if ii < 6, jump to beginning // of ii loop; otherwise, loop ends. @@ -748,7 +727,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -759,7 +738,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n consider_edge_cases: - // Handle edge cases in the m dimension, if they exist. + // Handle edge cases in the n dimension, if they exist. if ( n_left ) { const dim_t mr_cur = 6; @@ -774,6 +753,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n const dim_t nr_cur = 2; bli_dgemmsup_rd_haswell_asm_6x2 + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, nr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, @@ -783,24 +763,24 @@ void bli_dgemmsup_rd_haswell_asm_6x8n } if ( 1 == n_left ) { -#if 0 + #if 1 const dim_t nr_cur = 1; - //bli_dgemmsup_rd_haswell_asm_6x1n - bli_dgemmsup_r_haswell_ref + bli_dgemmsup_rd_haswell_asm_6x1 + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, nr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, beta, cij, rs_c0, cs_c0, data, cntx ); -#else + #else bli_dgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, beta, cij, rs_c0, cntx, NULL ); -#endif + #endif } } } @@ -865,11 +845,12 @@ void bli_dgemmsup_rd_haswell_asm_3x8n lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a mov(var(c), r12) // load address of c - //mov(var(rs_c), rdi) // load rs_c - //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) @@ -905,19 +886,20 @@ void bli_dgemmsup_rd_haswell_asm_3x8n vxorpd(ymm15, ymm15, ymm15) #endif + lea(mem(r12), rcx) // rcx = c_iijj; lea(mem(rdx), rax) // rax = a_ii; lea(mem(r14), rbx) // rbx = b_jj; #if 1 - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) - prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c #endif - lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + //lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b @@ -934,13 +916,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n // ---------------------------------- iteration 0 -#if 0 - prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b - prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b - prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b - prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b - add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; -#else +#if 1 prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b #endif @@ -948,7 +924,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -982,7 +958,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1016,7 +992,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1051,7 +1027,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1098,7 +1074,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1147,7 +1123,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) vmovsd(mem(rax, r8, 2), xmm2) - add(imm(1*8), rax) // a += 1*cs_b = 1*8; + add(imm(1*8), rax) // a += 1*cs_a = 1*8; vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1181,8 +1157,6 @@ void bli_dgemmsup_rd_haswell_asm_3x8n label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 // ymm5 ymm8 ymm11 ymm14 @@ -1190,14 +1164,15 @@ void bli_dgemmsup_rd_haswell_asm_3x8n vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) - vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vaddpd( xmm0, xmm1, xmm0 ) vhaddpd( ymm13, ymm10, ymm2 ) vextractf128(imm(1), ymm2, xmm1 ) - vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) - + // ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7) + // ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13) vhaddpd( ymm8, ymm5, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) @@ -1208,7 +1183,8 @@ void bli_dgemmsup_rd_haswell_asm_3x8n vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) - + // ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8) + // ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14) vhaddpd( ymm9, ymm6, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) @@ -1219,15 +1195,14 @@ void bli_dgemmsup_rd_haswell_asm_3x8n vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) - - // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) - // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) - // ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + // ymm6[0] = sum(ymm6); ymm6[1] = sum(ymm9) + // ymm6[2] = sum(ymm12); ymm6[3] = sum(ymm15) - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta @@ -1312,6 +1287,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n + label(.DRETURN) @@ -1337,7 +1313,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1348,7 +1324,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n consider_edge_cases: - // Handle edge cases in the m dimension, if they exist. + // Handle edge cases in the n dimension, if they exist. if ( n_left ) { const dim_t mr_cur = 3; @@ -1363,6 +1339,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n const dim_t nr_cur = 2; bli_dgemmsup_rd_haswell_asm_3x2 + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, nr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, @@ -1372,23 +1349,25 @@ void bli_dgemmsup_rd_haswell_asm_3x8n } if ( 1 == n_left ) { -#if 0 + #if 1 const dim_t nr_cur = 1; - bli_dgemmsup_r_haswell_ref_3x1 + bli_dgemmsup_rd_haswell_asm_3x1 + //bli_dgemmsup_r_haswell_ref_3x1 + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, nr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, beta, cij, rs_c0, cs_c0, data, cntx ); -#else + #else bli_dgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, beta, cij, rs_c0, cntx, NULL ); -#endif + #endif } } } @@ -1453,11 +1432,12 @@ void bli_dgemmsup_rd_haswell_asm_2x8n lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a mov(var(c), r12) // load address of c - //mov(var(rs_c), rdi) // load rs_c - //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) @@ -1489,18 +1469,19 @@ void bli_dgemmsup_rd_haswell_asm_2x8n vxorpd(ymm14, ymm14, ymm14) #endif + lea(mem(r12), rcx) // rcx = c_iijj; lea(mem(rdx), rax) // rax = a_ii; lea(mem(r14), rbx) // rbx = b_jj; #if 1 - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c #endif - lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + //lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b @@ -1517,20 +1498,14 @@ void bli_dgemmsup_rd_haswell_asm_2x8n // ---------------------------------- iteration 0 -#if 0 - prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b - prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b - prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b - prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b - add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; -#else +#if 1 prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b #endif vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1559,7 +1534,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1588,7 +1563,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1618,7 +1593,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1660,7 +1635,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1704,7 +1679,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) - add(imm(1*8), rax) // a += 1*cs_b = 1*8; + add(imm(1*8), rax) // a += 1*cs_a = 1*8; vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1734,23 +1709,21 @@ void bli_dgemmsup_rd_haswell_asm_2x8n label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 // ymm5 ymm8 ymm11 ymm14 - // ymm6 ymm9 ymm12 ymm15 vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) - vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vaddpd( xmm0, xmm1, xmm0 ) vhaddpd( ymm13, ymm10, ymm2 ) vextractf128(imm(1), ymm2, xmm1 ) - vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) - + // ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7) + // ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13) vhaddpd( ymm8, ymm5, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) @@ -1761,14 +1734,14 @@ void bli_dgemmsup_rd_haswell_asm_2x8n vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8) + // ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14) - // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) - // ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta @@ -1777,7 +1750,6 @@ void bli_dgemmsup_rd_haswell_asm_2x8n vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) - vmulpd(ymm0, ymm6, ymm6) @@ -1846,6 +1818,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n + label(.DRETURN) @@ -1871,7 +1844,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1882,7 +1855,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n consider_edge_cases: - // Handle edge cases in the m dimension, if they exist. + // Handle edge cases in the n dimension, if they exist. if ( n_left ) { const dim_t mr_cur = 2; @@ -1897,6 +1870,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n const dim_t nr_cur = 2; bli_dgemmsup_rd_haswell_asm_2x2 + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, nr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, @@ -1906,23 +1880,24 @@ void bli_dgemmsup_rd_haswell_asm_2x8n } if ( 1 == n_left ) { -#if 0 + #if 1 const dim_t nr_cur = 1; - bli_dgemmsup_r_haswell_ref_2x1 + bli_dgemmsup_rd_haswell_asm_2x1 + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, nr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, beta, cij, rs_c0, cs_c0, data, cntx ); -#else + #else bli_dgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, beta, cij, rs_c0, cntx, NULL ); -#endif + #endif } } } @@ -1987,11 +1962,12 @@ void bli_dgemmsup_rd_haswell_asm_1x8n lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a mov(var(c), r12) // load address of c - //mov(var(rs_c), rdi) // load rs_c - //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) @@ -2019,18 +1995,18 @@ void bli_dgemmsup_rd_haswell_asm_1x8n vxorpd(ymm13, ymm13, ymm13) #endif + lea(mem(r12), rcx) // rcx = c_iijj; lea(mem(rdx), rax) // rax = a_ii; lea(mem(r14), rbx) // rbx = b_jj; #if 1 - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) - prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c #endif - lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b + //lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b @@ -2047,19 +2023,13 @@ void bli_dgemmsup_rd_haswell_asm_1x8n // ---------------------------------- iteration 0 -#if 0 - prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b - prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b - prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b - prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b - add(imm(8*8), r10) // r10 += 8*rs_b = 8*8; -#else +#if 1 prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b #endif vmovupd(mem(rax ), ymm0) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -2083,7 +2053,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n #endif vmovupd(mem(rax ), ymm0) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -2107,7 +2077,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n #endif vmovupd(mem(rax ), ymm0) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -2132,7 +2102,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n #endif vmovupd(mem(rax ), ymm0) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -2169,7 +2139,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n label(.DLOOPKITER4) // EDGE LOOP (ymm) vmovupd(mem(rax ), ymm0) - add(imm(4*8), rax) // a += 4*cs_b = 4*8; + add(imm(4*8), rax) // a += 4*cs_a = 4*8; vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -2208,7 +2178,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n // which would destory intermediate results. vmovsd(mem(rax ), xmm0) - add(imm(1*8), rax) // a += 1*cs_b = 1*8; + add(imm(1*8), rax) // a += 1*cs_a = 1*8; vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -2234,29 +2204,26 @@ void bli_dgemmsup_rd_haswell_asm_1x8n label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 - // ymm5 ymm8 ymm11 ymm14 - // ymm6 ymm9 ymm12 ymm15 vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) - vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7) + vaddpd( xmm0, xmm1, xmm0 ) vhaddpd( ymm13, ymm10, ymm2 ) vextractf128(imm(1), ymm2, xmm1 ) - vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13) + vaddpd( xmm2, xmm1, xmm2 ) vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7) + // ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13) - // ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta @@ -2325,6 +2292,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n + label(.DRETURN) @@ -2350,7 +2318,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2361,7 +2329,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n consider_edge_cases: - // Handle edge cases in the m dimension, if they exist. + // Handle edge cases in the n dimension, if they exist. if ( n_left ) { const dim_t mr_cur = 1; @@ -2376,6 +2344,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n const dim_t nr_cur = 2; bli_dgemmsup_rd_haswell_asm_1x2 + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, nr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, @@ -2385,23 +2354,24 @@ void bli_dgemmsup_rd_haswell_asm_1x8n } if ( 1 == n_left ) { -#if 0 + #if 1 const dim_t nr_cur = 1; - bli_dgemmsup_r_haswell_ref_1x1 + bli_dgemmsup_rd_haswell_asm_1x1 + //bli_dgemmsup_r_haswell_ref ( conja, conjb, mr_cur, nr_cur, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, beta, cij, rs_c0, cs_c0, data, cntx ); -#else + #else bli_ddotxv_ex ( conja, conjb, k0, alpha, ai, cs_a0, bj, rs_b0, beta, cij, cntx, NULL ); -#endif + #endif } } } diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index 954eb1e28..1637e9766 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -80,23 +80,6 @@ // Prototype reference microkernels. GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) -#if 0 -// Define parameters and variables for edge case kernel map. -#define NUM_MR 4 -#define NUM_NR 4 -#define FUNCPTR_T dgemmsup_ker_ft - -static dim_t mrs[NUM_MR] = { 6, 4, 2, 1 }; -static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; -static FUNCPTR_T kmap[NUM_MR][NUM_NR] = -{ /* 8 4 2 1 */ -/* 6 */ { bli_dgemmsup_rv_haswell_asm_6x8m, bli_dgemmsup_rv_haswell_asm_6x4m, bli_dgemmsup_rv_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref_6x1 }, -/* 4 */ { bli_dgemmsup_rv_haswell_asm_4x8m, bli_dgemmsup_rv_haswell_asm_4x4m, bli_dgemmsup_rv_haswell_asm_4x2m, bli_dgemmsup_r_haswell_ref_4x1 }, -/* 2 */ { bli_dgemmsup_rv_haswell_asm_2x8m, bli_dgemmsup_rv_haswell_asm_2x4m, bli_dgemmsup_rv_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref_2x1 }, -/* 1 */ { bli_dgemmsup_rv_haswell_asm_1x8m, bli_dgemmsup_rv_haswell_asm_1x4m, bli_dgemmsup_rv_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref_1x1 }, -}; -#endif - void bli_dgemmsup_rv_haswell_asm_6x8m ( @@ -315,10 +298,10 @@ void bli_dgemmsup_rv_haswell_asm_6x8m lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c @@ -329,10 +312,10 @@ void bli_dgemmsup_rv_haswell_asm_6x8m lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; @@ -343,16 +326,15 @@ void bli_dgemmsup_rv_haswell_asm_6x8m #if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines lea(mem(rdx, r8, 2), rdx) // from next upanel of a. -#else - lea(mem(rax, r9, 8), rdx) // use rdx for prefetching a. - lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; - //mov(r9, rsi) // rsi = cs_a; - //sal(imm(4), rsi) // rsi = 16*cs_a; - //lea(mem(rax, rsi, 1), rdx) // rdx = a + 16*cs_a; + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; #endif @@ -369,9 +351,8 @@ void bli_dgemmsup_rv_haswell_asm_6x8m // ---------------------------------- iteration 0 -#if 1 +#if 0 prefetch(0, mem(rdx, 5*8)) - //prefetch(0, mem(rax, 5*8)) #else prefetch(0, mem(rdx, 5*8)) #endif @@ -405,10 +386,10 @@ void bli_dgemmsup_rv_haswell_asm_6x8m // ---------------------------------- iteration 1 -#if 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) - //prefetch(0, mem(rax, 5*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) #else + prefetch(0, mem(rdx, r9, 1, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -440,13 +421,10 @@ void bli_dgemmsup_rv_haswell_asm_6x8m // ---------------------------------- iteration 2 -#if 1 - prefetch(0, mem(rdx, r9, 2, 5*8)) - //prefetch(0, mem(rax, 5*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) #else prefetch(0, mem(rdx, r9, 2, 5*8)) - //prefetch(0, mem(rdx, r9, 2)) - //lea(mem(rdx, r9, 4), rdx) // rdx += 4*cs_a; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -478,11 +456,10 @@ void bli_dgemmsup_rv_haswell_asm_6x8m // ---------------------------------- iteration 3 -#if 1 - prefetch(0, mem(rdx, rcx, 1, 5*8)) - lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - //prefetch(0, mem(rax, 5*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) #else + prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; #endif @@ -531,6 +508,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -618,51 +600,51 @@ void bli_dgemmsup_rv_haswell_asm_6x8m label(.DROWSTORED) - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) - vmovupd(ymm7, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) - vmovupd(ymm9, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) - vmovupd(ymm11, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm13) - vmovupd(ymm13, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm15) - vmovupd(ymm15, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) //add(rdi, rcx) @@ -672,7 +654,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m label(.DCOLSTORED) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -684,11 +666,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx ), ymm3, ymm4) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, rax, 1)) @@ -700,18 +682,18 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx ), xmm3, xmm0) vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) lea(mem(rdx, rsi, 4), rdx) - + // begin I/O on columns 4-7 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -723,11 +705,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx ), ymm3, ymm5) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) - vmovupd(ymm5, mem(rcx)) + vmovupd(ymm5, mem(rcx )) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, rax, 1)) @@ -739,11 +721,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx ), xmm3, xmm0) vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) @@ -767,33 +749,33 @@ void bli_dgemmsup_rv_haswell_asm_6x8m label(.DROWSTORBZ) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm6, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 4)) + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm8, mem(rcx)) - vmovupd(ymm9, mem(rcx, rsi, 4)) + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm10, mem(rcx)) - vmovupd(ymm11, mem(rcx, rsi, 4)) + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm12, mem(rcx)) - vmovupd(ymm13, mem(rcx, rsi, 4)) + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm14, mem(rcx)) - vmovupd(ymm15, mem(rcx, rsi, 4)) + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(ymm15, mem(rcx, 1*32)) //add(rdi, rcx) @@ -803,7 +785,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -813,7 +795,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, rax, 1)) @@ -825,14 +807,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) lea(mem(rdx, rsi, 4), rdx) - + // begin I/O on columns 4-7 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -842,7 +824,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - vmovupd(ymm5, mem(rcx)) + vmovupd(ymm5, mem(rcx )) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, rax, 1)) @@ -854,7 +836,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) @@ -908,7 +890,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -978,7 +960,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8m } #endif -#if 1 dgemmsup_ker_ft ker_fps[6] = { NULL, @@ -999,67 +980,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8m ); return; -#else - if ( 5 <= m_left ) - { - const dim_t mr_cur = 5; - - bli_dgemmsup_rv_haswell_asm_5x8 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 4 <= m_left ) - { - const dim_t mr_cur = 4; - - bli_dgemmsup_rv_haswell_asm_4x8 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 3 <= m_left ) - { - const dim_t mr_cur = 3; - - bli_dgemmsup_rv_haswell_asm_3x8 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 2 <= m_left ) - { - const dim_t mr_cur = 2; - - bli_dgemmsup_rv_haswell_asm_2x8 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 1 == m_left ) - { - const dim_t mr_cur = 1; - - bli_dgemmsup_rv_haswell_asm_1x8 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - } -#endif } } @@ -1180,10 +1100,10 @@ void bli_dgemmsup_rv_haswell_asm_6x6m lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c @@ -1194,10 +1114,10 @@ void bli_dgemmsup_rv_haswell_asm_6x6m lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c @@ -1205,16 +1125,15 @@ void bli_dgemmsup_rv_haswell_asm_6x6m #if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines lea(mem(rdx, r8, 2), rdx) // from next upanel of a. -#else - lea(mem(rax, r9, 8), rdx) // use rdx for prefetching a. - lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; - //mov(r9, rsi) // rsi = cs_a; - //sal(imm(4), rsi) // rsi = 16*cs_a; - //lea(mem(rax, rsi, 1), rdx) // rdx = a + 16*cs_a; + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; #endif @@ -1231,9 +1150,8 @@ void bli_dgemmsup_rv_haswell_asm_6x6m // ---------------------------------- iteration 0 -#if 1 +#if 0 prefetch(0, mem(rdx, 5*8)) - //prefetch(0, mem(rax, 5*8)) #else prefetch(0, mem(rdx, 5*8)) #endif @@ -1267,10 +1185,10 @@ void bli_dgemmsup_rv_haswell_asm_6x6m // ---------------------------------- iteration 1 -#if 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) - //prefetch(0, mem(rax, 5*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) #else + prefetch(0, mem(rdx, r9, 1, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -1302,13 +1220,10 @@ void bli_dgemmsup_rv_haswell_asm_6x6m // ---------------------------------- iteration 2 -#if 1 - prefetch(0, mem(rdx, r9, 2, 5*8)) - //prefetch(0, mem(rax, 5*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) #else prefetch(0, mem(rdx, r9, 2, 5*8)) - //prefetch(0, mem(rdx, r9, 2)) - //lea(mem(rdx, r9, 4), rdx) // rdx += 4*cs_a; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -1340,11 +1255,10 @@ void bli_dgemmsup_rv_haswell_asm_6x6m // ---------------------------------- iteration 3 -#if 1 - prefetch(0, mem(rdx, rcx, 1, 5*8)) - lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - //prefetch(0, mem(rax, 5*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) #else + prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; #endif @@ -1394,6 +1308,11 @@ void bli_dgemmsup_rv_haswell_asm_6x6m label(.DLOOPKLEFT) // EDGE LOOP +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; @@ -1480,51 +1399,51 @@ void bli_dgemmsup_rv_haswell_asm_6x6m label(.DROWSTORED) - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm5) - vmovupd(xmm5, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm7) - vmovupd(xmm7, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm9) - vmovupd(xmm9, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm11) - vmovupd(xmm11, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm13) - vmovupd(xmm13, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) + vmovupd(xmm13, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), xmm3, xmm15) - vmovupd(xmm15, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm15) + vmovupd(xmm15, mem(rcx, 1*32)) //add(rdi, rcx) @@ -1534,7 +1453,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6m label(.DCOLSTORED) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1546,11 +1465,11 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx ), ymm3, ymm4) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, rax, 1)) @@ -1562,53 +1481,41 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx ), xmm3, xmm0) vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) lea(mem(rdx, rsi, 4), rdx) - + // begin I/O on columns 4-5 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) vunpckhpd(ymm11, ymm9, ymm3) vinsertf128(imm(0x1), xmm2, ymm0, ymm5) vinsertf128(imm(0x1), xmm3, ymm1, ymm7) - //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) - //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx ), ymm3, ymm5) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) - //vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) - //vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) - vmovupd(ymm5, mem(rcx)) + vmovupd(ymm5, mem(rcx )) vmovupd(ymm7, mem(rcx, rsi, 1)) - //vmovupd(ymm9, mem(rcx, rsi, 2)) - //vmovupd(ymm11, mem(rcx, rax, 1)) //lea(mem(rcx, rsi, 4), rcx) vunpcklpd(ymm15, ymm13, ymm0) vunpckhpd(ymm15, ymm13, ymm1) - //vextractf128(imm(0x1), ymm0, xmm2) - //vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx ), xmm3, xmm0) vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) - //vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) - //vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) - //vmovupd(xmm2, mem(rdx, rsi, 2)) - //vmovupd(xmm4, mem(rdx, rax, 1)) //lea(mem(rdx, rsi, 4), rdx) @@ -1629,33 +1536,33 @@ void bli_dgemmsup_rv_haswell_asm_6x6m label(.DROWSTORBZ) - vmovupd(ymm4, mem(rcx)) - vmovupd(xmm5, mem(rcx, rsi, 4)) + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm6, mem(rcx)) - vmovupd(xmm7, mem(rcx, rsi, 4)) + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm8, mem(rcx)) - vmovupd(xmm9, mem(rcx, rsi, 4)) + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(xmm9, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm10, mem(rcx)) - vmovupd(xmm11, mem(rcx, rsi, 4)) + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(xmm11, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm12, mem(rcx)) - vmovupd(xmm13, mem(rcx, rsi, 4)) + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(xmm13, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm14, mem(rcx)) - vmovupd(xmm15, mem(rcx, rsi, 4)) + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(xmm15, mem(rcx, 1*32)) //add(rdi, rcx) @@ -1665,7 +1572,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6m label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1675,7 +1582,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, rax, 1)) @@ -1687,39 +1594,31 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) lea(mem(rdx, rsi, 4), rdx) - + // begin I/O on columns 4-5 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) vunpckhpd(ymm11, ymm9, ymm3) vinsertf128(imm(0x1), xmm2, ymm0, ymm5) vinsertf128(imm(0x1), xmm3, ymm1, ymm7) - //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) - //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - vmovupd(ymm5, mem(rcx)) + vmovupd(ymm5, mem(rcx )) vmovupd(ymm7, mem(rcx, rsi, 1)) - //vmovupd(ymm9, mem(rcx, rsi, 2)) - //vmovupd(ymm11, mem(rcx, rax, 1)) //lea(mem(rcx, rsi, 4), rcx) vunpcklpd(ymm15, ymm13, ymm0) vunpckhpd(ymm15, ymm13, ymm1) - //vextractf128(imm(0x1), ymm0, xmm2) - //vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) - //vmovupd(xmm2, mem(rdx, rsi, 2)) - //vmovupd(xmm4, mem(rdx, rax, 1)) //lea(mem(rdx, rsi, 4), rdx) @@ -1770,7 +1669,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6m [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1840,7 +1739,6 @@ void bli_dgemmsup_rv_haswell_asm_6x6m } #endif -#if 1 dgemmsup_ker_ft ker_fps[6] = { NULL, @@ -1861,67 +1759,6 @@ void bli_dgemmsup_rv_haswell_asm_6x6m ); return; -#else - if ( 5 <= m_left ) - { - const dim_t mr_cur = 5; - - bli_dgemmsup_rv_haswell_asm_5x6 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 4 <= m_left ) - { - const dim_t mr_cur = 4; - - bli_dgemmsup_rv_haswell_asm_4x6 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 3 <= m_left ) - { - const dim_t mr_cur = 3; - - bli_dgemmsup_rv_haswell_asm_3x6 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 2 <= m_left ) - { - const dim_t mr_cur = 2; - - bli_dgemmsup_rv_haswell_asm_2x6 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 1 == m_left ) - { - const dim_t mr_cur = 1; - - bli_dgemmsup_rv_haswell_asm_1x6 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - } -#endif } } @@ -2008,6 +1845,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m label(.DLOOP6X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + #if 0 vzeroall() // zero all xmm/ymm registers. #else @@ -2027,26 +1865,17 @@ void bli_dgemmsup_rv_haswell_asm_6x4m mov(r14, rax) -#if 0 - lea(mem(rcx, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c -#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLPFETCH) // jump to column storage case label(.DROWPFETCH) // row-stored prefetching on c lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, 3*8)) // prefetch c + 0*rs_c prefetch(0, mem(r12, rdi, 1, 3*8)) // prefetch c + 1*rs_c prefetch(0, mem(r12, rdi, 2, 3*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c @@ -2057,24 +1886,24 @@ void bli_dgemmsup_rv_haswell_asm_6x4m lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c label(.DPOSTPFETCH) // done prefetching c -#endif - #if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines lea(mem(rdx, r8, 2), rdx) // from next upanel of a. - - //lea(mem(rax, r9, 8), rdx) // use rdx for prefetching a. - //lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; #endif @@ -2091,7 +1920,9 @@ void bli_dgemmsup_rv_haswell_asm_6x4m // ---------------------------------- iteration 0 -#if 1 +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else prefetch(0, mem(rdx, 5*8)) #endif @@ -2117,7 +1948,9 @@ void bli_dgemmsup_rv_haswell_asm_6x4m // ---------------------------------- iteration 1 -#if 1 +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else prefetch(0, mem(rdx, r9, 1, 5*8)) #endif @@ -2143,7 +1976,9 @@ void bli_dgemmsup_rv_haswell_asm_6x4m // ---------------------------------- iteration 2 -#if 1 +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else prefetch(0, mem(rdx, r9, 2, 5*8)) #endif @@ -2169,9 +2004,11 @@ void bli_dgemmsup_rv_haswell_asm_6x4m // ---------------------------------- iteration 3 -#if 1 +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else prefetch(0, mem(rdx, rcx, 1, 5*8)) - lea(mem(rdx, r9, 4), rdx) // rdx += 4*cs_a; + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -2215,7 +2052,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m #if 1 prefetch(0, mem(rdx, 5*8)) - add(r9, rdx) // rdx += cs_a; + add(r9, rdx) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -2291,33 +2128,33 @@ void bli_dgemmsup_rv_haswell_asm_6x4m label(.DROWSTORED) - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2327,7 +2164,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m label(.DCOLSTORED) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -2339,11 +2176,11 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx ), ymm3, ymm4) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, rax, 1)) @@ -2355,11 +2192,11 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx ), xmm3, xmm0) vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) @@ -2383,24 +2220,27 @@ void bli_dgemmsup_rv_haswell_asm_6x4m label(.DROWSTORBZ) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx, 0*32)) add(rdi, rcx) - vmovupd(ymm6, mem(rcx)) + + vmovupd(ymm6, mem(rcx, 0*32)) add(rdi, rcx) - vmovupd(ymm8, mem(rcx)) + vmovupd(ymm8, mem(rcx, 0*32)) add(rdi, rcx) + - vmovupd(ymm10, mem(rcx)) + vmovupd(ymm10, mem(rcx, 0*32)) add(rdi, rcx) - vmovupd(ymm12, mem(rcx)) + vmovupd(ymm12, mem(rcx, 0*32)) add(rdi, rcx) + - vmovupd(ymm14, mem(rcx)) + vmovupd(ymm14, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2410,7 +2250,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -2420,7 +2260,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, rax, 1)) @@ -2432,7 +2272,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) @@ -2486,7 +2326,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2556,7 +2396,6 @@ void bli_dgemmsup_rv_haswell_asm_6x4m } #endif -#if 1 dgemmsup_ker_ft ker_fps[6] = { NULL, @@ -2577,67 +2416,6 @@ void bli_dgemmsup_rv_haswell_asm_6x4m ); return; -#else - if ( 5 <= m_left ) - { - const dim_t mr_cur = 5; - - bli_dgemmsup_rv_haswell_asm_5x4 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 4 <= m_left ) - { - const dim_t mr_cur = 4; - - bli_dgemmsup_rv_haswell_asm_4x4 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 3 <= m_left ) - { - const dim_t mr_cur = 3; - - bli_dgemmsup_rv_haswell_asm_3x4 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 2 <= m_left ) - { - const dim_t mr_cur = 2; - - bli_dgemmsup_rv_haswell_asm_2x4 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 1 == m_left ) - { - const dim_t mr_cur = 1; - - bli_dgemmsup_rv_haswell_asm_1x4 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - } -#endif } } @@ -2724,6 +2502,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DLOOP6X2I) // LOOP OVER ii = [ m_iter ... 1 0 ] + #if 0 vzeroall() // zero all xmm/ymm registers. #else @@ -2743,26 +2522,17 @@ void bli_dgemmsup_rv_haswell_asm_6x2m mov(r14, rax) -#if 0 - lea(mem(rcx, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c -#else + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLPFETCH) // jump to column storage case label(.DROWPFETCH) // row-stored prefetching on c lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, 1*8)) // prefetch c + 0*rs_c prefetch(0, mem(r12, rdi, 1, 1*8)) // prefetch c + 1*rs_c prefetch(0, mem(r12, rdi, 2, 1*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c prefetch(0, mem(rdx, rdi, 2, 1*8)) // prefetch c + 5*rs_c @@ -2771,25 +2541,27 @@ void bli_dgemmsup_rv_haswell_asm_6x2m mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c label(.DPOSTPFETCH) // done prefetching c -#endif #if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines lea(mem(rdx, r8, 2), rdx) // from next upanel of a. - - //lea(mem(rax, r9, 8), rdx) // use rdx for prefetching a. - //lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; #endif + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that @@ -2801,7 +2573,9 @@ void bli_dgemmsup_rv_haswell_asm_6x2m // ---------------------------------- iteration 0 -#if 1 +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else prefetch(0, mem(rdx, 5*8)) #endif @@ -2827,8 +2601,10 @@ void bli_dgemmsup_rv_haswell_asm_6x2m // ---------------------------------- iteration 1 -#if 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) #endif vmovupd(mem(rbx, 0*32), xmm0) @@ -2853,7 +2629,9 @@ void bli_dgemmsup_rv_haswell_asm_6x2m // ---------------------------------- iteration 2 -#if 1 +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else prefetch(0, mem(rdx, r9, 2, 5*8)) #endif @@ -2879,9 +2657,11 @@ void bli_dgemmsup_rv_haswell_asm_6x2m // ---------------------------------- iteration 3 -#if 1 +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else prefetch(0, mem(rdx, rcx, 1, 5*8)) - lea(mem(rdx, r9, 4), rdx) // rdx += 4*cs_a; + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; #endif vmovupd(mem(rbx, 0*32), xmm0) @@ -2925,7 +2705,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m #if 1 prefetch(0, mem(rdx, 5*8)) - add(r9, rdx) // rdx += cs_a; + add(r9, rdx) #endif vmovupd(mem(rbx, 0*32), xmm0) @@ -3001,33 +2781,33 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DROWSTORED) - vfmadd231pd(mem(rcx), xmm3, xmm4) - vmovupd(xmm4, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), xmm3, xmm6) - vmovupd(xmm6, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), xmm3, xmm8) - vmovupd(xmm8, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) + vmovupd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), xmm3, xmm10) - vmovupd(xmm10, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) + vmovupd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), xmm3, xmm12) - vmovupd(xmm12, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) + vmovupd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), xmm3, xmm14) - vmovupd(xmm14, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) + vmovupd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) @@ -3037,7 +2817,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DCOLSTORED) - + // begin I/O on columns 0-3 vunpcklpd(xmm6, xmm4, xmm0) vunpckhpd(xmm6, xmm4, xmm1) vunpcklpd(xmm10, xmm8, xmm2) @@ -3047,9 +2827,9 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx ), ymm3, ymm4) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) //lea(mem(rcx, rsi, 4), rcx) @@ -3057,9 +2837,9 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vunpcklpd(xmm14, xmm12, xmm0) vunpckhpd(xmm14, xmm12, xmm1) - vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx ), xmm3, xmm0) vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) //lea(mem(rdx, rsi, 4), rdx) @@ -3081,24 +2861,27 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DROWSTORBZ) - vmovupd(xmm4, mem(rcx)) + vmovupd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - vmovupd(xmm6, mem(rcx)) + + vmovupd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - vmovupd(xmm8, mem(rcx)) + vmovupd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - vmovupd(xmm10, mem(rcx)) + + vmovupd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - vmovupd(xmm12, mem(rcx)) + vmovupd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - vmovupd(xmm14, mem(rcx)) + + vmovupd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) @@ -3108,7 +2891,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(xmm6, xmm4, xmm0) vunpckhpd(xmm6, xmm4, xmm1) vunpcklpd(xmm10, xmm8, xmm2) @@ -3116,7 +2899,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vinsertf128(imm(0x1), xmm2, ymm0, ymm4) vinsertf128(imm(0x1), xmm3, ymm1, ymm6) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) //lea(mem(rcx, rsi, 4), rcx) @@ -3124,7 +2907,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vunpcklpd(xmm14, xmm12, xmm0) vunpckhpd(xmm14, xmm12, xmm1) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) //lea(mem(rdx, rsi, 4), rdx) @@ -3176,7 +2959,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -3246,7 +3029,6 @@ void bli_dgemmsup_rv_haswell_asm_6x2m } #endif -#if 1 dgemmsup_ker_ft ker_fps[6] = { NULL, @@ -3267,67 +3049,6 @@ void bli_dgemmsup_rv_haswell_asm_6x2m ); return; -#else - if ( 5 <= m_left ) - { - const dim_t mr_cur = 5; - - bli_dgemmsup_rv_haswell_asm_5x2 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 4 <= m_left ) - { - const dim_t mr_cur = 4; - - bli_dgemmsup_rv_haswell_asm_4x2 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 3 <= m_left ) - { - const dim_t mr_cur = 3; - - bli_dgemmsup_rv_haswell_asm_3x2 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 2 <= m_left ) - { - const dim_t mr_cur = 2; - - bli_dgemmsup_rv_haswell_asm_2x2 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 1 == m_left ) - { - const dim_t mr_cur = 1; - - bli_dgemmsup_rv_haswell_asm_1x2 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - } -#endif } } diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c index 0740b7b26..e634fd053 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c @@ -80,23 +80,6 @@ // Prototype reference microkernels. GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) -#if 0 -// Define parameters and variables for edge case kernel map. -#define NUM_MR 4 -#define NUM_NR 4 -#define FUNCPTR_T dgemmsup_ker_ft - -static dim_t mrs[NUM_MR] = { 6, 4, 2, 1 }; -static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; -static FUNCPTR_T kmap[NUM_MR][NUM_NR] = -{ /* 8 4 2 1 */ -/* 6 */ { bli_dgemmsup_rv_haswell_asm_6x8n, bli_dgemmsup_rv_haswell_asm_6x4n, bli_dgemmsup_rv_haswell_asm_6x2n, bli_dgemmsup_r_haswell_ref_6x1 }, -/* 4 */ { bli_dgemmsup_rv_haswell_asm_4x8n, bli_dgemmsup_rv_haswell_asm_4x4n, bli_dgemmsup_rv_haswell_asm_4x2n, bli_dgemmsup_r_haswell_ref_4x1 }, -/* 2 */ { bli_dgemmsup_rv_haswell_asm_2x8n, bli_dgemmsup_rv_haswell_asm_2x4n, bli_dgemmsup_rv_haswell_asm_2x2n, bli_dgemmsup_r_haswell_ref_2x1 }, -/* 1 */ { bli_dgemmsup_rv_haswell_asm_1x8n, bli_dgemmsup_rv_haswell_asm_1x4n, bli_dgemmsup_rv_haswell_asm_1x2n, bli_dgemmsup_r_haswell_ref_1x1 }, -}; -#endif - void bli_dgemmsup_rv_haswell_asm_6x8n ( @@ -116,16 +99,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8n { uint64_t m_left = m0 % 6; -#if 0 - bli_dgemmsup_r_haswell_ref - ( - conja, conjb, m0, n0, k0, - alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, - beta, c, rs_c0, cs_c0, data, cntx - ); return; -#endif - -//printf( "rv_6x8n: %d %d %d\n", (int)m0, (int)n0, (int)k0 ); // First check whether this is a edge case in the m dimension. If so, // dispatch other ?x8m kernels, as needed. if ( m_left ) @@ -181,7 +154,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8n } #endif -#if 1 dgemmsup_ker_ft ker_fps[6] = { NULL, @@ -202,77 +174,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8n ); return; -#else - if ( 5 <= m_left ) - { - const dim_t mr_cur = 5; - - bli_dgemmsup_rv_haswell_asm_5x8n - ( - conja, conjb, mr_cur, n0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 4 <= m_left ) - { - const dim_t mr_cur = 4; - - bli_dgemmsup_rv_haswell_asm_4x8n - ( - conja, conjb, mr_cur, n0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 3 <= m_left ) - { - const dim_t mr_cur = 3; - - bli_dgemmsup_rv_haswell_asm_3x8n - ( - conja, conjb, mr_cur, n0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 2 <= m_left ) - { - const dim_t mr_cur = 2; - - bli_dgemmsup_rv_haswell_asm_2x8n - ( - conja, conjb, mr_cur, n0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 1 == m_left ) - { -#if 1 - const dim_t mr_cur = 1; - - bli_dgemmsup_rv_haswell_asm_1x8n - ( - conja, conjb, mr_cur, n0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); -#else - bli_dgemv_ex - ( - BLIS_TRANSPOSE, conja, k0, n0, - alpha, bj, rs_b0, cs_b0, ai, cs_a0, - beta, cij, cs_c0, cntx, NULL - ); -#endif - } - return; -#endif } //void* a_next = bli_auxinfo_next_a( data ); @@ -375,10 +276,10 @@ void bli_dgemmsup_rv_haswell_asm_6x8n lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c @@ -389,10 +290,10 @@ void bli_dgemmsup_rv_haswell_asm_6x8n lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; @@ -402,9 +303,15 @@ void bli_dgemmsup_rv_haswell_asm_6x8n label(.DPOSTPFETCH) // done prefetching c #if 1 - // use byte offsets from rbx to - // prefetch lines from next upanel - // of b. + mov(var(ps_b8), rdx) // load ps_b8 + lea(mem(rbx, rdx, 1), rdx) // rdx = b + ps_b8 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; #endif @@ -422,12 +329,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8n // ---------------------------------- iteration 0 #if 1 - //prefetch(0, mem(rdx, 11*8)) // prefetch line of next upanel of b - prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b + prefetch(0, mem(rdx, 5*8)) #else prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; @@ -457,9 +363,8 @@ void bli_dgemmsup_rv_haswell_asm_6x8n // ---------------------------------- iteration 1 -#if 1 - //prefetch(0, mem(rdx, r10, 1, 11*8)) - prefetch(0, mem(rbx, 11*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) #else prefetch(0, mem(rdx, r10, 1, 5*8)) #endif @@ -493,13 +398,12 @@ void bli_dgemmsup_rv_haswell_asm_6x8n // ---------------------------------- iteration 2 -#if 1 - //prefetch(0, mem(rdx, r10, 2, 11*8)) - prefetch(0, mem(rbx, 11*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) #else prefetch(0, mem(rdx, r10, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; @@ -529,13 +433,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8n // ---------------------------------- iteration 3 -#if 1 - //prefetch(0, mem(rdx, rcx, 1, 11*8)) - prefetch(0, mem(rbx, 11*8)) - //prefetch(0, mem(rdx, r9, 1, 7*8)) - //lea(mem(rdx, r10, 4), rdx) // a_prefetch += 4*cs_a; +#if 0 + prefetch(0, mem(rdx, 5*8)) #else prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -585,7 +487,8 @@ void bli_dgemmsup_rv_haswell_asm_6x8n label(.DLOOPKLEFT) // EDGE LOOP #if 1 - prefetch(0, mem(rbx, 11*8)) + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) // b_prefetch += rs_b; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -674,51 +577,51 @@ void bli_dgemmsup_rv_haswell_asm_6x8n label(.DROWSTORED) - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) - vmovupd(ymm7, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) - vmovupd(ymm9, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) - vmovupd(ymm11, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm13) - vmovupd(ymm13, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm15) - vmovupd(ymm15, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) //add(rdi, rcx) @@ -728,7 +631,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8n label(.DCOLSTORED) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -740,11 +643,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8n vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx ), ymm3, ymm4) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, rax, 1)) @@ -756,18 +659,18 @@ void bli_dgemmsup_rv_haswell_asm_6x8n vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx ), xmm3, xmm0) vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) lea(mem(rdx, rsi, 4), rdx) - + // begin I/O on columns 4-7 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -779,11 +682,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8n vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx ), ymm3, ymm5) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) - vmovupd(ymm5, mem(rcx)) + vmovupd(ymm5, mem(rcx )) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, rax, 1)) @@ -795,11 +698,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8n vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(rdx), xmm3, xmm0) + vfmadd231pd(mem(rdx ), xmm3, xmm0) vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) @@ -823,33 +726,33 @@ void bli_dgemmsup_rv_haswell_asm_6x8n label(.DROWSTORBZ) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm6, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 4)) + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm8, mem(rcx)) - vmovupd(ymm9, mem(rcx, rsi, 4)) + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm10, mem(rcx)) - vmovupd(ymm11, mem(rcx, rsi, 4)) + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm12, mem(rcx)) - vmovupd(ymm13, mem(rcx, rsi, 4)) + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm14, mem(rcx)) - vmovupd(ymm15, mem(rcx, rsi, 4)) + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(ymm15, mem(rcx, 1*32)) //add(rdi, rcx) @@ -859,7 +762,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8n label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -869,7 +772,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8n vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, rax, 1)) @@ -881,14 +784,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8n vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) lea(mem(rdx, rsi, 4), rdx) - + // begin I/O on columns 4-7 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -898,7 +801,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8n vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - vmovupd(ymm5, mem(rcx)) + vmovupd(ymm5, mem(rcx )) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, rax, 1)) @@ -910,7 +813,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8n vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(rdx)) + vmovupd(xmm0, mem(rdx )) vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) @@ -929,7 +832,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8n //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b mov(var(ps_b8), rbx) // load ps_b8 - lea(mem(r14, rbx, 1), r14) // a_ii = r14 += ps_b8 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b8 dec(r11) // jj -= 1; jne(.DLOOP6X8J) // iterate again if jj != 0. @@ -962,7 +865,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8n [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1011,7 +914,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8n } if ( 1 == n_left ) { -#if 1 + #if 1 const dim_t nr_cur = 1; bli_dgemmsup_r_haswell_ref_6x1 @@ -1020,14 +923,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8n alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, beta, cij, rs_c0, cs_c0, data, cntx ); -#else + #else bli_dgemv_ex ( BLIS_NO_TRANSPOSE, conjb, m0, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, beta, cij, rs_c0, cntx, NULL ); -#endif + #endif } } } @@ -1132,8 +1035,6 @@ void bli_dgemmsup_rv_haswell_asm_5x8n vxorpd(ymm11, ymm11, ymm11) vxorpd(ymm12, ymm12, ymm12) vxorpd(ymm13, ymm13, ymm13) - //vxorpd(ymm14, ymm14, ymm14) - //vxorpd(ymm15, ymm15, ymm15) #endif mov(var(a), rax) // load address of a. @@ -1148,10 +1049,10 @@ void bli_dgemmsup_rv_haswell_asm_5x8n lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c jmp(.DPOSTPFETCH) // jump to end of prefetching c @@ -1161,10 +1062,10 @@ void bli_dgemmsup_rv_haswell_asm_5x8n lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, 4*8)) // prefetch c + 0*cs_c prefetch(0, mem(r12, rsi, 1, 4*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, rsi, 2, 4*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; @@ -1174,20 +1075,15 @@ void bli_dgemmsup_rv_haswell_asm_5x8n label(.DPOSTPFETCH) // done prefetching c #if 1 - - // use byte offsets from rbx to - // prefetch lines from next upanel - // of b. -#else + mov(var(ps_b8), rdx) // load ps_b8 + lea(mem(rbx, rdx, 1), rdx) // rdx = b + ps_b8 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; - lea(mem(rbx, r10, 8), rdx) // use rdx for prefetching b. - lea(mem(rdx, r10, 8), rdx) // rdx = b + 16*rs_b; - - #if 0 - mov(r9, rsi) // rsi = rs_b; - sal(imm(5), rsi) // rsi = 16*rs_b; - lea(mem(rax, rsi, 1), rdx) // rdx = b + 16*rs_b; - #endif #endif @@ -1204,9 +1100,8 @@ void bli_dgemmsup_rv_haswell_asm_5x8n // ---------------------------------- iteration 0 -#if 1 - //prefetch(0, mem(rdx, 11*8)) // prefetch line of next upanel of b - prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b +#if 0 + prefetch(0, mem(rdx, 5*8)) #else prefetch(0, mem(rdx, 5*8)) #endif @@ -1237,9 +1132,8 @@ void bli_dgemmsup_rv_haswell_asm_5x8n // ---------------------------------- iteration 1 -#if 1 - //prefetch(0, mem(rdx, r10, 1, 11*8)) - prefetch(0, mem(rbx, 11*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) #else prefetch(0, mem(rdx, r10, 1, 5*8)) #endif @@ -1270,9 +1164,8 @@ void bli_dgemmsup_rv_haswell_asm_5x8n // ---------------------------------- iteration 2 -#if 1 - //prefetch(0, mem(rdx, r10, 2, 11*8)) - prefetch(0, mem(rbx, 11*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) #else prefetch(0, mem(rdx, r10, 2, 5*8)) #endif @@ -1303,13 +1196,11 @@ void bli_dgemmsup_rv_haswell_asm_5x8n // ---------------------------------- iteration 3 -#if 1 - //prefetch(0, mem(rdx, rcx, 1, 11*8)) - prefetch(0, mem(rbx, 11*8)) - //prefetch(0, mem(rdx, r9, 1, 7*8)) - //lea(mem(rdx, r10, 4), rdx) // a_prefetch += 4*cs_a; +#if 0 + prefetch(0, mem(rdx, 5*8)) #else prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -1356,7 +1247,8 @@ void bli_dgemmsup_rv_haswell_asm_5x8n label(.DLOOPKLEFT) // EDGE LOOP #if 1 - prefetch(0, mem(rbx, 11*8)) + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) // b_prefetch += rs_b; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -1440,43 +1332,43 @@ void bli_dgemmsup_rv_haswell_asm_5x8n label(.DROWSTORED) - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) - vmovupd(ymm7, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) - vmovupd(ymm9, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) - vmovupd(ymm11, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm13) - vmovupd(ymm13, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) //add(rdi, rcx) @@ -1486,7 +1378,7 @@ void bli_dgemmsup_rv_haswell_asm_5x8n label(.DCOLSTORED) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1498,33 +1390,18 @@ void bli_dgemmsup_rv_haswell_asm_5x8n vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx ), ymm3, ymm4) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, rax, 1)) lea(mem(rcx, rsi, 4), rcx) -#if 0 - vunpcklpd(ymm14, ymm12, ymm0) - vunpckhpd(ymm14, ymm12, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vfmadd231pd(mem(rdx), xmm3, xmm0) - vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) - vmovupd(xmm0, mem(rdx)) - vmovupd(xmm1, mem(rdx, rsi, 1)) - vmovupd(xmm2, mem(rdx, rsi, 2)) - vmovupd(xmm4, mem(rdx, rax, 1)) -#else - vmovlpd(mem(rdx), xmm0, xmm0) + vmovlpd(mem(rdx ), xmm0, xmm0) vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) @@ -1532,15 +1409,14 @@ void bli_dgemmsup_rv_haswell_asm_5x8n vfmadd213pd(ymm12, ymm3, ymm0) vextractf128(imm(1), ymm0, xmm1) - vmovlpd(xmm0, mem(rdx)) + vmovlpd(xmm0, mem(rdx )) vmovhpd(xmm0, mem(rdx, rsi, 1)) vmovlpd(xmm1, mem(rdx, rsi, 2)) vmovhpd(xmm1, mem(rdx, rax, 1)) -#endif lea(mem(rdx, rsi, 4), rdx) - + // begin I/O on columns 4-7 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -1552,33 +1428,18 @@ void bli_dgemmsup_rv_haswell_asm_5x8n vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx ), ymm3, ymm5) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) - vmovupd(ymm5, mem(rcx)) + vmovupd(ymm5, mem(rcx )) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, rax, 1)) //lea(mem(rcx, rsi, 4), rcx) -#if 0 - vunpcklpd(ymm15, ymm13, ymm0) - vunpckhpd(ymm15, ymm13, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vfmadd231pd(mem(rdx), xmm3, xmm0) - vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) - vmovupd(xmm0, mem(rdx)) - vmovupd(xmm1, mem(rdx, rsi, 1)) - vmovupd(xmm2, mem(rdx, rsi, 2)) - vmovupd(xmm4, mem(rdx, rax, 1)) -#else - vmovlpd(mem(rdx), xmm0, xmm0) + vmovlpd(mem(rdx ), xmm0, xmm0) vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) @@ -1586,11 +1447,10 @@ void bli_dgemmsup_rv_haswell_asm_5x8n vfmadd213pd(ymm13, ymm3, ymm0) vextractf128(imm(1), ymm0, xmm1) - vmovlpd(xmm0, mem(rdx)) + vmovlpd(xmm0, mem(rdx )) vmovhpd(xmm0, mem(rdx, rsi, 1)) vmovlpd(xmm1, mem(rdx, rsi, 2)) vmovhpd(xmm1, mem(rdx, rax, 1)) -#endif //lea(mem(rdx, rsi, 4), rdx) @@ -1611,28 +1471,28 @@ void bli_dgemmsup_rv_haswell_asm_5x8n label(.DROWSTORBZ) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm6, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 4)) + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm8, mem(rcx)) - vmovupd(ymm9, mem(rcx, rsi, 4)) + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm10, mem(rcx)) - vmovupd(ymm11, mem(rcx, rsi, 4)) + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm12, mem(rcx)) - vmovupd(ymm13, mem(rcx, rsi, 4)) + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) //add(rdi, rcx) @@ -1642,7 +1502,7 @@ void bli_dgemmsup_rv_haswell_asm_5x8n label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1652,36 +1512,24 @@ void bli_dgemmsup_rv_haswell_asm_5x8n vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, rax, 1)) lea(mem(rcx, rsi, 4), rcx) -#if 0 - vunpcklpd(ymm14, ymm12, ymm0) - vunpckhpd(ymm14, ymm12, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vmovupd(xmm0, mem(rdx)) - vmovupd(xmm1, mem(rdx, rsi, 1)) - vmovupd(xmm2, mem(rdx, rsi, 2)) - vmovupd(xmm4, mem(rdx, rax, 1)) -#else vmovupd(ymm12, ymm0) vextractf128(imm(1), ymm0, xmm1) - vmovlpd(xmm0, mem(rdx)) + vmovlpd(xmm0, mem(rdx )) vmovhpd(xmm0, mem(rdx, rsi, 1)) vmovlpd(xmm1, mem(rdx, rsi, 2)) vmovhpd(xmm1, mem(rdx, rax, 1)) -#endif lea(mem(rdx, rsi, 4), rdx) - + // begin I/O on columns 4-7 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -1691,32 +1539,20 @@ void bli_dgemmsup_rv_haswell_asm_5x8n vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - vmovupd(ymm5, mem(rcx)) + vmovupd(ymm5, mem(rcx )) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, rax, 1)) //lea(mem(rcx, rsi, 4), rcx) -#if 0 - vunpcklpd(ymm15, ymm13, ymm0) - vunpckhpd(ymm15, ymm13, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vmovupd(xmm0, mem(rdx)) - vmovupd(xmm1, mem(rdx, rsi, 1)) - vmovupd(xmm2, mem(rdx, rsi, 2)) - vmovupd(xmm4, mem(rdx, rax, 1)) -#else vmovupd(ymm13, ymm0) vextractf128(imm(1), ymm0, xmm1) - vmovlpd(xmm0, mem(rdx)) + vmovlpd(xmm0, mem(rdx )) vmovhpd(xmm0, mem(rdx, rsi, 1)) vmovlpd(xmm1, mem(rdx, rsi, 2)) vmovhpd(xmm1, mem(rdx, rax, 1)) -#endif //lea(mem(rdx, rsi, 4), rdx) @@ -1732,7 +1568,7 @@ void bli_dgemmsup_rv_haswell_asm_5x8n //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b mov(var(ps_b8), rbx) // load ps_b8 - lea(mem(r14, rbx, 1), r14) // a_ii = r14 += ps_b8 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b8 dec(r11) // jj -= 1; jne(.DLOOP6X8J) // iterate again if jj != 0. @@ -1765,7 +1601,7 @@ void bli_dgemmsup_rv_haswell_asm_5x8n [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1814,7 +1650,7 @@ void bli_dgemmsup_rv_haswell_asm_5x8n } if ( 1 == n_left ) { -#if 1 + #if 1 const dim_t nr_cur = 1; bli_dgemmsup_r_haswell_ref_5x1 @@ -1823,14 +1659,14 @@ void bli_dgemmsup_rv_haswell_asm_5x8n alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, beta, cij, rs_c0, cs_c0, data, cntx ); -#else + #else bli_dgemv_ex ( BLIS_NO_TRANSPOSE, conjb, m0, k0, alpha, ai, rs_a0, cs_a0, bj, rs_b0, beta, cij, rs_c0, cntx, NULL ); -#endif + #endif } } } @@ -1947,10 +1783,10 @@ void bli_dgemmsup_rv_haswell_asm_4x8n lea(mem(r12, rdi, 2), rdx) // lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c jmp(.DPOSTPFETCH) // jump to end of prefetching c label(.DCOLPFETCH) // column-stored prefetching c @@ -1959,10 +1795,10 @@ void bli_dgemmsup_rv_haswell_asm_4x8n lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(r12, rsi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, rsi, 2, 3*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; @@ -1972,11 +1808,15 @@ void bli_dgemmsup_rv_haswell_asm_4x8n label(.DPOSTPFETCH) // done prefetching c #if 1 - //lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; - - // use byte offsets from rbx to - // prefetch lines from next upanel - // of b. + mov(var(ps_b8), rdx) // load ps_b8 + lea(mem(rbx, rdx, 1), rdx) // rdx = b + ps_b8 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; #endif @@ -1993,9 +1833,10 @@ void bli_dgemmsup_rv_haswell_asm_4x8n // ---------------------------------- iteration 0 -#if 1 - //prefetch(0, mem(rdx, 11*8)) // prefetch line of next upanel of b - prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -2020,9 +1861,10 @@ void bli_dgemmsup_rv_haswell_asm_4x8n // ---------------------------------- iteration 1 -#if 1 - //prefetch(0, mem(rdx, r10, 1, 11*8)) - prefetch(0, mem(rbx, 11*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -2047,9 +1889,10 @@ void bli_dgemmsup_rv_haswell_asm_4x8n // ---------------------------------- iteration 2 -#if 1 - //prefetch(0, mem(rdx, r10, 2, 11*8)) - prefetch(0, mem(rbx, 11*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -2075,9 +1918,10 @@ void bli_dgemmsup_rv_haswell_asm_4x8n // ---------------------------------- iteration 3 #if 1 - //prefetch(0, mem(rdx, rcx, 1, 11*8)) - //lea(mem(rdx, r10, 4), rdx) // a_prefetch += 4*cs_a; - prefetch(0, mem(rbx, 11*8)) + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -2120,7 +1964,8 @@ void bli_dgemmsup_rv_haswell_asm_4x8n label(.DLOOPKLEFT) // EDGE LOOP #if 1 - prefetch(0, mem(rbx, 11*8)) + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) // b_prefetch += rs_b; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -2198,35 +2043,35 @@ void bli_dgemmsup_rv_haswell_asm_4x8n label(.DROWSTORED) - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) - vmovupd(ymm7, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) - vmovupd(ymm9, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm11) - vmovupd(ymm11, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) //add(rdi, rcx) @@ -2236,7 +2081,7 @@ void bli_dgemmsup_rv_haswell_asm_4x8n label(.DCOLSTORED) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -2248,18 +2093,18 @@ void bli_dgemmsup_rv_haswell_asm_4x8n vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx ), ymm3, ymm4) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, rax, 1)) lea(mem(rcx, rsi, 4), rcx) - + // begin I/O on columns 4-7 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -2271,11 +2116,11 @@ void bli_dgemmsup_rv_haswell_asm_4x8n vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx ), ymm3, ymm5) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) - vmovupd(ymm5, mem(rcx)) + vmovupd(ymm5, mem(rcx )) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, rax, 1)) @@ -2299,20 +2144,20 @@ void bli_dgemmsup_rv_haswell_asm_4x8n label(.DROWSTORBZ) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm6, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 4)) + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm8, mem(rcx)) - vmovupd(ymm9, mem(rcx, rsi, 4)) + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm10, mem(rcx)) - vmovupd(ymm11, mem(rcx, rsi, 4)) + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) //add(rdi, rcx) @@ -2322,7 +2167,7 @@ void bli_dgemmsup_rv_haswell_asm_4x8n label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -2332,14 +2177,14 @@ void bli_dgemmsup_rv_haswell_asm_4x8n vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - vmovupd(ymm4, mem(rcx)) + vmovupd(ymm4, mem(rcx )) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, rax, 1)) lea(mem(rcx, rsi, 4), rcx) - + // begin I/O on columns 4-7 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -2349,7 +2194,7 @@ void bli_dgemmsup_rv_haswell_asm_4x8n vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - vmovupd(ymm5, mem(rcx)) + vmovupd(ymm5, mem(rcx )) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, rax, 1)) @@ -2368,7 +2213,7 @@ void bli_dgemmsup_rv_haswell_asm_4x8n //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b mov(var(ps_b8), rbx) // load ps_b8 - lea(mem(r14, rbx, 1), r14) // a_ii = r14 += ps_b8 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b8 dec(r11) // jj -= 1; jne(.DLOOP4X8J) // iterate again if jj != 0. @@ -2401,7 +2246,7 @@ void bli_dgemmsup_rv_haswell_asm_4x8n [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2578,7 +2423,7 @@ void bli_dgemmsup_rv_haswell_asm_3x8n //lea(mem(r12, rdi, 2), rdx) // //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c @@ -2589,10 +2434,10 @@ void bli_dgemmsup_rv_haswell_asm_3x8n lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, 2*8)) // prefetch c + 0*cs_c prefetch(0, mem(r12, rsi, 1, 2*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, rsi, 2, 2*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; @@ -2602,11 +2447,15 @@ void bli_dgemmsup_rv_haswell_asm_3x8n label(.DPOSTPFETCH) // done prefetching c #if 1 - //lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; - - // use byte offsets from rbx to - // prefetch lines from next upanel - // of b. + mov(var(ps_b8), rdx) // load ps_b8 + lea(mem(rbx, rdx, 1), rdx) // rdx = b + ps_b8 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; #endif @@ -2623,9 +2472,10 @@ void bli_dgemmsup_rv_haswell_asm_3x8n // ---------------------------------- iteration 0 -#if 1 - //prefetch(0, mem(rdx, 11*8)) // prefetch line of next upanel of b - prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -2647,9 +2497,10 @@ void bli_dgemmsup_rv_haswell_asm_3x8n // ---------------------------------- iteration 1 -#if 1 - //prefetch(0, mem(rdx, r10, 1, 11*8)) - prefetch(0, mem(rbx, 11*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -2671,9 +2522,10 @@ void bli_dgemmsup_rv_haswell_asm_3x8n // ---------------------------------- iteration 2 -#if 1 - //prefetch(0, mem(rdx, r10, 2, 11*8)) - prefetch(0, mem(rbx, 11*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -2695,10 +2547,11 @@ void bli_dgemmsup_rv_haswell_asm_3x8n // ---------------------------------- iteration 3 -#if 1 - //prefetch(0, mem(rdx, rcx, 1, 11*8)) - //lea(mem(rdx, r10, 4), rdx) // a_prefetch += 4*cs_a; - prefetch(0, mem(rbx, 11*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -2738,7 +2591,8 @@ void bli_dgemmsup_rv_haswell_asm_3x8n label(.DLOOPKLEFT) // EDGE LOOP #if 1 - prefetch(0, mem(rbx, 11*8)) + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) // b_prefetch += rs_b; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -2811,27 +2665,27 @@ void bli_dgemmsup_rv_haswell_asm_3x8n label(.DROWSTORED) - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) - vmovupd(ymm7, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm9) - vmovupd(ymm9, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) //add(rdi, rcx) @@ -2841,7 +2695,7 @@ void bli_dgemmsup_rv_haswell_asm_3x8n label(.DCOLSTORED) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -2858,30 +2712,29 @@ void bli_dgemmsup_rv_haswell_asm_3x8n vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx ), xmm3, xmm4) vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) - vmovupd(xmm4, mem(rcx)) + vmovupd(xmm4, mem(rcx )) vmovupd(xmm6, mem(rcx, rsi, 1)) vmovupd(xmm8, mem(rcx, rsi, 2)) vmovupd(xmm10, mem(rcx, rax, 1)) lea(mem(rcx, rsi, 4), rcx) - vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx ), xmm3, xmm12) vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) - vmovsd(xmm12, mem(rdx)) + vmovsd(xmm12, mem(rdx )) vmovsd(xmm13, mem(rdx, rsi, 1)) vmovsd(xmm14, mem(rdx, rsi, 2)) vmovsd(xmm15, mem(rdx, rax, 1)) lea(mem(rdx, rsi, 4), rdx) - - + // begin I/O on columns 4-7 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -2898,22 +2751,22 @@ void bli_dgemmsup_rv_haswell_asm_3x8n vbroadcastsd(mem(rbx), ymm3) - vfmadd231pd(mem(rcx), xmm3, xmm5) + vfmadd231pd(mem(rcx ), xmm3, xmm5) vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm9) vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm11) - vmovupd(xmm5, mem(rcx)) + vmovupd(xmm5, mem(rcx )) vmovupd(xmm7, mem(rcx, rsi, 1)) vmovupd(xmm9, mem(rcx, rsi, 2)) vmovupd(xmm11, mem(rcx, rax, 1)) //lea(mem(rcx, rsi, 4), rcx) - vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx ), xmm3, xmm12) vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) - vmovsd(xmm12, mem(rdx)) + vmovsd(xmm12, mem(rdx )) vmovsd(xmm13, mem(rdx, rsi, 1)) vmovsd(xmm14, mem(rdx, rsi, 2)) vmovsd(xmm15, mem(rdx, rax, 1)) @@ -2937,16 +2790,16 @@ void bli_dgemmsup_rv_haswell_asm_3x8n label(.DROWSTORBZ) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm6, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 4)) + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm8, mem(rcx)) - vmovupd(ymm9, mem(rcx, rsi, 4)) + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) //add(rdi, rcx) @@ -2956,7 +2809,7 @@ void bli_dgemmsup_rv_haswell_asm_3x8n label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -2971,21 +2824,21 @@ void bli_dgemmsup_rv_haswell_asm_3x8n vextractf128(imm(0x1), ymm8, xmm14) vextractf128(imm(0x1), ymm10, xmm15) - vmovupd(xmm4, mem(rcx)) + vmovupd(xmm4, mem(rcx )) vmovupd(xmm6, mem(rcx, rsi, 1)) vmovupd(xmm8, mem(rcx, rsi, 2)) vmovupd(xmm10, mem(rcx, rax, 1)) lea(mem(rcx, rsi, 4), rcx) - vmovsd(xmm12, mem(rdx)) + vmovsd(xmm12, mem(rdx )) vmovsd(xmm13, mem(rdx, rsi, 1)) vmovsd(xmm14, mem(rdx, rsi, 2)) vmovsd(xmm15, mem(rdx, rax, 1)) lea(mem(rdx, rsi, 4), rdx) - + // begin I/O on columns 4-7 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -3000,14 +2853,14 @@ void bli_dgemmsup_rv_haswell_asm_3x8n vextractf128(imm(0x1), ymm9, xmm14) vextractf128(imm(0x1), ymm11, xmm15) - vmovupd(xmm5, mem(rcx)) + vmovupd(xmm5, mem(rcx )) vmovupd(xmm7, mem(rcx, rsi, 1)) vmovupd(xmm9, mem(rcx, rsi, 2)) vmovupd(xmm11, mem(rcx, rax, 1)) //lea(mem(rcx, rsi, 4), rcx) - vmovsd(xmm12, mem(rdx)) + vmovsd(xmm12, mem(rdx )) vmovsd(xmm13, mem(rdx, rsi, 1)) vmovsd(xmm14, mem(rdx, rsi, 2)) vmovsd(xmm15, mem(rdx, rax, 1)) @@ -3026,7 +2879,7 @@ void bli_dgemmsup_rv_haswell_asm_3x8n //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b mov(var(ps_b8), rbx) // load ps_b8 - lea(mem(r14, rbx, 1), r14) // a_ii = r14 += ps_b8 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b8 dec(r11) // jj -= 1; jne(.DLOOP4X8J) // iterate again if jj != 0. @@ -3059,7 +2912,7 @@ void bli_dgemmsup_rv_haswell_asm_3x8n [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -3228,7 +3081,7 @@ void bli_dgemmsup_rv_haswell_asm_2x8n //lea(mem(r12, rdi, 2), rdx) // //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c jmp(.DPOSTPFETCH) // jump to end of prefetching c @@ -3238,10 +3091,10 @@ void bli_dgemmsup_rv_haswell_asm_2x8n lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, 1*8)) // prefetch c + 0*cs_c prefetch(0, mem(r12, rsi, 1, 1*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, rsi, 2, 1*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; @@ -3251,11 +3104,15 @@ void bli_dgemmsup_rv_haswell_asm_2x8n label(.DPOSTPFETCH) // done prefetching c #if 1 - //lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; - - // use byte offsets from rbx to - // prefetch lines from next upanel - // of b. + mov(var(ps_b8), rdx) // load ps_b8 + lea(mem(rbx, rdx, 1), rdx) // rdx = b + ps_b8 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; #endif @@ -3272,9 +3129,10 @@ void bli_dgemmsup_rv_haswell_asm_2x8n // ---------------------------------- iteration 0 -#if 1 - //prefetch(0, mem(rdx, 11*8)) // prefetch line of next upanel of b - prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -3292,9 +3150,10 @@ void bli_dgemmsup_rv_haswell_asm_2x8n // ---------------------------------- iteration 1 -#if 1 - //prefetch(0, mem(rdx, r10, 1, 11*8)) - prefetch(0, mem(rbx, 11*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -3312,9 +3171,10 @@ void bli_dgemmsup_rv_haswell_asm_2x8n // ---------------------------------- iteration 2 -#if 1 - //prefetch(0, mem(rdx, r10, 2, 11*8)) - prefetch(0, mem(rbx, 11*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -3332,10 +3192,11 @@ void bli_dgemmsup_rv_haswell_asm_2x8n // ---------------------------------- iteration 3 -#if 1 - //prefetch(0, mem(rdx, rcx, 1, 11*8)) - //lea(mem(rdx, r10, 4), rdx) // a_prefetch += 4*cs_a; - prefetch(0, mem(rbx, 11*8)) +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -3371,7 +3232,8 @@ void bli_dgemmsup_rv_haswell_asm_2x8n label(.DLOOPKLEFT) // EDGE LOOP #if 1 - prefetch(0, mem(rbx, 11*8)) + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) // b_prefetch += rs_b; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -3438,19 +3300,19 @@ void bli_dgemmsup_rv_haswell_asm_2x8n label(.DROWSTORED) - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm7) - vmovupd(ymm7, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) //add(rdi, rcx) @@ -3460,34 +3322,34 @@ void bli_dgemmsup_rv_haswell_asm_2x8n label(.DCOLSTORED) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(rcx), xmm3, xmm0) + vfmadd231pd(mem(rcx ), xmm3, xmm0) vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) - vmovupd(xmm0, mem(rcx)) + vmovupd(xmm0, mem(rcx )) vmovupd(xmm1, mem(rcx, rsi, 1)) vmovupd(xmm2, mem(rcx, rsi, 2)) vmovupd(xmm4, mem(rcx, rax, 1)) lea(mem(rcx, rsi, 4), rcx) - + // begin I/O on columns 4-7 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(rcx), xmm3, xmm0) + vfmadd231pd(mem(rcx ), xmm3, xmm0) vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) - vmovupd(xmm0, mem(rcx)) + vmovupd(xmm0, mem(rcx )) vmovupd(xmm1, mem(rcx, rsi, 1)) vmovupd(xmm2, mem(rcx, rsi, 2)) vmovupd(xmm4, mem(rcx, rax, 1)) @@ -3511,12 +3373,12 @@ void bli_dgemmsup_rv_haswell_asm_2x8n label(.DROWSTORBZ) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - vmovupd(ymm6, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 4)) + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) //add(rdi, rcx) @@ -3526,26 +3388,26 @@ void bli_dgemmsup_rv_haswell_asm_2x8n label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(rcx)) + vmovupd(xmm0, mem(rcx )) vmovupd(xmm1, mem(rcx, rsi, 1)) vmovupd(xmm2, mem(rcx, rsi, 2)) vmovupd(xmm4, mem(rcx, rax, 1)) lea(mem(rcx, rsi, 4), rcx) - + // begin I/O on columns 4-7 vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(rcx)) + vmovupd(xmm0, mem(rcx )) vmovupd(xmm1, mem(rcx, rsi, 1)) vmovupd(xmm2, mem(rcx, rsi, 2)) vmovupd(xmm4, mem(rcx, rax, 1)) @@ -3564,7 +3426,7 @@ void bli_dgemmsup_rv_haswell_asm_2x8n //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b mov(var(ps_b8), rbx) // load ps_b8 - lea(mem(r14, rbx, 1), r14) // a_ii = r14 += ps_b8 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b8 dec(r11) // jj -= 1; jne(.DLOOP2X8J) // iterate again if jj != 0. @@ -3597,7 +3459,7 @@ void bli_dgemmsup_rv_haswell_asm_2x8n [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -3764,7 +3626,7 @@ void bli_dgemmsup_rv_haswell_asm_1x8n //lea(mem(r12, rdi, 2), rdx) // //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c jmp(.DPOSTPFETCH) // jump to end of prefetching c label(.DCOLPFETCH) // column-stored prefetching c @@ -3773,10 +3635,10 @@ void bli_dgemmsup_rv_haswell_asm_1x8n lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) lea(mem(r12, rsi, 2), rdx) // lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, 0*8)) // prefetch c + 0*cs_c prefetch(0, mem(r12, rsi, 1, 0*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, rsi, 2, 0*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; @@ -3786,11 +3648,15 @@ void bli_dgemmsup_rv_haswell_asm_1x8n label(.DPOSTPFETCH) // done prefetching c #if 1 - //lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; - - // use byte offsets from rbx to - // prefetch lines from next upanel - // of b. + mov(var(ps_b8), rdx) // load ps_b8 + lea(mem(rbx, rdx, 1), rdx) // rdx = b + ps_b8 + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; + // use rcx, rdx for prefetching lines + // from next upanel of b. +#else + lea(mem(rbx, r8, 8), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 8), rdx) // from next upanel of b. + lea(mem(r10, r10, 2), rcx) // rcx = 3*rs_b; #endif @@ -3808,8 +3674,9 @@ void bli_dgemmsup_rv_haswell_asm_1x8n // ---------------------------------- iteration 0 #if 1 - //prefetch(0, mem(rdx, 11*8)) // prefetch line of next upanel of b - prefetch(0, mem(rbx, 11*8)) // prefetch line of next upanel of b + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -3825,8 +3692,9 @@ void bli_dgemmsup_rv_haswell_asm_1x8n // ---------------------------------- iteration 1 #if 1 - //prefetch(0, mem(rdx, r10, 1, 11*8)) - prefetch(0, mem(rbx, 11*8)) + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 1, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -3842,8 +3710,9 @@ void bli_dgemmsup_rv_haswell_asm_1x8n // ---------------------------------- iteration 2 #if 1 - //prefetch(0, mem(rdx, r10, 2, 11*8)) - prefetch(0, mem(rbx, 11*8)) + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r10, 2, 5*8)) #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -3859,9 +3728,10 @@ void bli_dgemmsup_rv_haswell_asm_1x8n // ---------------------------------- iteration 3 #if 1 - //prefetch(0, mem(rdx, rcx, 1, 11*8)) - //lea(mem(rdx, r10, 4), rdx) // a_prefetch += 4*cs_a; - prefetch(0, mem(rbx, 11*8)) + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r10, 4), rdx) // b_prefetch += 4*rs_b; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -3894,7 +3764,8 @@ void bli_dgemmsup_rv_haswell_asm_1x8n label(.DLOOPKLEFT) // EDGE LOOP #if 1 - prefetch(0, mem(rbx, 11*8)) + prefetch(0, mem(rdx, 5*8)) + add(r10, rdx) // b_prefetch += rs_b; #endif vmovupd(mem(rbx, 0*32), ymm0) @@ -3956,11 +3827,11 @@ void bli_dgemmsup_rv_haswell_asm_1x8n label(.DROWSTORED) - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, rsi, 4), ymm3, ymm5) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) //add(rdi, rcx) @@ -3970,8 +3841,8 @@ void bli_dgemmsup_rv_haswell_asm_1x8n label(.DCOLSTORED) - - vmovlpd(mem(rcx), xmm0, xmm0) + // begin I/O on columns 0-3 + vmovlpd(mem(rcx ), xmm0, xmm0) vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) @@ -3980,15 +3851,15 @@ void bli_dgemmsup_rv_haswell_asm_1x8n vfmadd213pd(ymm4, ymm3, ymm0) vextractf128(imm(1), ymm0, xmm1) - vmovlpd(xmm0, mem(rcx)) + vmovlpd(xmm0, mem(rcx )) vmovhpd(xmm0, mem(rcx, rsi, 1)) vmovlpd(xmm1, mem(rcx, rsi, 2)) vmovhpd(xmm1, mem(rcx, rax, 1)) lea(mem(rcx, rsi, 4), rcx) - - vmovlpd(mem(rcx), xmm0, xmm0) + // begin I/O on columns 4-7 + vmovlpd(mem(rcx ), xmm0, xmm0) vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) @@ -3997,7 +3868,7 @@ void bli_dgemmsup_rv_haswell_asm_1x8n vfmadd213pd(ymm5, ymm3, ymm0) vextractf128(imm(1), ymm0, xmm1) - vmovlpd(xmm0, mem(rcx)) + vmovlpd(xmm0, mem(rcx )) vmovhpd(xmm0, mem(rcx, rsi, 1)) vmovlpd(xmm1, mem(rcx, rsi, 2)) vmovhpd(xmm1, mem(rcx, rax, 1)) @@ -4021,8 +3892,8 @@ void bli_dgemmsup_rv_haswell_asm_1x8n label(.DROWSTORBZ) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rsi, 4)) + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) //add(rdi, rcx) @@ -4032,22 +3903,22 @@ void bli_dgemmsup_rv_haswell_asm_1x8n label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vmovupd(ymm4, ymm0) vextractf128(imm(1), ymm0, xmm1) - vmovlpd(xmm0, mem(rcx)) + vmovlpd(xmm0, mem(rcx )) vmovhpd(xmm0, mem(rcx, rsi, 1)) vmovlpd(xmm1, mem(rcx, rsi, 2)) vmovhpd(xmm1, mem(rcx, rax, 1)) lea(mem(rcx, rsi, 4), rcx) - + // begin I/O on columns 4-7 vmovupd(ymm5, ymm0) vextractf128(imm(1), ymm0, xmm1) - vmovlpd(xmm0, mem(rcx)) + vmovlpd(xmm0, mem(rcx )) vmovhpd(xmm0, mem(rcx, rsi, 1)) vmovlpd(xmm1, mem(rcx, rsi, 2)) vmovhpd(xmm1, mem(rcx, rax, 1)) @@ -4066,7 +3937,7 @@ void bli_dgemmsup_rv_haswell_asm_1x8n //add(imm(8*8), r14) // b_jj = r14 += 8*cs_b mov(var(ps_b8), rbx) // load ps_b8 - lea(mem(r14, rbx, 1), r14) // a_ii = r14 += ps_b8 + lea(mem(r14, rbx, 1), r14) // b_jj = r14 += ps_b8 dec(r11) // jj -= 1; jne(.DLOOP1X8J) // iterate again if jj != 0. @@ -4099,7 +3970,7 @@ void bli_dgemmsup_rv_haswell_asm_1x8n [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c new file mode 100644 index 000000000..69d543a99 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c @@ -0,0 +1,158 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +// NOTE: Normally, for any "?x1" kernel, we would call the reference kernel. +// However, at least one other subconfiguration (zen) uses this kernel set, so +// we need to be able to call a set of "?x1" kernels that we know will actually +// exist regardless of which subconfiguration these kernels were used by. Thus, +// the compromise employed here is to inline the reference kernel so it gets +// compiled as part of the haswell kernel set, and hence can unconditionally be +// called by other kernels within that kernel set. + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, mdim ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + for ( dim_t i = 0; i < mdim; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + /* for ( dim_t j = 0; j < 1; ++j ) */ \ + { \ + ctype* restrict cij = ci /*[ j*cs_c ]*/ ; \ + ctype* restrict bj = b /*[ j*cs_b ]*/ ; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(d,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ +} + +GENTFUNC( double, d, gemmsup_r_haswell_ref_6x1, 6 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_5x1, 5 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_4x1, 4 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_3x1, 3 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_2x1, 2 ) +GENTFUNC( double, d, gemmsup_r_haswell_ref_1x1, 1 ) + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c new file mode 100644 index 000000000..6e3c1a0e8 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c @@ -0,0 +1,1698 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rd_haswell_asm_6x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 + // ymm6 + // ymm8 + // ymm10 + // ymm12 + // ymm14 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4[0] = sum(ymm4) + // xmm6[0] = sum(ymm6) + // xmm8[0] = sum(ymm8) + // xmm10[0] = sum(ymm10) + // xmm12[0] = sum(ymm12) + // xmm14[0] = sum(ymm14) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm8) + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm10) + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm12) + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm14) + vmovsd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_3x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm8, ymm8, ymm8) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 + // ymm6 + // ymm8 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4[0] = sum(ymm4) + // xmm6[0] = sum(ymm6) + // xmm8[0] = sum(ymm8) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm8) + vmovsd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm6, ymm6, ymm6) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 + // ymm6 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4[0] = sum(ymm4) + // xmm6[0] = sum(ymm6) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm6) + vmovsd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovsd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + // xmm4[0] = sum(ymm4) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vmovsd(mem(rcx), xmm0) + vfmadd231pd(xmm0, xmm3, xmm4) + vmovsd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovsd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c new file mode 100644 index 000000000..af498eb0e --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c @@ -0,0 +1,1794 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rd_haswell_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 6; + //uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + lea(mem(rcx, rdi, 2), r10) // + lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovupd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovupd(mem(rax, r15, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + vmovsd(mem(rax, r13, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rax, r8, 4), xmm3) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + + vmovsd(mem(rax, r15, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + // ymm10 ymm11 + // ymm12 ymm13 + // ymm14 ymm15 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + vhaddpd( ymm11, ymm10, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm10 ) + + vhaddpd( ymm13, ymm12, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm12 ) + + vhaddpd( ymm15, ymm14, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm14 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + // xmm6[0:1] = sum(ymm6) sum(ymm7) + // xmm8[0:1] = sum(ymm8) sum(ymm9) + // xmm10[0:1] = sum(ymm10) sum(ymm11) + // xmm12[0:1] = sum(ymm12) sum(ymm13) + // xmm14[0:1] = sum(ymm14) sum(ymm15) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm10) + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm12) + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm14) + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm10, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm12, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm14, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 6; + //uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovupd(mem(rax, r8, 2), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vmovsd(mem(rax, r8, 2), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm8) + vfmadd231pd(ymm1, ymm3, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + // ymm8 ymm9 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + vhaddpd( ymm9, ymm8, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm8 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + // xmm6[0:1] = sum(ymm6) sum(ymm7) + // xmm8[0:1] = sum(ymm8) sum(ymm9) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm8) + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm8, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 6; + //uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rax, r8, 1), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rax, r8, 1), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + // ymm6 ymm7 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + vhaddpd( ymm7, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm6 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + // xmm6[0:1] = sum(ymm6) sum(ymm7) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), xmm3, xmm6) + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + //uint64_t m_iter = m0 / 6; + //uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) +#endif + + + //lea(mem(r12), rcx) // rcx = c_ii; + //lea(mem(r14), rax) // rax = a_ii; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + //lea(mem(rcx, rdi, 2), r10) // + //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rbx ), ymm0) + vmovupd(mem(rbx, r11, 1), ymm1) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + + vmovupd(mem(rax ), ymm3) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rbx ), xmm0) + vmovsd(mem(rbx, r11, 1), xmm1) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + + vmovsd(mem(rax ), xmm3) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + // ymm4 ymm5 + + vhaddpd( ymm5, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm4 ) + + // xmm4[0:1] = sum(ymm4) sum(ymm5) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx)) + //add(rdi, rcx) + + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c new file mode 100644 index 000000000..a3b56cb12 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c @@ -0,0 +1,1450 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rd_haswell_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + //uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_ii; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + + // xmm4[0:3] = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // xmm5[0:3] = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + // xmm6[0:3] = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + //lea(mem(r12), rcx) // rcx = c; + //lea(mem(r14), rax) // rax = a; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // xmm4[0:3] = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + // xmm5[0:3] = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + //lea(mem(r12), rcx) // rcx = c; + //lea(mem(r14), rax) // rax = a; + //lea(mem(rdx), rbx) // rbx = b; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // xmm4[0:3] = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c new file mode 100644 index 000000000..571444bed --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c @@ -0,0 +1,1617 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrc: + -------- ------ | | | | | | | | + -------- ------ | | | | | | | | + -------- += ------ ... | | | | | | | | + -------- ------ | | | | | | | | + -------- ------ : + -------- ------ : + + Assumptions: + - C is row-stored and B is column-stored; + - A is row-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential microkernel is well-suited for + a dot-product-based accumulation that performs vector loads from + both A and B. +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rd_haswell_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t n_left = n0 % 8; + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if ( n_left ) + { + double* restrict cij = c; + double* restrict bj = b; + double* restrict ai = a; + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_dgemmsup_rd_haswell_asm_6x4 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_6x2 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + #if 1 + const dim_t nr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_6x1 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + #else + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + #endif + } + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + //mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + //mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + + + mov(var(m_iter), r9) // ii = m_iter; + + label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm6, ymm6, ymm6) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm9, ymm9, ymm9) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm12, ymm12, ymm12) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) + vxorpd(ymm15, ymm15, ymm15) +#endif + + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c +#endif + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + jne(.DLOOP3X4I) // iterate again if ii != 0. + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 8 + jl(.DLOOP3X4J) // if jj < 8, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8 + //bli_dgemmsup_r_haswell_ref + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } +} + +void bli_dgemmsup_rd_haswell_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = unused + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm8, ymm8, ymm8) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm11, ymm11, ymm11) + vxorpd(ymm13, ymm13, ymm13) + vxorpd(ymm14, ymm14, ymm14) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_b = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 8 + jl(.DLOOP3X4J) // if jj < 8, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rd_haswell_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + //mov(var(rs_a), r8) // load rs_a + //mov(var(cs_a), r9) // load cs_a + //lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + //lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rdx) // load address of b. + //mov(var(rs_b), r10) // load rs_b + mov(var(cs_b), r11) // load cs_b + //lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + + mov(var(c), r12) // load address of c + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + // r12 = rcx = c + // r14 = rax = a + // rdx = rbx = b + // r9 = m dim index ii + // r15 = n dim index jj + + mov(imm(0), r15) // jj = 0; + + label(.DLOOP3X4J) // LOOP OVER jj = [ 0 1 ... ] + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vxorpd(ymm7, ymm7, ymm7) + vxorpd(ymm10, ymm10, ymm10) + vxorpd(ymm13, ymm13, ymm13) +#endif + + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rbx) // rbx = b + 4*jj*cs_b; + + lea(mem(r14), rax) // rax = a; + + +#if 1 + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c +#endif + + + + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + + label(.DLOOPKITER16) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16) // iterate again if i != 0. + + + + + + + label(.DCONSIDKITER4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + + label(.DLOOPKITER4) // EDGE LOOP (ymm) + +#if 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*cs_a +#endif + + vmovupd(mem(rax ), ymm0) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4) // iterate again if i != 0. + + + + + label(.DCONSIDKLEFT1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + + + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1) // iterate again if i != 0. + + + + + + + + label(.DPOSTACCUM) + + + + // ymm4 ymm7 ymm10 ymm13 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + + + + //mov(var(rs_c), rdi) // load rs_c + //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + //mov(var(cs_c), rsi) // load cs_c + //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx)) + //add(rdi, rcx) + + + + + label(.DDONE) + + + + + add(imm(4), r15) // jj += 4; + cmp(imm(8), r15) // compare jj to 8 + jl(.DLOOP3X4J) // if jj < 8, jump to beginning + // of jj loop; otherwise, loop ends. + + + + label(.DRETURN) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c new file mode 100644 index 000000000..eb1118196 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c @@ -0,0 +1,2496 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rv_haswell_asm_6x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 1*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + vfmadd231pd(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + vmulpd(xmm0, xmm14, xmm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) + vmovupd(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) + vmovupd(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(xmm14, xmm12, xmm0) + vunpckhpd(xmm14, xmm12, xmm1) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 1*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm12) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + vmulpd(xmm0, xmm12, xmm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // r13 = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) + vmovupd(xmm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + + vfmadd213pd(xmm12, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovupd(xmm12, xmm0) + + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + vfmadd231pd(xmm0, xmm3, xmm10) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + vmulpd(xmm0, xmm10, xmm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) + vmovupd(xmm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + vunpcklpd(xmm10, xmm8, xmm2) + vunpckhpd(xmm10, xmm8, xmm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm8) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + vmulpd(xmm0, xmm8, xmm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) + vmovupd(xmm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-1 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + vmulpd(xmm0, xmm6, xmm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-1 + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + + vunpcklpd(xmm6, xmm4, xmm0) + vunpckhpd(xmm6, xmm4, xmm1) + + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x2 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + //lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(xmm0, xmm2, xmm4) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(xmm0, xmm4, xmm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovupd(xmm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-1 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + + vfmadd213pd(xmm4, xmm3, xmm0) + + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-1 + vmovlpd(xmm4, mem(rcx )) + vmovhpd(xmm4, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c new file mode 100644 index 000000000..bdcf833e3 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c @@ -0,0 +1,2600 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rv_haswell_asm_6x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 5*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 4*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovupd(ymm12, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 3*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 2*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx ), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 1*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x4 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rbp, 1, 0*8)) // prefetch c + 3*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + + // begin I/O on columns 0-3 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c new file mode 100644 index 000000000..8022bf065 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c @@ -0,0 +1,3095 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + + +void bli_dgemmsup_rv_haswell_asm_6x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(xmm0, xmm13, xmm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(xmm0, xmm15, xmm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) + vmovupd(xmm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm15) + vmovupd(xmm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(xmm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(xmm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + + vmovupd(xmm0, mem(rdx)) + vmovupd(xmm1, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(xmm0, xmm13, xmm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) + vmovupd(xmm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + + vfmadd213pd(xmm13, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(xmm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovupd(ymm12, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovupd(ymm13, ymm0) + + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(xmm0, xmm11, xmm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) + vmovupd(xmm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(xmm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 5*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(xmm0, xmm9, xmm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) + vmovupd(xmm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx ), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + //vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + //vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + //vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm9) + //vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm11) + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx ), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(xmm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 5*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(xmm0, xmm7, xmm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) + vmovupd(xmm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-5 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x6 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(xmm0, xmm5, xmm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vmovupd(xmm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-5 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + + vfmadd213pd(xmm5, xmm3, xmm0) + + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-5 + vmovupd(xmm5, xmm0) + + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c new file mode 100644 index 000000000..a6c8f0e43 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c @@ -0,0 +1,3260 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +/* + rrr: + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : + + rcr: + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : + + Assumptions: + - B is row-stored; + - A is row- or column-stored; + - m0 and n0 are at most MR and NR, respectively. + Therefore, this (r)ow-preferential kernel is well-suited for contiguous + (v)ector loads on B and single-element broadcasts from A. + + NOTE: These kernels explicitly support column-oriented IO, implemented + via an in-register transpose. And thus they also support the crr and + ccr cases, though only crr is ever utilized (because ccr is handled by + transposing the operation and executing rcr, which does not incur the + cost of the in-register transpose). + + crr: + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : +*/ + +// Prototype reference microkernels. +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) + +// Define parameters and variables for edge case kernel map. +#define NUM_MR 4 +#define NUM_NR 4 +#define FUNCPTR_T dgemmsup_ker_ft + +static dim_t mrs[NUM_MR] = { 6, 4, 2, 1 }; +static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 }; +static FUNCPTR_T kmap[NUM_MR][NUM_NR] = +{ /* 8 4 2 1 */ +/* 6 */ { bli_dgemmsup_rv_haswell_asm_6x8, bli_dgemmsup_rv_haswell_asm_6x4, bli_dgemmsup_rv_haswell_asm_6x2, bli_dgemmsup_r_haswell_ref_6x1 }, +/* 4 */ { bli_dgemmsup_rv_haswell_asm_4x8, bli_dgemmsup_rv_haswell_asm_4x4, bli_dgemmsup_rv_haswell_asm_4x2, bli_dgemmsup_r_haswell_ref_4x1 }, +/* 2 */ { bli_dgemmsup_rv_haswell_asm_2x8, bli_dgemmsup_rv_haswell_asm_2x4, bli_dgemmsup_rv_haswell_asm_2x2, bli_dgemmsup_r_haswell_ref_2x1 }, +/* 1 */ { bli_dgemmsup_rv_haswell_asm_1x8, bli_dgemmsup_rv_haswell_asm_1x4, bli_dgemmsup_rv_haswell_asm_1x2, bli_dgemmsup_r_haswell_ref_1x1 }, +}; + + +void bli_dgemmsup_rv_haswell_asm_6x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + + // Use a reference kernel if this is an edge case in the m or n + // dimensions. + if ( m0 < 6 || n0 < 8 ) + { +#if 0 + bli_dgemmsup_r_haswell_ref + ( + conja, conjb, m0, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; +#endif + dim_t n_left = n0; + double* restrict cj = c; + double* restrict bj = b; + + // Iterate across columns (corresponding to elements of nrs) until + // n_left is zero. + for ( dim_t j = 0; n_left != 0; ++j ) + { + const dim_t nr_cur = nrs[ j ]; + + // Once we find the value of nrs that is less than (or equal to) + // n_left, we use the kernels in that column. + if ( nr_cur <= n_left ) + { + dim_t m_left = m0; + double* restrict cij = cj; + double* restrict ai = a; + + // Iterate down the current column (corresponding to elements + // of mrs) until m_left is zero. + for ( dim_t i = 0; m_left != 0; ++i ) + { + const dim_t mr_cur = mrs[ i ]; + + // Once we find the value of mrs that is less than (or equal + // to) m_left, we select that kernel. + if ( mr_cur <= m_left ) + { + FUNCPTR_T ker_fp = kmap[i][j]; + + //printf( "executing %d x %d sup kernel.\n", (int)mr_cur, (int)nr_cur ); + + // Call the kernel using current mrs and nrs values. + ker_fp + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + // Advance C and A pointers by the mrs and nrs we just + // used, and decrement m_left. + cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + } + + // Advance C and B pointers by the mrs and nrs we just used, and + // decrement n_left. + cj += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_5x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm12, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovlpd(mem(rdx ), xmm0, xmm0) + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm13, ymm3, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovupd(ymm12, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovupd(ymm13, ymm0) + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rdx )) + vmovhpd(xmm0, mem(rdx, rsi, 1)) + vmovlpd(xmm1, mem(rdx, rsi, 2)) + vmovhpd(xmm1, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm4) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm6) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm8) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm10) + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx ), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + vextractf128(imm(0x1), ymm9, xmm14) + vextractf128(imm(0x1), ymm11, xmm15) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), xmm3, xmm5) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm7) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm9) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm11) + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vfmadd231sd(mem(rdx ), xmm3, xmm12) + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14) + vfmadd231sd(mem(rdx, rax, 1), xmm3, xmm15) + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vextractf128(imm(0x1), ymm4, xmm12) + vextractf128(imm(0x1), ymm6, xmm13) + vextractf128(imm(0x1), ymm8, xmm14) + vextractf128(imm(0x1), ymm10, xmm15) + + vmovupd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vextractf128(imm(0x1), ymm5, xmm12) + vextractf128(imm(0x1), ymm7, xmm13) + vextractf128(imm(0x1), ymm9, xmm14) + vextractf128(imm(0x1), ymm11, xmm15) + + vmovupd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vmovupd(xmm11, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + vmovsd(xmm12, mem(rdx )) + vmovsd(xmm13, mem(rdx, rsi, 1)) + vmovsd(xmm14, mem(rdx, rsi, 2)) + vmovsd(xmm15, mem(rdx, rax, 1)) + + //lea(mem(rdx, rsi, 4), rdx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi","rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 1*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + vmovupd(xmm2, mem(rcx, rsi, 2)) + vmovupd(xmm4, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi","rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x8 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + //lea(mem(rcx, rdi, 2), rdx) // + //lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rcx, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm4, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vmovlpd(mem(rcx ), xmm0, xmm0) + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1) + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) + + vfmadd213pd(ymm5, ymm3, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vmovupd(ymm4, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vmovupd(ymm5, ymm0) + + vextractf128(imm(1), ymm0, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovhpd(xmm0, mem(rcx, rsi, 1)) + vmovlpd(xmm1, mem(rcx, rsi, 2)) + vmovhpd(xmm1, mem(rcx, rax, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + + + + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi","rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8.c b/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rd_haswell_asm_d6x8.c similarity index 99% rename from kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8.c rename to kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rd_haswell_asm_d6x8.c index 1b80af8b7..87ef7309b 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8.c +++ b/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rd_haswell_asm_d6x8.c @@ -219,7 +219,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8 mov(var(c), r12) // load address of c lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; - imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; @@ -487,7 +487,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8 vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) vmovsd(mem(rax, r8, 2), xmm2) - add(imm(1*8), rax) // a += 1*cs_b = 1*8; + add(imm(1*8), rax) // a += 1*cs_a = 1*8; vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -788,6 +788,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8 lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a mov(var(c), r12) // load address of c @@ -826,7 +827,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8 lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; - imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; @@ -1018,7 +1019,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8 vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) - add(imm(1*8), rax) // a += 1*cs_b = 1*8; + add(imm(1*8), rax) // a += 1*cs_a = 1*8; vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1274,7 +1275,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8 lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; - imul(imm(1*8), rsi) // rsi *= cs_c = 1*8 + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c; lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; @@ -1438,7 +1439,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8 // which would destory intermediate results. vmovsd(mem(rax ), xmm0) - add(imm(1*8), rax) // a += 1*cs_b = 1*8; + add(imm(1*8), rax) // a += 1*cs_a = 1*8; vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1908,7 +1909,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) vmovsd(mem(rax, r8, 2), xmm2) - add(imm(1*8), rax) // a += 1*cs_b = 1*8; + add(imm(1*8), rax) // a += 1*cs_a = 1*8; vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -2398,7 +2399,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) - add(imm(1*8), rax) // a += 1*cs_b = 1*8; + add(imm(1*8), rax) // a += 1*cs_a = 1*8; vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -2783,7 +2784,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 // which would destory intermediate results. vmovsd(mem(rax ), xmm0) - add(imm(1*8), rax) // a += 1*cs_b = 1*8; + add(imm(1*8), rax) // a += 1*cs_a = 1*8; vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8.c b/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rv_haswell_asm_d6x8.c similarity index 99% rename from kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8.c rename to kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rv_haswell_asm_d6x8.c index ebe396317..fe61fbc31 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8.c +++ b/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rv_haswell_asm_d6x8.c @@ -257,6 +257,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8 prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c jmp(.DPOSTPFETCH) // jump to end of prefetching c label(.DCOLPFETCH) // column-stored prefetching c diff --git a/kernels/haswell/bli_kernels_haswell.h b/kernels/haswell/bli_kernels_haswell.h index df49a77dd..d37f0f877 100644 --- a/kernels/haswell/bli_kernels_haswell.h +++ b/kernels/haswell/bli_kernels_haswell.h @@ -65,6 +65,17 @@ GEMMTRSM_UKR_PROT( double, d, gemmtrsm_u_haswell_asm_6x8 ) // -- level-3 sup -------------------------------------------------------------- +// -- double real -- + +// gemmsup_r + +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_6x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_5x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_4x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_3x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_2x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_1x1 ) + // gemmsup_rv GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8 ) @@ -95,13 +106,6 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x2 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x2 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x2 ) -GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_6x1 ) -GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_5x1 ) -GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_4x1 ) -GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_3x1 ) -GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_2x1 ) -GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_1x1 ) - // gemmsup_rv (mkernel in m dim) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m ) @@ -133,6 +137,11 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_3x2 ) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x2 ) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_3x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x1 ) + // gemmsup_rd (mkernel in m dim) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m ) diff --git a/ref_kernels/bli_cntx_ref.c b/ref_kernels/bli_cntx_ref.c index 36cdd52dc..dbbc2130e 100644 --- a/ref_kernels/bli_cntx_ref.c +++ b/ref_kernels/bli_cntx_ref.c @@ -418,7 +418,8 @@ void GENBARNAME(cntx_init) gen_func_init( &funcs[ BLIS_TRSM_L_UKR ], trsm_l_ukr_name ); gen_func_init( &funcs[ BLIS_TRSM_U_UKR ], trsm_u_ukr_name ); - bli_mbool_init( &mbools[ BLIS_GEMM_UKR ], TRUE, TRUE, TRUE, TRUE ); + // s d c z + bli_mbool_init( &mbools[ BLIS_GEMM_UKR ], TRUE, TRUE, TRUE, TRUE ); bli_mbool_init( &mbools[ BLIS_GEMMTRSM_L_UKR ], FALSE, FALSE, FALSE, FALSE ); bli_mbool_init( &mbools[ BLIS_GEMMTRSM_U_UKR ], FALSE, FALSE, FALSE, FALSE ); bli_mbool_init( &mbools[ BLIS_TRSM_L_UKR ], FALSE, FALSE, FALSE, FALSE ); @@ -427,11 +428,14 @@ void GENBARNAME(cntx_init) // -- Set level-3 small/unpacked thresholds -------------------------------- - // NOTE: The default thresholds are set very low so that the sup framework - // only actives for exceedingly small dimensions. If a sub-configuration - // registers optimized sup kernels, then that sub-configuration should also - // register new (probably larger) thresholds that are almost surely more - // appropriate that these default values. + // NOTE: The default thresholds are set to zero so that the sup framework + // does not activate by default. Note that the semantic meaning of the + // thresholds is that the sup code path is executed if a dimension is + // strictly less than its corresponding threshold. So actually, the + // thresholds specify the minimum dimension size that will still dispatch + // the non-sup/large code path. This "strictly less than" behavior was + // chosen over "less than or equal to" so that threshold values of 0 would + // effectively disable sup (even for matrix dimensions of 0). // s d c z bli_blksz_init_easy( &thresh[ BLIS_MT ], 0, 0, 0, 0 ); bli_blksz_init_easy( &thresh[ BLIS_NT ], 0, 0, 0, 0 ); @@ -486,10 +490,10 @@ void GENBARNAME(cntx_init) gen_func_init( &funcs[ BLIS_RRC ], gemmsup_rv_ukr_name ); gen_func_init( &funcs[ BLIS_RCR ], gemmsup_rv_ukr_name ); gen_func_init( &funcs[ BLIS_RCC ], gemmsup_rv_ukr_name ); - gen_func_init( &funcs[ BLIS_CRR ], gemmsup_cv_ukr_name ); - gen_func_init( &funcs[ BLIS_CRC ], gemmsup_cv_ukr_name ); - gen_func_init( &funcs[ BLIS_CCR ], gemmsup_cv_ukr_name ); - gen_func_init( &funcs[ BLIS_CCC ], gemmsup_cv_ukr_name ); + gen_func_init( &funcs[ BLIS_CRR ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_CRC ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_CCR ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_CCC ], gemmsup_rv_ukr_name ); // Register the general-stride/generic ukernel to the "catch-all" slot // associated with the BLIS_XXX enum value. This slot will be queried if @@ -498,16 +502,17 @@ void GENBARNAME(cntx_init) // Set the l3 sup ukernel storage preferences. - bli_mbool_init( &mbools[ BLIS_RRR ], TRUE, TRUE, TRUE, TRUE ); - bli_mbool_init( &mbools[ BLIS_RRC ], TRUE, TRUE, TRUE, TRUE ); - bli_mbool_init( &mbools[ BLIS_RCR ], TRUE, TRUE, TRUE, TRUE ); - bli_mbool_init( &mbools[ BLIS_RCC ], TRUE, TRUE, TRUE, TRUE ); - bli_mbool_init( &mbools[ BLIS_CRR ], FALSE, FALSE, FALSE, FALSE ); - bli_mbool_init( &mbools[ BLIS_CRC ], FALSE, FALSE, FALSE, FALSE ); - bli_mbool_init( &mbools[ BLIS_CCR ], FALSE, FALSE, FALSE, FALSE ); - bli_mbool_init( &mbools[ BLIS_CCC ], FALSE, FALSE, FALSE, FALSE ); + // s d c z + bli_mbool_init( &mbools[ BLIS_RRR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_RRC ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_RCR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_RCC ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CRR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CRC ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CCR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CCC ], TRUE, TRUE, TRUE, TRUE ); - bli_mbool_init( &mbools[ BLIS_XXX ], FALSE, FALSE, FALSE, FALSE ); + bli_mbool_init( &mbools[ BLIS_XXX ], TRUE, TRUE, TRUE, TRUE ); // -- Set level-1f kernels -------------------------------------------------