From 8ff2e069c48c12fd06b9c48c6b3aeb4ea9b0e6e1 Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Fri, 22 Jul 2016 16:22:26 -0500 Subject: [PATCH] Add 4x unrolled variant for KNL microkernel. --- kernels/x86_64/knl/3/bli_dgemm_opt_8x24.c | 196 +++++++++++++++++++++- 1 file changed, 192 insertions(+), 4 deletions(-) diff --git a/kernels/x86_64/knl/3/bli_dgemm_opt_8x24.c b/kernels/x86_64/knl/3/bli_dgemm_opt_8x24.c index 45926311e..f12ba4af6 100644 --- a/kernels/x86_64/knl/3/bli_dgemm_opt_8x24.c +++ b/kernels/x86_64/knl/3/bli_dgemm_opt_8x24.c @@ -41,7 +41,8 @@ extern int32_t offsets[24]; #define A_PREFETCH_DIST 5 #define PREFETCH_A 1 #define PIPELINE_A 1 -#define UNROLL_X2 1 +#define UNROLL_X2 0 +#define UNROLL_X4 1 #define UPDATE_SCATTERED(n) \ KMOV(K(1), ESI) \ @@ -108,7 +109,7 @@ void bli_dgemm_opt_8x24 TEST(RSI, RSI) JZ(.DPOSTACCUM) -#if !UNROLL_X2 || !PIPELINE_A +#if !(UNROLL_X2 || UNROLL_X4) || !PIPELINE_A ALIGN32 LABEL(.DLOOPKITER) @@ -158,11 +159,13 @@ void bli_dgemm_opt_8x24 SUB(RSI, IMM(1)) JNZ(.DLOOPKITER) -#else // UNROLL_X2 && PIPELINE_A +#elif UNROLL_X2 SAR1(RSI) // k -> k/2, jump to .DEXTRAITER if k was odd JC(.DEXTRAITER) + LABEL(.DMAINLOOP) + MOV(RDI, IMM(24*8)) ALIGN32 @@ -199,7 +202,7 @@ void bli_dgemm_opt_8x24 VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(RBX, 8*8)) #if PREFETCH_A - PREFETCH(0, MEM(RAX,A_PREFETCH_DIST*8*8)) + PREFETCH(0, MEM(RAX,A_PREFETCH_DIST*8*8+8*8)) #endif VMOVAPD(ZMM(0), MEM(RAX,8*8)) @@ -273,8 +276,193 @@ void bli_dgemm_opt_8x24 ADD(RBX, IMM(24*8)) TEST(RSI, RSI) + JNZ(.DMAINLOOP) + +#elif UNROLL_X4 + + MOV(RDI, RSI) + SAR(RSI, IMM(2)) // k/4 + AND(RDI, IMM(3)) // k%4 + JNZ(.DEXTRALOOP) + + LABEL(.DMAINLOOP) + + MOV(RDI, IMM(24*8)) + LEA(RDX, MEM(RDI,RDI,2)) + + ALIGN32 + LABEL(.DLOOPKITER) + +#if PREFETCH_A + PREFETCH(0, MEM(RAX,A_PREFETCH_DIST*8*8)) +#endif + VMOVAPD(ZMM(1), MEM(RAX)) + + VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(RBX,-15*8)) + VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(RBX,-14*8)) + VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(RBX,-13*8)) + VFMADD231PD(ZMM(11), ZMM(0), MEM_1TO8(RBX,-12*8)) + VFMADD231PD(ZMM(12), ZMM(0), MEM_1TO8(RBX,-11*8)) + VFMADD231PD(ZMM(13), ZMM(0), MEM_1TO8(RBX,-10*8)) + VFMADD231PD(ZMM(14), ZMM(0), MEM_1TO8(RBX, -9*8)) + VFMADD231PD(ZMM(15), ZMM(0), MEM_1TO8(RBX, -8*8)) + VFMADD231PD(ZMM(16), ZMM(0), MEM_1TO8(RBX, -7*8)) + VFMADD231PD(ZMM(17), ZMM(0), MEM_1TO8(RBX, -6*8)) + VFMADD231PD(ZMM(18), ZMM(0), MEM_1TO8(RBX, -5*8)) + VFMADD231PD(ZMM(19), ZMM(0), MEM_1TO8(RBX, -4*8)) + VFMADD231PD(ZMM(20), ZMM(0), MEM_1TO8(RBX, -3*8)) + VFMADD231PD(ZMM(21), ZMM(0), MEM_1TO8(RBX, -2*8)) + VFMADD231PD(ZMM(22), ZMM(0), MEM_1TO8(RBX, -1*8)) + VFMADD231PD(ZMM(23), ZMM(0), MEM_1TO8(RBX, 0*8)) + VFMADD231PD(ZMM(24), ZMM(0), MEM_1TO8(RBX, 1*8)) + VFMADD231PD(ZMM(25), ZMM(0), MEM_1TO8(RBX, 2*8)) + VFMADD231PD(ZMM(26), ZMM(0), MEM_1TO8(RBX, 3*8)) + VFMADD231PD(ZMM(27), ZMM(0), MEM_1TO8(RBX, 4*8)) + VFMADD231PD(ZMM(28), ZMM(0), MEM_1TO8(RBX, 5*8)) + VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(RBX, 6*8)) + VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(RBX, 7*8)) + VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(RBX, 8*8)) + +#if PREFETCH_A + PREFETCH(0, MEM(RAX,A_PREFETCH_DIST*8*8+8*8)) +#endif + VMOVAPD(ZMM(0), MEM(RAX,8*8)) + + VFMADD231PD(ZMM( 8), ZMM(1), MEM_1TO8(RBX,RDI,1,-15*8)) + VFMADD231PD(ZMM( 9), ZMM(1), MEM_1TO8(RBX,RDI,1,-14*8)) + VFMADD231PD(ZMM(10), ZMM(1), MEM_1TO8(RBX,RDI,1,-13*8)) + VFMADD231PD(ZMM(11), ZMM(1), MEM_1TO8(RBX,RDI,1,-12*8)) + VFMADD231PD(ZMM(12), ZMM(1), MEM_1TO8(RBX,RDI,1,-11*8)) + VFMADD231PD(ZMM(13), ZMM(1), MEM_1TO8(RBX,RDI,1,-10*8)) + VFMADD231PD(ZMM(14), ZMM(1), MEM_1TO8(RBX,RDI,1, -9*8)) + VFMADD231PD(ZMM(15), ZMM(1), MEM_1TO8(RBX,RDI,1, -8*8)) + VFMADD231PD(ZMM(16), ZMM(1), MEM_1TO8(RBX,RDI,1, -7*8)) + VFMADD231PD(ZMM(17), ZMM(1), MEM_1TO8(RBX,RDI,1, -6*8)) + VFMADD231PD(ZMM(18), ZMM(1), MEM_1TO8(RBX,RDI,1, -5*8)) + VFMADD231PD(ZMM(19), ZMM(1), MEM_1TO8(RBX,RDI,1, -4*8)) + VFMADD231PD(ZMM(20), ZMM(1), MEM_1TO8(RBX,RDI,1, -3*8)) + VFMADD231PD(ZMM(21), ZMM(1), MEM_1TO8(RBX,RDI,1, -2*8)) + VFMADD231PD(ZMM(22), ZMM(1), MEM_1TO8(RBX,RDI,1, -1*8)) + VFMADD231PD(ZMM(23), ZMM(1), MEM_1TO8(RBX,RDI,1, 0*8)) + VFMADD231PD(ZMM(24), ZMM(1), MEM_1TO8(RBX,RDI,1, 1*8)) + VFMADD231PD(ZMM(25), ZMM(1), MEM_1TO8(RBX,RDI,1, 2*8)) + VFMADD231PD(ZMM(26), ZMM(1), MEM_1TO8(RBX,RDI,1, 3*8)) + VFMADD231PD(ZMM(27), ZMM(1), MEM_1TO8(RBX,RDI,1, 4*8)) + VFMADD231PD(ZMM(28), ZMM(1), MEM_1TO8(RBX,RDI,1, 5*8)) + VFMADD231PD(ZMM(29), ZMM(1), MEM_1TO8(RBX,RDI,1, 6*8)) + VFMADD231PD(ZMM(30), ZMM(1), MEM_1TO8(RBX,RDI,1, 7*8)) + VFMADD231PD(ZMM(31), ZMM(1), MEM_1TO8(RBX,RDI,1, 8*8)) + +#if PREFETCH_A + PREFETCH(0, MEM(RAX,A_PREFETCH_DIST*8*8+2*8*8)) +#endif + VMOVAPD(ZMM(1), MEM(RAX,2*8*8)) + + VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(RBX,RDI,2,-15*8)) + VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(RBX,RDI,2,-14*8)) + VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(RBX,RDI,2,-13*8)) + VFMADD231PD(ZMM(11), ZMM(0), MEM_1TO8(RBX,RDI,2,-12*8)) + VFMADD231PD(ZMM(12), ZMM(0), MEM_1TO8(RBX,RDI,2,-11*8)) + VFMADD231PD(ZMM(13), ZMM(0), MEM_1TO8(RBX,RDI,2,-10*8)) + VFMADD231PD(ZMM(14), ZMM(0), MEM_1TO8(RBX,RDI,2, -9*8)) + VFMADD231PD(ZMM(15), ZMM(0), MEM_1TO8(RBX,RDI,2, -8*8)) + VFMADD231PD(ZMM(16), ZMM(0), MEM_1TO8(RBX,RDI,2, -7*8)) + VFMADD231PD(ZMM(17), ZMM(0), MEM_1TO8(RBX,RDI,2, -6*8)) + VFMADD231PD(ZMM(18), ZMM(0), MEM_1TO8(RBX,RDI,2, -5*8)) + VFMADD231PD(ZMM(19), ZMM(0), MEM_1TO8(RBX,RDI,2, -4*8)) + VFMADD231PD(ZMM(20), ZMM(0), MEM_1TO8(RBX,RDI,2, -3*8)) + VFMADD231PD(ZMM(21), ZMM(0), MEM_1TO8(RBX,RDI,2, -2*8)) + VFMADD231PD(ZMM(22), ZMM(0), MEM_1TO8(RBX,RDI,2, -1*8)) + VFMADD231PD(ZMM(23), ZMM(0), MEM_1TO8(RBX,RDI,2, 0*8)) + VFMADD231PD(ZMM(24), ZMM(0), MEM_1TO8(RBX,RDI,2, 1*8)) + VFMADD231PD(ZMM(25), ZMM(0), MEM_1TO8(RBX,RDI,2, 2*8)) + VFMADD231PD(ZMM(26), ZMM(0), MEM_1TO8(RBX,RDI,2, 3*8)) + VFMADD231PD(ZMM(27), ZMM(0), MEM_1TO8(RBX,RDI,2, 4*8)) + VFMADD231PD(ZMM(28), ZMM(0), MEM_1TO8(RBX,RDI,2, 5*8)) + VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(RBX,RDI,2, 6*8)) + VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(RBX,RDI,2, 7*8)) + VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(RBX,RDI,2, 8*8)) + +#if PREFETCH_A + PREFETCH(0, MEM(RAX,A_PREFETCH_DIST*8*8+3*8*8)) +#endif + VMOVAPD(ZMM(0), MEM(RAX,3*8*8)) + + VFMADD231PD(ZMM( 8), ZMM(1), MEM_1TO8(RBX,RDX,1,-15*8)) + VFMADD231PD(ZMM( 9), ZMM(1), MEM_1TO8(RBX,RDX,1,-14*8)) + VFMADD231PD(ZMM(10), ZMM(1), MEM_1TO8(RBX,RDX,1,-13*8)) + VFMADD231PD(ZMM(11), ZMM(1), MEM_1TO8(RBX,RDX,1,-12*8)) + VFMADD231PD(ZMM(12), ZMM(1), MEM_1TO8(RBX,RDX,1,-11*8)) + VFMADD231PD(ZMM(13), ZMM(1), MEM_1TO8(RBX,RDX,1,-10*8)) + VFMADD231PD(ZMM(14), ZMM(1), MEM_1TO8(RBX,RDX,1, -9*8)) + VFMADD231PD(ZMM(15), ZMM(1), MEM_1TO8(RBX,RDX,1, -8*8)) + VFMADD231PD(ZMM(16), ZMM(1), MEM_1TO8(RBX,RDX,1, -7*8)) + VFMADD231PD(ZMM(17), ZMM(1), MEM_1TO8(RBX,RDX,1, -6*8)) + VFMADD231PD(ZMM(18), ZMM(1), MEM_1TO8(RBX,RDX,1, -5*8)) + VFMADD231PD(ZMM(19), ZMM(1), MEM_1TO8(RBX,RDX,1, -4*8)) + VFMADD231PD(ZMM(20), ZMM(1), MEM_1TO8(RBX,RDX,1, -3*8)) + VFMADD231PD(ZMM(21), ZMM(1), MEM_1TO8(RBX,RDX,1, -2*8)) + VFMADD231PD(ZMM(22), ZMM(1), MEM_1TO8(RBX,RDX,1, -1*8)) + VFMADD231PD(ZMM(23), ZMM(1), MEM_1TO8(RBX,RDX,1, 0*8)) + VFMADD231PD(ZMM(24), ZMM(1), MEM_1TO8(RBX,RDX,1, 1*8)) + VFMADD231PD(ZMM(25), ZMM(1), MEM_1TO8(RBX,RDX,1, 2*8)) + VFMADD231PD(ZMM(26), ZMM(1), MEM_1TO8(RBX,RDX,1, 3*8)) + VFMADD231PD(ZMM(27), ZMM(1), MEM_1TO8(RBX,RDX,1, 4*8)) + VFMADD231PD(ZMM(28), ZMM(1), MEM_1TO8(RBX,RDX,1, 5*8)) + VFMADD231PD(ZMM(29), ZMM(1), MEM_1TO8(RBX,RDX,1, 6*8)) + VFMADD231PD(ZMM(30), ZMM(1), MEM_1TO8(RBX,RDX,1, 7*8)) + VFMADD231PD(ZMM(31), ZMM(1), MEM_1TO8(RBX,RDX,1, 8*8)) + + ADD(RAX, IMM(4*8*8)) + ADD(RBX, IMM(4*24*8)) + + SUB(RSI, IMM(1)) JNZ(.DLOOPKITER) + JMP(.DPOSTACCUM) + + LABEL(.DEXTRALOOP) + +#if PREFETCH_A + PREFETCH(0, MEM(RAX,A_PREFETCH_DIST*8*8)) +#endif + VMOVAPD(ZMM(1), MEM(RAX)) + + VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(RBX,-15*8)) + VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(RBX,-14*8)) + VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(RBX,-13*8)) + VFMADD231PD(ZMM(11), ZMM(0), MEM_1TO8(RBX,-12*8)) + VFMADD231PD(ZMM(12), ZMM(0), MEM_1TO8(RBX,-11*8)) + VFMADD231PD(ZMM(13), ZMM(0), MEM_1TO8(RBX,-10*8)) + VFMADD231PD(ZMM(14), ZMM(0), MEM_1TO8(RBX, -9*8)) + VFMADD231PD(ZMM(15), ZMM(0), MEM_1TO8(RBX, -8*8)) + VFMADD231PD(ZMM(16), ZMM(0), MEM_1TO8(RBX, -7*8)) + VFMADD231PD(ZMM(17), ZMM(0), MEM_1TO8(RBX, -6*8)) + VFMADD231PD(ZMM(18), ZMM(0), MEM_1TO8(RBX, -5*8)) + VFMADD231PD(ZMM(19), ZMM(0), MEM_1TO8(RBX, -4*8)) + VFMADD231PD(ZMM(20), ZMM(0), MEM_1TO8(RBX, -3*8)) + VFMADD231PD(ZMM(21), ZMM(0), MEM_1TO8(RBX, -2*8)) + VFMADD231PD(ZMM(22), ZMM(0), MEM_1TO8(RBX, -1*8)) + VFMADD231PD(ZMM(23), ZMM(0), MEM_1TO8(RBX, 0*8)) + VFMADD231PD(ZMM(24), ZMM(0), MEM_1TO8(RBX, 1*8)) + VFMADD231PD(ZMM(25), ZMM(0), MEM_1TO8(RBX, 2*8)) + VFMADD231PD(ZMM(26), ZMM(0), MEM_1TO8(RBX, 3*8)) + VFMADD231PD(ZMM(27), ZMM(0), MEM_1TO8(RBX, 4*8)) + VFMADD231PD(ZMM(28), ZMM(0), MEM_1TO8(RBX, 5*8)) + VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(RBX, 6*8)) + VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(RBX, 7*8)) + VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(RBX, 8*8)) + + VMOVAPD(ZMM(0), ZMM(1)) + ADD(RAX, IMM(8*8)) + ADD(RBX, IMM(24*8)) + + SUB(RDI, IMM(1)) + JNZ(.DEXTRALOOP) + + TEST(RSI, RSI) + JNZ(.DMAINLOOP) + #endif LABEL(.DPOSTACCUM)