From 451bde076f0320d60cd2475cfb048ac4a2b798bb Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Fri, 22 Jul 2016 15:43:00 -0500 Subject: [PATCH] Add some more knobs to twiddle for KNL microkernel. --- kernels/x86_64/knl/3/bli_avx512_macros.h | 2 + kernels/x86_64/knl/3/bli_dgemm_opt_8x24.c | 182 +++++++++++++++++++--- 2 files changed, 162 insertions(+), 22 deletions(-) diff --git a/kernels/x86_64/knl/3/bli_avx512_macros.h b/kernels/x86_64/knl/3/bli_avx512_macros.h index 248d9b546..9ab46eff4 100644 --- a/kernels/x86_64/knl/3/bli_avx512_macros.h +++ b/kernels/x86_64/knl/3/bli_avx512_macros.h @@ -101,6 +101,8 @@ #define JZ(_0) ASM(jz _0) #define JNE(_0) ASM(jne _0) #define JE(_0) ASM(je _0) +#define JNC(_0) ASM(jnc _0) +#define JC(_0) ASM(jc _0) #define JMP(_0) ASM(jmp _0) #define VGATHERDPS(_0, _1) ASM(vgatherdps _1, _0) #define VSCATTERDPS(_0, _1) ASM(vscatterdps _1, _0) diff --git a/kernels/x86_64/knl/3/bli_dgemm_opt_8x24.c b/kernels/x86_64/knl/3/bli_dgemm_opt_8x24.c index c9a6e778f..18da568d0 100644 --- a/kernels/x86_64/knl/3/bli_dgemm_opt_8x24.c +++ b/kernels/x86_64/knl/3/bli_dgemm_opt_8x24.c @@ -39,6 +39,9 @@ extern int32_t offsets[24]; #define A_PREFETCH_DIST 5 +#define PREFETCH_A 1 +#define PIPELINE_A 1 +#define UNROLL_X2 1 #define UPDATE_SCATTERED(n) \ KMOV(K(1), ESI) \ @@ -78,36 +81,48 @@ void bli_dgemm_opt_8x24 //so that both lines are ready in all m_r iterations but the first //use vscatterpfdps to prefetch 12 lines at once VPXORD(ZMM(8), ZMM(8), ZMM(8)) VBROADCASTSS(ZMM(4), VAR(cs_c)) - VMOVAPD(ZMM( 9), ZMM(8)) MOV(RDI, VAR(offsetPtr)) - VMOVAPD(ZMM(10), ZMM(8)) VMOVUPS(ZMM(5), MEM(RDI)) - VMOVAPD(ZMM(11), ZMM(8)) VMOVUPS(ZMM(6), MEM(RDI,12*4)) - VMOVAPD(ZMM(12), ZMM(8)) VPMULLD(ZMM(5), ZMM(5), ZMM(4)) - VMOVAPD(ZMM(13), ZMM(8)) VPMULLD(ZMM(6), ZMM(6), ZMM(4)) - VMOVAPD(ZMM(14), ZMM(8)) MOV(RDX, IMM(0xFFF)) - VMOVAPD(ZMM(15), ZMM(8)) KMOV(K(1), EDX) - VMOVAPD(ZMM(16), ZMM(8)) //KMOV(K(2), EDX) - VMOVAPD(ZMM(17), ZMM(8)) //VSCATTERPFDPS(0, MEM(RCX,ZMM(5),8,0*8) MASK_K(2)) - VMOVAPD(ZMM(18), ZMM(8)) VSCATTERPFDPS(0, MEM(RCX,ZMM(5),8,7*8) MASK_K(1)) - VMOVAPD(ZMM(19), ZMM(8)) //KMOV(K(1), EDX) - VMOVAPD(ZMM(20), ZMM(8)) KMOV(K(2), EDX) - VMOVAPD(ZMM(21), ZMM(8)) //VSCATTERPFDPS(0, MEM(RCX,ZMM(6),8,0*8) MASK_K(1)) - VMOVAPD(ZMM(22), ZMM(8)) VSCATTERPFDPS(0, MEM(RCX,ZMM(6),8,7*8) MASK_K(2)) - VMOVAPD(ZMM(23), ZMM(8)) MOV(RAX, VAR(a)) - VMOVAPD(ZMM(24), ZMM(8)) MOV(RBX, VAR(b)) - VMOVAPD(ZMM(25), ZMM(8)) ADD(RBX, IMM(15*8)) - VMOVAPD(ZMM(26), ZMM(8)) VMOVAPD(ZMM(0), MEM(RAX)) - VMOVAPD(ZMM(27), ZMM(8)) ADD(RAX, IMM(8*8)) - VMOVAPD(ZMM(28), ZMM(8)) MOV(RCX, VAR(c)) - VMOVAPD(ZMM(29), ZMM(8)) MOV(RDI, RCX) - VMOVAPD(ZMM(30), ZMM(8)) MOV(RSI, VAR(k)) + VMOVAPD(ZMM( 9), ZMM(8)) MOV(RCX, VAR(c)) + VMOVAPD(ZMM(10), ZMM(8)) MOV(RDI, VAR(offsetPtr)) + VMOVAPD(ZMM(11), ZMM(8)) VMOVUPS(ZMM(5), MEM(RDI)) + VMOVAPD(ZMM(12), ZMM(8)) VMOVUPS(ZMM(6), MEM(RDI,12*4)) + VMOVAPD(ZMM(13), ZMM(8)) VPMULLD(ZMM(5), ZMM(5), ZMM(4)) + VMOVAPD(ZMM(14), ZMM(8)) VPMULLD(ZMM(6), ZMM(6), ZMM(4)) + VMOVAPD(ZMM(15), ZMM(8)) MOV(RDX, IMM(0xFFF)) + VMOVAPD(ZMM(16), ZMM(8)) KMOV(K(1), EDX) + VMOVAPD(ZMM(17), ZMM(8)) //KMOV(K(2), EDX) + VMOVAPD(ZMM(18), ZMM(8)) //VSCATTERPFDPS(0, MEM(RCX,ZMM(5),8,0*8) MASK_K(2)) + VMOVAPD(ZMM(19), ZMM(8)) VSCATTERPFDPS(0, MEM(RCX,ZMM(5),8,7*8) MASK_K(1)) + VMOVAPD(ZMM(20), ZMM(8)) //KMOV(K(1), EDX) + VMOVAPD(ZMM(21), ZMM(8)) KMOV(K(2), EDX) + VMOVAPD(ZMM(22), ZMM(8)) //VSCATTERPFDPS(0, MEM(RCX,ZMM(6),8,0*8) MASK_K(1)) + VMOVAPD(ZMM(23), ZMM(8)) VSCATTERPFDPS(0, MEM(RCX,ZMM(6),8,7*8) MASK_K(2)) + VMOVAPD(ZMM(24), ZMM(8)) MOV(RAX, VAR(a)) + VMOVAPD(ZMM(25), ZMM(8)) MOV(RBX, VAR(b)) + VMOVAPD(ZMM(26), ZMM(8)) ADD(RBX, IMM(15*8)) + VMOVAPD(ZMM(27), ZMM(8)) VMOVAPD(ZMM(0), MEM(RAX)) + VMOVAPD(ZMM(28), ZMM(8)) ADD(RAX, IMM(8*8)) + VMOVAPD(ZMM(29), ZMM(8)) MOV(RSI, VAR(k)) + VMOVAPD(ZMM(30), ZMM(8)) VMOVAPD(ZMM(31), ZMM(8)) + TEST(RSI, RSI) + JZ(.DPOSTACCUM) + +#if !UNROLL_X2 || !PIPELINE_A + ALIGN32 LABEL(.DLOOPKITER) +#if PREFETCH_A PREFETCH(0, MEM(RAX,A_PREFETCH_DIST*8*8)) +#endif +#if PIPELINE_A VMOVAPD(ZMM(1), MEM(RAX)) +#else + VMOVAPD(ZMM(0), MEM(RAX)) +#endif + VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(RBX,-15*8)) VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(RBX,-14*8)) VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(RBX,-13*8)) @@ -132,7 +147,10 @@ void bli_dgemm_opt_8x24 VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(RBX, 6*8)) VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(RBX, 7*8)) VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(RBX, 8*8)) + +#if PIPELINE_A VMOVAPD(ZMM(0), ZMM(1)) +#endif ADD(RAX, IMM(8*8)) ADD(RBX, IMM(24*8)) @@ -140,6 +158,126 @@ void bli_dgemm_opt_8x24 SUB(RSI, IMM(1)) JNZ(.DLOOPKITER) +#else // UNROLL_X2 && PIPELINE_A + + SAR1(RSI) // k -> k/2, jump to .DEXTRAITER if k was odd + JC(.DEXTRAITER) + + ALIGN32 + LABEL(.DLOOPKITER) + +#if PREFETCH_A + PREFETCH(0, MEM(RAX,A_PREFETCH_DIST*8*8)) +#endif + VMOVAPD(ZMM(1), MEM(RAX)) + + VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(RBX,-15*8)) + VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(RBX,-14*8)) + VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(RBX,-13*8)) + VFMADD231PD(ZMM(11), ZMM(0), MEM_1TO8(RBX,-12*8)) + VFMADD231PD(ZMM(12), ZMM(0), MEM_1TO8(RBX,-11*8)) + VFMADD231PD(ZMM(13), ZMM(0), MEM_1TO8(RBX,-10*8)) + VFMADD231PD(ZMM(14), ZMM(0), MEM_1TO8(RBX, -9*8)) + VFMADD231PD(ZMM(15), ZMM(0), MEM_1TO8(RBX, -8*8)) + VFMADD231PD(ZMM(16), ZMM(0), MEM_1TO8(RBX, -7*8)) + VFMADD231PD(ZMM(17), ZMM(0), MEM_1TO8(RBX, -6*8)) + VFMADD231PD(ZMM(18), ZMM(0), MEM_1TO8(RBX, -5*8)) + VFMADD231PD(ZMM(19), ZMM(0), MEM_1TO8(RBX, -4*8)) + VFMADD231PD(ZMM(20), ZMM(0), MEM_1TO8(RBX, -3*8)) + VFMADD231PD(ZMM(21), ZMM(0), MEM_1TO8(RBX, -2*8)) + VFMADD231PD(ZMM(22), ZMM(0), MEM_1TO8(RBX, -1*8)) + VFMADD231PD(ZMM(23), ZMM(0), MEM_1TO8(RBX, 0*8)) + VFMADD231PD(ZMM(24), ZMM(0), MEM_1TO8(RBX, 1*8)) + VFMADD231PD(ZMM(25), ZMM(0), MEM_1TO8(RBX, 2*8)) + VFMADD231PD(ZMM(26), ZMM(0), MEM_1TO8(RBX, 3*8)) + VFMADD231PD(ZMM(27), ZMM(0), MEM_1TO8(RBX, 4*8)) + VFMADD231PD(ZMM(28), ZMM(0), MEM_1TO8(RBX, 5*8)) + VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(RBX, 6*8)) + VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(RBX, 7*8)) + VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(RBX, 8*8)) + +#if PREFETCH_A + PREFETCH(0, MEM(RAX,A_PREFETCH_DIST*8*8)) +#endif + ADD(RBX, IMM(24*8)) + VMOVAPD(ZMM(0), MEM(RAX,8*8)) + + VFMADD231PD(ZMM( 8), ZMM(1), MEM_1TO8(RBX,-15*8)) + VFMADD231PD(ZMM( 9), ZMM(1), MEM_1TO8(RBX,-14*8)) + VFMADD231PD(ZMM(10), ZMM(1), MEM_1TO8(RBX,-13*8)) + VFMADD231PD(ZMM(11), ZMM(1), MEM_1TO8(RBX,-12*8)) + VFMADD231PD(ZMM(12), ZMM(1), MEM_1TO8(RBX,-11*8)) + VFMADD231PD(ZMM(13), ZMM(1), MEM_1TO8(RBX,-10*8)) + VFMADD231PD(ZMM(14), ZMM(1), MEM_1TO8(RBX, -9*8)) + VFMADD231PD(ZMM(15), ZMM(1), MEM_1TO8(RBX, -8*8)) + VFMADD231PD(ZMM(16), ZMM(1), MEM_1TO8(RBX, -7*8)) + VFMADD231PD(ZMM(17), ZMM(1), MEM_1TO8(RBX, -6*8)) + VFMADD231PD(ZMM(18), ZMM(1), MEM_1TO8(RBX, -5*8)) + VFMADD231PD(ZMM(19), ZMM(1), MEM_1TO8(RBX, -4*8)) + VFMADD231PD(ZMM(20), ZMM(1), MEM_1TO8(RBX, -3*8)) + VFMADD231PD(ZMM(21), ZMM(1), MEM_1TO8(RBX, -2*8)) + VFMADD231PD(ZMM(22), ZMM(1), MEM_1TO8(RBX, -1*8)) + VFMADD231PD(ZMM(23), ZMM(1), MEM_1TO8(RBX, 0*8)) + VFMADD231PD(ZMM(24), ZMM(1), MEM_1TO8(RBX, 1*8)) + VFMADD231PD(ZMM(25), ZMM(1), MEM_1TO8(RBX, 2*8)) + VFMADD231PD(ZMM(26), ZMM(1), MEM_1TO8(RBX, 3*8)) + VFMADD231PD(ZMM(27), ZMM(1), MEM_1TO8(RBX, 4*8)) + VFMADD231PD(ZMM(28), ZMM(1), MEM_1TO8(RBX, 5*8)) + VFMADD231PD(ZMM(29), ZMM(1), MEM_1TO8(RBX, 6*8)) + VFMADD231PD(ZMM(30), ZMM(1), MEM_1TO8(RBX, 7*8)) + VFMADD231PD(ZMM(31), ZMM(1), MEM_1TO8(RBX, 8*8)) + + ADD(RAX, IMM(2*8*8)) + ADD(RBX, IMM(24*8)) + + SUB(RSI, IMM(1)) + JNZ(.DLOOPKITER) + + JMP(.DPOSTACCUM) + + LABEL(.DEXTRAITER) + +#if PREFETCH_A + PREFETCH(0, MEM(RAX,A_PREFETCH_DIST*8*8)) +#endif + VMOVAPD(ZMM(1), MEM(RAX)) + + VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(RBX,-15*8)) + VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(RBX,-14*8)) + VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(RBX,-13*8)) + VFMADD231PD(ZMM(11), ZMM(0), MEM_1TO8(RBX,-12*8)) + VFMADD231PD(ZMM(12), ZMM(0), MEM_1TO8(RBX,-11*8)) + VFMADD231PD(ZMM(13), ZMM(0), MEM_1TO8(RBX,-10*8)) + VFMADD231PD(ZMM(14), ZMM(0), MEM_1TO8(RBX, -9*8)) + VFMADD231PD(ZMM(15), ZMM(0), MEM_1TO8(RBX, -8*8)) + VFMADD231PD(ZMM(16), ZMM(0), MEM_1TO8(RBX, -7*8)) + VFMADD231PD(ZMM(17), ZMM(0), MEM_1TO8(RBX, -6*8)) + VFMADD231PD(ZMM(18), ZMM(0), MEM_1TO8(RBX, -5*8)) + VFMADD231PD(ZMM(19), ZMM(0), MEM_1TO8(RBX, -4*8)) + VFMADD231PD(ZMM(20), ZMM(0), MEM_1TO8(RBX, -3*8)) + VFMADD231PD(ZMM(21), ZMM(0), MEM_1TO8(RBX, -2*8)) + VFMADD231PD(ZMM(22), ZMM(0), MEM_1TO8(RBX, -1*8)) + VFMADD231PD(ZMM(23), ZMM(0), MEM_1TO8(RBX, 0*8)) + VFMADD231PD(ZMM(24), ZMM(0), MEM_1TO8(RBX, 1*8)) + VFMADD231PD(ZMM(25), ZMM(0), MEM_1TO8(RBX, 2*8)) + VFMADD231PD(ZMM(26), ZMM(0), MEM_1TO8(RBX, 3*8)) + VFMADD231PD(ZMM(27), ZMM(0), MEM_1TO8(RBX, 4*8)) + VFMADD231PD(ZMM(28), ZMM(0), MEM_1TO8(RBX, 5*8)) + VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(RBX, 6*8)) + VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(RBX, 7*8)) + VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(RBX, 8*8)) + + VMOVAPD(ZMM(0), ZMM(1)) + ADD(RAX, IMM(8*8)) + ADD(RBX, IMM(24*8)) + + TEST(RSI, RSI) + JNZ(.DLOOPKITER) + +#endif + + LABEL(.DPOSTACCUM) + MOV(RAX, VAR(alpha)) MOV(RBX, VAR(beta)) VBROADCASTSD(ZMM(0), MEM(RAX))