This version gets ~26GF on one core.

This commit is contained in:
Devin Matthews
2016-07-27 11:44:54 -05:00
parent a7d8ca97b8
commit 2c9de740ed

View File

@@ -39,12 +39,15 @@
#define UNROLL_K 32
#define SCATTER_PREFETCH_AB 0
#define SCATTER_PREFETCH_C 1
#define PREFETCH_A_L2 0
#define PREFETCH_B_L2 0
#define L2_PREFETCH_DIST 16
#define L2_PREFETCH_DIST 64
#define A_L1_PREFETCH_DIST 4
#define B_L1_PREFETCH_DIST 2
#define A_L1_PREFETCH_DIST 10
#define B_L1_PREFETCH_DIST 30
#define C_MIN_L2_ITERS 40 //C is not prefetched into L2 for k <= this
#define C_L1_ITERS 16 //number of iterations before the end to prefetch C into L1
@@ -97,11 +100,9 @@
VSCATTERDPD(MEM(RCX,YMM(2),8) MASK_K(1), ZMM(NUM)) \
ADD(RCX, RAX)
#define PREFETCH_A_L1(n) \
\
PREFETCH(0, MEM(RAX,(A_L1_PREFETCH_DIST+n)*24*8)) \
PREFETCH(0, MEM(RAX,(A_L1_PREFETCH_DIST+n)*24*8+64)) \
PREFETCH(0, MEM(RAX,(A_L1_PREFETCH_DIST+n)*24*8+128))
#define PREFETCH_A_L1_1(n) PREFETCH(0, MEM(RAX,(A_L1_PREFETCH_DIST+n)*24*8))
#define PREFETCH_A_L1_2(n) PREFETCH(0, MEM(RAX,(A_L1_PREFETCH_DIST+n)*24*8+64))
#define PREFETCH_A_L1_3(n) PREFETCH(0, MEM(RAX,(A_L1_PREFETCH_DIST+n)*24*8+128))
#if PREFETCH_A_L2
#undef PREFETCH_A_L2
@@ -129,6 +130,36 @@
#define PREFETCH_B_L2(...)
#endif
#if SCATTER_PREFETCH_AB
#undef SCATTER_PREFETCH_AB
#undef PREFETCH_A_L1_1
#undef PREFETCH_A_L1_2
#undef PREFETCH_A_L1_3
#undef PREFETCH_B_L1
#define SCATTER_PREFETCH_AB(n) \
\
KXNORW(K(1), K(0), K(0)) \
VGATHERPFDPS(0, MEM(RAX,ZMM(4),8,((3*n )*16+3*A_L1_PREFETCH_DIST)*64) MASK_K(1)) \
KXNORW(K(2), K(0), K(0)) \
VGATHERPFDPS(0, MEM(RAX,ZMM(4),8,((3*n+1)*16+3*A_L1_PREFETCH_DIST)*64) MASK_K(2)) \
KXNORW(K(3), K(0), K(0)) \
VGATHERPFDPS(0, MEM(RAX,ZMM(4),8,((3*n+2)*16+3*A_L1_PREFETCH_DIST)*64) MASK_K(3)) \
KXNORW(K(4), K(0), K(0)) \
VGATHERPFDPS(0, MEM(RBX,ZMM(4),8,( n *16+ B_L1_PREFETCH_DIST)*64) MASK_K(4))
#define PREFETCH_A_L1_1(...)
#define PREFETCH_A_L1_2(...)
#define PREFETCH_A_L1_3(...)
#define PREFETCH_B_L1(...)
#else
#undef SCATTER_PREFETCH_AB
#define SCATTER_PREFETCH_AB(...)
#endif
//
// n: index in unrolled loop (for prefetching offsets)
//
@@ -139,37 +170,38 @@
//
#define SUBITER(n,a,b,...) \
\
VMOVAPD(ZMM(0), MEM(RBX,(n+1-1)*64)) \
\
PREFETCH_A_L1(n) \
PREFETCH_B_L1(n) \
PREFETCH_A_L2(n) \
PREFETCH_B_L2(n) \
\
VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 0)*8)) \
VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 1)*8)) \
VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 2)*8)) \
VFMADD231PD(ZMM(11), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 3)*8)) \
VFMADD231PD(ZMM(12), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 4)*8)) \
VFMADD231PD(ZMM(13), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 5)*8)) \
VFMADD231PD(ZMM(14), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 6)*8)) \
VFMADD231PD(ZMM(15), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 7)*8)) \
VFMADD231PD(ZMM(16), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 8)*8)) \
VFMADD231PD(ZMM(17), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 9)*8)) \
VFMADD231PD(ZMM(18), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+10)*8)) \
VFMADD231PD(ZMM(19), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+11)*8)) \
VFMADD231PD(ZMM(20), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+12)*8)) \
VFMADD231PD(ZMM(21), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+13)*8)) \
VFMADD231PD(ZMM(22), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+14)*8)) \
VFMADD231PD(ZMM(23), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+15)*8)) \
VFMADD231PD(ZMM(24), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+16)*8)) \
VFMADD231PD(ZMM(25), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+17)*8)) \
VFMADD231PD(ZMM(26), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+18)*8)) \
VFMADD231PD(ZMM(27), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+19)*8)) \
VFMADD231PD(ZMM(28), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+20)*8)) \
VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+21)*8)) \
VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+22)*8)) \
VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(__VA_ARGS__,((n%%4)*24+23)*8))
VMOVAPD(ZMM(a), MEM(RBX,(n+1)*64)) \
VFMADD231PD(ZMM( 8), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 0)*8)) \
VFMADD231PD(ZMM( 9), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 1)*8)) \
VFMADD231PD(ZMM(10), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 2)*8)) \
VFMADD231PD(ZMM(11), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 3)*8)) \
PREFETCH_A_L1_1(n) \
VFMADD231PD(ZMM(12), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 4)*8)) \
VFMADD231PD(ZMM(13), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 5)*8)) \
VFMADD231PD(ZMM(14), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 6)*8)) \
VFMADD231PD(ZMM(15), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 7)*8)) \
PREFETCH_A_L1_2(n) \
VFMADD231PD(ZMM(16), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 8)*8)) \
VFMADD231PD(ZMM(17), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+ 9)*8)) \
VFMADD231PD(ZMM(18), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+10)*8)) \
VFMADD231PD(ZMM(19), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+11)*8)) \
PREFETCH_A_L1_3(n) \
VFMADD231PD(ZMM(20), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+12)*8)) \
VFMADD231PD(ZMM(21), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+13)*8)) \
VFMADD231PD(ZMM(22), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+14)*8)) \
VFMADD231PD(ZMM(23), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+15)*8)) \
PREFETCH_B_L1(n) \
VFMADD231PD(ZMM(24), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+16)*8)) \
VFMADD231PD(ZMM(25), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+17)*8)) \
VFMADD231PD(ZMM(26), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+18)*8)) \
VFMADD231PD(ZMM(27), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+19)*8)) \
PREFETCH_B_L2(n) \
VFMADD231PD(ZMM(28), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+20)*8)) \
VFMADD231PD(ZMM(29), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+21)*8)) \
VFMADD231PD(ZMM(30), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+22)*8)) \
VFMADD231PD(ZMM(31), ZMM(b), MEM_1TO8(__VA_ARGS__,((n%%4)*24+23)*8))
#define TAIL_LOOP(NAME) \
\
@@ -308,6 +340,8 @@
\
LOOP_ALIGN \
LABEL(NAME##_LOOP) \
\
SCATTER_PREFETCH_AB(0) \
\
SUBITER( 0,1,0,RAX) \
SUBITER( 1,0,1,RAX) \
@@ -335,6 +369,8 @@
\
TEST(RDI, RDI) \
JZ(NAME##_DONE) \
\
SCATTER_PREFETCH_AB(0) \
\
TAIL_LOOP(NAME##_TAIL) \
\
@@ -349,6 +385,8 @@
\
LOOP_ALIGN \
LABEL(NAME##_LOOP) \
\
SCATTER_PREFETCH_AB(0) \
\
SUBITER( 0,1,0,RAX) \
SUBITER( 1,0,1,RAX) \
@@ -366,6 +404,9 @@
SUBITER(13,0,1,RAX,R9,1) \
SUBITER(14,1,0,RAX,R9,1) \
SUBITER(15,0,1,RAX,R9,1) \
\
SCATTER_PREFETCH_AB(1) \
\
SUBITER(16,1,0,RAX,R8,4) \
SUBITER(17,0,1,RAX,R8,4) \
SUBITER(18,1,0,RAX,R8,4) \
@@ -392,6 +433,9 @@
\
TEST(RDI, RDI) \
JZ(NAME##_DONE) \
\
SCATTER_PREFETCH_AB(0) \
SCATTER_PREFETCH_AB(1) \
\
TAIL_LOOP(NAME##_TAIL) \
\
@@ -445,19 +489,29 @@ void bli_dgemm_opt_24x8(
VMOVAPS(ZMM( 9), ZMM(8))
VMOVAPS(ZMM(10), ZMM(8)) MOV(RSI, VAR(k)) //loop index
VMOVAPS(ZMM(11), ZMM(8)) MOV(RAX, VAR(a)) //load address of a
VMOVAPS(ZMM(12), ZMM(8))
VMOVAPS(ZMM(13), ZMM(8)) MOV(RBX, VAR(b)) //load address of b
VMOVAPS(ZMM(12), ZMM(8)) MOV(RBX, VAR(b)) //load address of b
VMOVAPS(ZMM(13), ZMM(8)) MOV(RCX, VAR(c)) //load address of c
VMOVAPS(ZMM(14), ZMM(8)) VMOVAPD(ZMM(0), MEM(RBX)) //pre-load b
VMOVAPS(ZMM(15), ZMM(8))
VMOVAPS(ZMM(16), ZMM(8)) MOV(RCX, VAR(c)) //load address of c
VMOVAPS(ZMM(17), ZMM(8)) //set up indexing information for prefetching C
VMOVAPS(ZMM(18), ZMM(8)) MOV(RDI, VAR(offsetPtr))
VMOVAPS(ZMM(19), ZMM(8)) VBROADCASTSS(ZMM(4), VAR(rs_c))
VMOVAPS(ZMM(20), ZMM(8)) VMOVAPS(ZMM(2), MEM(RDI)) //at this point zmm2 contains (0...15)
VMOVAPS(ZMM(21), ZMM(8)) VPMULLD(ZMM(2), ZMM(2), ZMM(4)) //and now zmm2 contains (rs_c*0...15)
VMOVAPS(ZMM(22), ZMM(8)) VMOVAPS(YMM(3), MEM(RDI,64)) //at this point ymm3 contains (16...23)
VMOVAPS(ZMM(23), ZMM(8)) VPMULLD(YMM(3), YMM(3), YMM(4)) //and now ymm3 contains (rs_c*16...23)
VMOVAPS(ZMM(24), ZMM(8))
VMOVAPS(ZMM(15), ZMM(8)) MOV(RDI, VAR(offsetPtr))
VMOVAPS(ZMM(16), ZMM(8)) VMOVAPS(ZMM(4), MEM(RDI))
#if SCATTER_PREFETCH_C
VMOVAPS(ZMM(17), ZMM(8))
VMOVAPS(ZMM(18), ZMM(8))
VMOVAPS(ZMM(19), ZMM(8)) VBROADCASTSS(ZMM(5), VAR(rs_c))
VMOVAPS(ZMM(20), ZMM(8))
VMOVAPS(ZMM(21), ZMM(8)) VPMULLD(ZMM(2), ZMM(4), ZMM(5))
VMOVAPS(ZMM(22), ZMM(8)) VMOVAPS(YMM(3), MEM(RDI,64))
VMOVAPS(ZMM(23), ZMM(8)) VPMULLD(YMM(3), YMM(3), YMM(5))
#else
VMOVAPS(ZMM(17), ZMM(8)) MOV(R12, VAR(rs_c))
VMOVAPS(ZMM(18), ZMM(8)) LEA(R13, MEM(R12,R12,2))
VMOVAPS(ZMM(19), ZMM(8)) LEA(R14, MEM(R12,R12,4))
VMOVAPS(ZMM(20), ZMM(8)) LEA(R15, MEM(R13,R12,4))
VMOVAPS(ZMM(21), ZMM(8)) LEA(RDX, MEM(RCX,R12,8))
VMOVAPS(ZMM(22), ZMM(8)) LEA(RDI, MEM(RDX,R12,8))
VMOVAPS(ZMM(23), ZMM(8))
#endif
VMOVAPS(ZMM(24), ZMM(8)) VPSLLD(ZMM(4), ZMM(4), IMM(3))
VMOVAPS(ZMM(25), ZMM(8)) MOV(R8, IMM(4*24*8)) //offset for 4 iterations
VMOVAPS(ZMM(26), ZMM(8)) LEA(R9, MEM(R8,R8,2)) //*3
VMOVAPS(ZMM(27), ZMM(8)) LEA(R10, MEM(R8,R8,4)) //*5
@@ -479,10 +533,37 @@ void bli_dgemm_opt_24x8(
SUB(RSI, IMM(0+C_L1_ITERS))
//prefetch C into L2
#if SCATTER_PREFETCH_C
KXNORW(K(1), K(0), K(0))
KXNORW(K(2), K(0), K(0))
VSCATTERPFDPS(1, MEM(RCX,ZMM(2),8) MASK_K(1))
VSCATTERPFDPD(1, MEM(RCX,YMM(3),8) MASK_K(2))
#else
PREFETCH(1, MEM(RCX ))
PREFETCH(1, MEM(RCX,R12,1))
PREFETCH(1, MEM(RCX,R12,2))
PREFETCH(1, MEM(RCX,R13,1))
PREFETCH(1, MEM(RCX,R12,4))
PREFETCH(1, MEM(RCX,R14,1))
PREFETCH(1, MEM(RCX,R13,2))
PREFETCH(1, MEM(RCX,R15,1))
PREFETCH(1, MEM(RDX ))
PREFETCH(1, MEM(RDX,R12,1))
PREFETCH(1, MEM(RDX,R12,2))
PREFETCH(1, MEM(RDX,R13,1))
PREFETCH(1, MEM(RDX,R12,4))
PREFETCH(1, MEM(RDX,R14,1))
PREFETCH(1, MEM(RDX,R13,2))
PREFETCH(1, MEM(RDX,R15,1))
PREFETCH(1, MEM(RDI ))
PREFETCH(1, MEM(RDI,R12,1))
PREFETCH(1, MEM(RDI,R12,2))
PREFETCH(1, MEM(RDI,R13,1))
PREFETCH(1, MEM(RDI,R12,4))
PREFETCH(1, MEM(RDI,R14,1))
PREFETCH(1, MEM(RDI,R13,2))
PREFETCH(1, MEM(RDI,R15,1))
#endif
MAIN_LOOP_L2
@@ -491,10 +572,37 @@ void bli_dgemm_opt_24x8(
LABEL(PREFETCH_C_L1)
//prefetch C into L1
#if SCATTER_PREFETCH_C
KXNORW(K(1), K(0), K(0))
KXNORW(K(2), K(0), K(0))
VSCATTERPFDPS(0, MEM(RCX,ZMM(2),8) MASK_K(1))
VSCATTERPFDPD(0, MEM(RCX,YMM(3),8) MASK_K(2))
#else
PREFETCH(0, MEM(RCX ))
PREFETCH(0, MEM(RCX,R12,1))
PREFETCH(0, MEM(RCX,R12,2))
PREFETCH(0, MEM(RCX,R13,1))
PREFETCH(0, MEM(RCX,R12,4))
PREFETCH(0, MEM(RCX,R14,1))
PREFETCH(0, MEM(RCX,R13,2))
PREFETCH(0, MEM(RCX,R15,1))
PREFETCH(0, MEM(RDX ))
PREFETCH(0, MEM(RDX,R12,1))
PREFETCH(0, MEM(RDX,R12,2))
PREFETCH(0, MEM(RDX,R13,1))
PREFETCH(0, MEM(RDX,R12,4))
PREFETCH(0, MEM(RDX,R14,1))
PREFETCH(0, MEM(RDX,R13,2))
PREFETCH(0, MEM(RDX,R15,1))
PREFETCH(0, MEM(RDI ))
PREFETCH(0, MEM(RDI,R12,1))
PREFETCH(0, MEM(RDI,R12,2))
PREFETCH(0, MEM(RDI,R13,1))
PREFETCH(0, MEM(RDI,R12,4))
PREFETCH(0, MEM(RDI,R14,1))
PREFETCH(0, MEM(RDI,R13,2))
PREFETCH(0, MEM(RDI,R15,1))
#endif
MAIN_LOOP_L1