diff --git a/kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c b/kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c index d1cdcbc65..8f07d067b 100644 --- a/kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c +++ b/kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c @@ -39,12 +39,15 @@ #define UNROLL_K 32 +#define SCATTER_PREFETCH_AB 0 +#define SCATTER_PREFETCH_C 1 + #define PREFETCH_A_L2 0 #define PREFETCH_B_L2 0 -#define L2_PREFETCH_DIST 16 +#define L2_PREFETCH_DIST 64 -#define A_L1_PREFETCH_DIST 4 -#define B_L1_PREFETCH_DIST 2 +#define A_L1_PREFETCH_DIST 10 +#define B_L1_PREFETCH_DIST 30 #define C_MIN_L2_ITERS 40 //C is not prefetched into L2 for k <= this #define C_L1_ITERS 16 //number of iterations before the end to prefetch C into L1 @@ -97,11 +100,9 @@ VSCATTERDPD(MEM(RCX,YMM(2),8) MASK_K(1), ZMM(NUM)) \ ADD(RCX, RAX) -#define PREFETCH_A_L1(n) \ -\ - PREFETCH(0, MEM(RAX,(A_L1_PREFETCH_DIST+n)*24*8)) \ - PREFETCH(0, MEM(RAX,(A_L1_PREFETCH_DIST+n)*24*8+64)) \ - PREFETCH(0, MEM(RAX,(A_L1_PREFETCH_DIST+n)*24*8+128)) +#define PREFETCH_A_L1_1(n) PREFETCH(0, MEM(RAX,(A_L1_PREFETCH_DIST+n)*24*8)) +#define PREFETCH_A_L1_2(n) PREFETCH(0, MEM(RAX,(A_L1_PREFETCH_DIST+n)*24*8+64)) +#define PREFETCH_A_L1_3(n) PREFETCH(0, MEM(RAX,(A_L1_PREFETCH_DIST+n)*24*8+128)) #if PREFETCH_A_L2 #undef PREFETCH_A_L2 @@ -129,6 +130,36 @@ #define PREFETCH_B_L2(...) #endif +#if SCATTER_PREFETCH_AB +#undef SCATTER_PREFETCH_AB +#undef PREFETCH_A_L1_1 +#undef PREFETCH_A_L1_2 +#undef PREFETCH_A_L1_3 +#undef PREFETCH_B_L1 + +#define SCATTER_PREFETCH_AB(n) \ +\ + KXNORW(K(1), K(0), K(0)) \ + VGATHERPFDPS(0, MEM(RAX,ZMM(4),8,((3*n )*16+3*A_L1_PREFETCH_DIST)*64) MASK_K(1)) \ + KXNORW(K(2), K(0), K(0)) \ + VGATHERPFDPS(0, MEM(RAX,ZMM(4),8,((3*n+1)*16+3*A_L1_PREFETCH_DIST)*64) MASK_K(2)) \ + KXNORW(K(3), K(0), K(0)) \ + VGATHERPFDPS(0, MEM(RAX,ZMM(4),8,((3*n+2)*16+3*A_L1_PREFETCH_DIST)*64) MASK_K(3)) \ + KXNORW(K(4), K(0), K(0)) \ + VGATHERPFDPS(0, MEM(RBX,ZMM(4),8,( n *16+ B_L1_PREFETCH_DIST)*64) MASK_K(4)) + +#define PREFETCH_A_L1_1(...) +#define PREFETCH_A_L1_2(...) +#define PREFETCH_A_L1_3(...) +#define PREFETCH_B_L1(...) + +#else +#undef SCATTER_PREFETCH_AB + +#define SCATTER_PREFETCH_AB(...) + +#endif + // // n: index in unrolled loop (for prefetching offsets) // @@ -139,37 +170,38 @@ // #define SUBITER(n,a,b,...) \ \ - VMOVAPD(ZMM(0), MEM(RBX,(n+1-1)*64)) \ -\ - PREFETCH_A_L1(n) \ - PREFETCH_B_L1(n) \ PREFETCH_A_L2(n) \ - PREFETCH_B_L2(n) \ \ - VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 0)*8)) \ - VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 1)*8)) \ - VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 2)*8)) \ - VFMADD231PD(ZMM(11), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 3)*8)) \ - VFMADD231PD(ZMM(12), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 4)*8)) \ - VFMADD231PD(ZMM(13), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 5)*8)) \ - VFMADD231PD(ZMM(14), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 6)*8)) \ - VFMADD231PD(ZMM(15), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 7)*8)) \ - VFMADD231PD(ZMM(16), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 8)*8)) \ - VFMADD231PD(ZMM(17), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 9)*8)) \ - VFMADD231PD(ZMM(18), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+10)*8)) \ - VFMADD231PD(ZMM(19), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+11)*8)) \ - VFMADD231PD(ZMM(20), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+12)*8)) \ - VFMADD231PD(ZMM(21), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+13)*8)) \ - VFMADD231PD(ZMM(22), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+14)*8)) \ - VFMADD231PD(ZMM(23), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+15)*8)) \ - VFMADD231PD(ZMM(24), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+16)*8)) \ - VFMADD231PD(ZMM(25), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+17)*8)) \ - VFMADD231PD(ZMM(26), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+18)*8)) \ - VFMADD231PD(ZMM(27), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+19)*8)) \ - VFMADD231PD(ZMM(28), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+20)*8)) \ - VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+21)*8)) \ - VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+22)*8)) \ - VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+23)*8)) + VMOVAPD(ZMM(a), MEM(RBX,(n+1)*64)) \ + VFMADD231PD(ZMM( 8), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 0)*8)) \ + VFMADD231PD(ZMM( 9), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 1)*8)) \ + VFMADD231PD(ZMM(10), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 2)*8)) \ + VFMADD231PD(ZMM(11), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 3)*8)) \ + PREFETCH_A_L1_1(n) \ + VFMADD231PD(ZMM(12), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 4)*8)) \ + VFMADD231PD(ZMM(13), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 5)*8)) \ + VFMADD231PD(ZMM(14), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 6)*8)) \ + VFMADD231PD(ZMM(15), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 7)*8)) \ + PREFETCH_A_L1_2(n) \ + VFMADD231PD(ZMM(16), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 8)*8)) \ + VFMADD231PD(ZMM(17), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 9)*8)) \ + VFMADD231PD(ZMM(18), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+10)*8)) \ + VFMADD231PD(ZMM(19), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+11)*8)) \ + PREFETCH_A_L1_3(n) \ + VFMADD231PD(ZMM(20), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+12)*8)) \ + VFMADD231PD(ZMM(21), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+13)*8)) \ + VFMADD231PD(ZMM(22), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+14)*8)) \ + VFMADD231PD(ZMM(23), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+15)*8)) \ + PREFETCH_B_L1(n) \ + VFMADD231PD(ZMM(24), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+16)*8)) \ + VFMADD231PD(ZMM(25), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+17)*8)) \ + VFMADD231PD(ZMM(26), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+18)*8)) \ + VFMADD231PD(ZMM(27), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+19)*8)) \ + PREFETCH_B_L2(n) \ + VFMADD231PD(ZMM(28), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+20)*8)) \ + VFMADD231PD(ZMM(29), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+21)*8)) \ + VFMADD231PD(ZMM(30), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+22)*8)) \ + VFMADD231PD(ZMM(31), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+23)*8)) #define TAIL_LOOP(NAME) \ \ @@ -308,6 +340,8 @@ \ LOOP_ALIGN \ LABEL(NAME##_LOOP) \ +\ + SCATTER_PREFETCH_AB(0) \ \ SUBITER( 0,1,0,RAX) \ SUBITER( 1,0,1,RAX) \ @@ -335,6 +369,8 @@ \ TEST(RDI, RDI) \ JZ(NAME##_DONE) \ +\ + SCATTER_PREFETCH_AB(0) \ \ TAIL_LOOP(NAME##_TAIL) \ \ @@ -349,6 +385,8 @@ \ LOOP_ALIGN \ LABEL(NAME##_LOOP) \ +\ + SCATTER_PREFETCH_AB(0) \ \ SUBITER( 0,1,0,RAX) \ SUBITER( 1,0,1,RAX) \ @@ -366,6 +404,9 @@ SUBITER(13,0,1,RAX,R9,1) \ SUBITER(14,1,0,RAX,R9,1) \ SUBITER(15,0,1,RAX,R9,1) \ +\ + SCATTER_PREFETCH_AB(1) \ +\ SUBITER(16,1,0,RAX,R8,4) \ SUBITER(17,0,1,RAX,R8,4) \ SUBITER(18,1,0,RAX,R8,4) \ @@ -392,6 +433,9 @@ \ TEST(RDI, RDI) \ JZ(NAME##_DONE) \ +\ + SCATTER_PREFETCH_AB(0) \ + SCATTER_PREFETCH_AB(1) \ \ TAIL_LOOP(NAME##_TAIL) \ \ @@ -445,19 +489,29 @@ void bli_dgemm_opt_24x8( VMOVAPS(ZMM( 9), ZMM(8)) VMOVAPS(ZMM(10), ZMM(8)) MOV(RSI, VAR(k)) //loop index VMOVAPS(ZMM(11), ZMM(8)) MOV(RAX, VAR(a)) //load address of a - VMOVAPS(ZMM(12), ZMM(8)) - VMOVAPS(ZMM(13), ZMM(8)) MOV(RBX, VAR(b)) //load address of b + VMOVAPS(ZMM(12), ZMM(8)) MOV(RBX, VAR(b)) //load address of b + VMOVAPS(ZMM(13), ZMM(8)) MOV(RCX, VAR(c)) //load address of c VMOVAPS(ZMM(14), ZMM(8)) VMOVAPD(ZMM(0), MEM(RBX)) //pre-load b - VMOVAPS(ZMM(15), ZMM(8)) - VMOVAPS(ZMM(16), ZMM(8)) MOV(RCX, VAR(c)) //load address of c - VMOVAPS(ZMM(17), ZMM(8)) //set up indexing information for prefetching C - VMOVAPS(ZMM(18), ZMM(8)) MOV(RDI, VAR(offsetPtr)) - VMOVAPS(ZMM(19), ZMM(8)) VBROADCASTSS(ZMM(4), VAR(rs_c)) - VMOVAPS(ZMM(20), ZMM(8)) VMOVAPS(ZMM(2), MEM(RDI)) //at this point zmm2 contains (0...15) - VMOVAPS(ZMM(21), ZMM(8)) VPMULLD(ZMM(2), ZMM(2), ZMM(4)) //and now zmm2 contains (rs_c*0...15) - VMOVAPS(ZMM(22), ZMM(8)) VMOVAPS(YMM(3), MEM(RDI,64)) //at this point ymm3 contains (16...23) - VMOVAPS(ZMM(23), ZMM(8)) VPMULLD(YMM(3), YMM(3), YMM(4)) //and now ymm3 contains (rs_c*16...23) - VMOVAPS(ZMM(24), ZMM(8)) + VMOVAPS(ZMM(15), ZMM(8)) MOV(RDI, VAR(offsetPtr)) + VMOVAPS(ZMM(16), ZMM(8)) VMOVAPS(ZMM(4), MEM(RDI)) +#if SCATTER_PREFETCH_C + VMOVAPS(ZMM(17), ZMM(8)) + VMOVAPS(ZMM(18), ZMM(8)) + VMOVAPS(ZMM(19), ZMM(8)) VBROADCASTSS(ZMM(5), VAR(rs_c)) + VMOVAPS(ZMM(20), ZMM(8)) + VMOVAPS(ZMM(21), ZMM(8)) VPMULLD(ZMM(2), ZMM(4), ZMM(5)) + VMOVAPS(ZMM(22), ZMM(8)) VMOVAPS(YMM(3), MEM(RDI,64)) + VMOVAPS(ZMM(23), ZMM(8)) VPMULLD(YMM(3), YMM(3), YMM(5)) +#else + VMOVAPS(ZMM(17), ZMM(8)) MOV(R12, VAR(rs_c)) + VMOVAPS(ZMM(18), ZMM(8)) LEA(R13, MEM(R12,R12,2)) + VMOVAPS(ZMM(19), ZMM(8)) LEA(R14, MEM(R12,R12,4)) + VMOVAPS(ZMM(20), ZMM(8)) LEA(R15, MEM(R13,R12,4)) + VMOVAPS(ZMM(21), ZMM(8)) LEA(RDX, MEM(RCX,R12,8)) + VMOVAPS(ZMM(22), ZMM(8)) LEA(RDI, MEM(RDX,R12,8)) + VMOVAPS(ZMM(23), ZMM(8)) +#endif + VMOVAPS(ZMM(24), ZMM(8)) VPSLLD(ZMM(4), ZMM(4), IMM(3)) VMOVAPS(ZMM(25), ZMM(8)) MOV(R8, IMM(4*24*8)) //offset for 4 iterations VMOVAPS(ZMM(26), ZMM(8)) LEA(R9, MEM(R8,R8,2)) //*3 VMOVAPS(ZMM(27), ZMM(8)) LEA(R10, MEM(R8,R8,4)) //*5 @@ -479,10 +533,37 @@ void bli_dgemm_opt_24x8( SUB(RSI, IMM(0+C_L1_ITERS)) //prefetch C into L2 +#if SCATTER_PREFETCH_C KXNORW(K(1), K(0), K(0)) KXNORW(K(2), K(0), K(0)) VSCATTERPFDPS(1, MEM(RCX,ZMM(2),8) MASK_K(1)) VSCATTERPFDPD(1, MEM(RCX,YMM(3),8) MASK_K(2)) +#else + PREFETCH(1, MEM(RCX )) + PREFETCH(1, MEM(RCX,R12,1)) + PREFETCH(1, MEM(RCX,R12,2)) + PREFETCH(1, MEM(RCX,R13,1)) + PREFETCH(1, MEM(RCX,R12,4)) + PREFETCH(1, MEM(RCX,R14,1)) + PREFETCH(1, MEM(RCX,R13,2)) + PREFETCH(1, MEM(RCX,R15,1)) + PREFETCH(1, MEM(RDX )) + PREFETCH(1, MEM(RDX,R12,1)) + PREFETCH(1, MEM(RDX,R12,2)) + PREFETCH(1, MEM(RDX,R13,1)) + PREFETCH(1, MEM(RDX,R12,4)) + PREFETCH(1, MEM(RDX,R14,1)) + PREFETCH(1, MEM(RDX,R13,2)) + PREFETCH(1, MEM(RDX,R15,1)) + PREFETCH(1, MEM(RDI )) + PREFETCH(1, MEM(RDI,R12,1)) + PREFETCH(1, MEM(RDI,R12,2)) + PREFETCH(1, MEM(RDI,R13,1)) + PREFETCH(1, MEM(RDI,R12,4)) + PREFETCH(1, MEM(RDI,R14,1)) + PREFETCH(1, MEM(RDI,R13,2)) + PREFETCH(1, MEM(RDI,R15,1)) +#endif MAIN_LOOP_L2 @@ -491,10 +572,37 @@ void bli_dgemm_opt_24x8( LABEL(PREFETCH_C_L1) //prefetch C into L1 +#if SCATTER_PREFETCH_C KXNORW(K(1), K(0), K(0)) KXNORW(K(2), K(0), K(0)) VSCATTERPFDPS(0, MEM(RCX,ZMM(2),8) MASK_K(1)) VSCATTERPFDPD(0, MEM(RCX,YMM(3),8) MASK_K(2)) +#else + PREFETCH(0, MEM(RCX )) + PREFETCH(0, MEM(RCX,R12,1)) + PREFETCH(0, MEM(RCX,R12,2)) + PREFETCH(0, MEM(RCX,R13,1)) + PREFETCH(0, MEM(RCX,R12,4)) + PREFETCH(0, MEM(RCX,R14,1)) + PREFETCH(0, MEM(RCX,R13,2)) + PREFETCH(0, MEM(RCX,R15,1)) + PREFETCH(0, MEM(RDX )) + PREFETCH(0, MEM(RDX,R12,1)) + PREFETCH(0, MEM(RDX,R12,2)) + PREFETCH(0, MEM(RDX,R13,1)) + PREFETCH(0, MEM(RDX,R12,4)) + PREFETCH(0, MEM(RDX,R14,1)) + PREFETCH(0, MEM(RDX,R13,2)) + PREFETCH(0, MEM(RDX,R15,1)) + PREFETCH(0, MEM(RDI )) + PREFETCH(0, MEM(RDI,R12,1)) + PREFETCH(0, MEM(RDI,R12,2)) + PREFETCH(0, MEM(RDI,R13,1)) + PREFETCH(0, MEM(RDI,R12,4)) + PREFETCH(0, MEM(RDI,R14,1)) + PREFETCH(0, MEM(RDI,R13,2)) + PREFETCH(0, MEM(RDI,R15,1)) +#endif MAIN_LOOP_L1