diff --git a/kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c b/kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c index 5846f4a82..d1cdcbc65 100644 --- a/kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c +++ b/kernels/x86_64/knl/3/bli_dgemm_opt_24x8.c @@ -37,7 +37,7 @@ #include "bli_avx512_macros.h" -#define UNROLL_K 1 +#define UNROLL_K 32 #define PREFETCH_A_L2 0 #define PREFETCH_B_L2 0 @@ -225,6 +225,8 @@ \ TEST(RDI, RDI) \ JZ(NAME##_DONE) \ +\ + LABEL(NAME##_TAIL) \ \ SUBITER(0,1,0,RAX) \ \ @@ -399,8 +401,7 @@ #define LOOP_K(M,K,NAME) LOOP_K_(M,K)(NAME) #define MAIN_LOOP_L2 LOOP_K(MAIN_LOOP_,UNROLL_K,MAIN_LOOP_L2) -//#define MAIN_LOOP_L1 LOOP_K(MAIN_LOOP_,C_L1_ITERS,MAIN_LOOP_L1) -#define MAIN_LOOP_L1 LOOP_K(MAIN_LOOP_,1,MAIN_LOOP_L1) +#define MAIN_LOOP_L1 LOOP_K(MAIN_LOOP_,C_L1_ITERS,MAIN_LOOP_L1) //This is an array used for the scatter/gather instructions. extern int32_t offsets[24]; @@ -418,22 +419,6 @@ void bli_dgemm_opt_24x8( cntx_t* restrict cntx ) { - /* - for (dim_t i = 0;i < 24;i++) - { - for (dim_t j = 0;j < 8;j++) - { - c[i*rs_c+j*cs_c] *= *beta; - for (dim_t p = 0;p < k;p++) - { - c[i*rs_c+j*cs_c] += (*alpha)*a[i+p*24]*b[j+p*8]; - } - } - } - - return; - */ - const double * a_next = bli_auxinfo_next_a( data ); const double * b_next = bli_auxinfo_next_b( data ); @@ -462,21 +447,21 @@ void bli_dgemm_opt_24x8( VMOVAPS(ZMM(11), ZMM(8)) MOV(RAX, VAR(a)) //load address of a VMOVAPS(ZMM(12), ZMM(8)) VMOVAPS(ZMM(13), ZMM(8)) MOV(RBX, VAR(b)) //load address of b - VMOVAPS(ZMM(14), ZMM(8)) //VMOVAPD(ZMM(0), MEM(RBX)) //pre-load b + VMOVAPS(ZMM(14), ZMM(8)) VMOVAPD(ZMM(0), MEM(RBX)) //pre-load b VMOVAPS(ZMM(15), ZMM(8)) VMOVAPS(ZMM(16), ZMM(8)) MOV(RCX, VAR(c)) //load address of c VMOVAPS(ZMM(17), ZMM(8)) //set up indexing information for prefetching C - VMOVAPS(ZMM(18), ZMM(8)) //MOV(RDI, VAR(offsetPtr)) - VMOVAPS(ZMM(19), ZMM(8)) //VBROADCASTSS(ZMM(4), VAR(rs_c)) - VMOVAPS(ZMM(20), ZMM(8)) //VMOVAPS(ZMM(2), MEM(RDI)) //at this point zmm2 contains (0...15) - VMOVAPS(ZMM(21), ZMM(8)) //VPMULLD(ZMM(2), ZMM(2), ZMM(4)) //and now zmm2 contains (rs_c*0...15) - VMOVAPS(ZMM(22), ZMM(8)) //VMOVAPS(YMM(3), MEM(RDI,64)) //at this point ymm3 contains (16...23) - VMOVAPS(ZMM(23), ZMM(8)) //VPMULLD(YMM(3), YMM(3), YMM(4)) //and now ymm3 contains (rs_c*16...23) + VMOVAPS(ZMM(18), ZMM(8)) MOV(RDI, VAR(offsetPtr)) + VMOVAPS(ZMM(19), ZMM(8)) VBROADCASTSS(ZMM(4), VAR(rs_c)) + VMOVAPS(ZMM(20), ZMM(8)) VMOVAPS(ZMM(2), MEM(RDI)) //at this point zmm2 contains (0...15) + VMOVAPS(ZMM(21), ZMM(8)) VPMULLD(ZMM(2), ZMM(2), ZMM(4)) //and now zmm2 contains (rs_c*0...15) + VMOVAPS(ZMM(22), ZMM(8)) VMOVAPS(YMM(3), MEM(RDI,64)) //at this point ymm3 contains (16...23) + VMOVAPS(ZMM(23), ZMM(8)) VPMULLD(YMM(3), YMM(3), YMM(4)) //and now ymm3 contains (rs_c*16...23) VMOVAPS(ZMM(24), ZMM(8)) - VMOVAPS(ZMM(25), ZMM(8)) //MOV(R8, IMM(4*24*8)) //offset for 4 iterations - VMOVAPS(ZMM(26), ZMM(8)) //LEA(R9, MEM(R8,R8,2)) //*3 - VMOVAPS(ZMM(27), ZMM(8)) //LEA(R10, MEM(R8,R8,4)) //*5 - VMOVAPS(ZMM(28), ZMM(8)) //LEA(R11, MEM(R9,R8,4)) //*7 + VMOVAPS(ZMM(25), ZMM(8)) MOV(R8, IMM(4*24*8)) //offset for 4 iterations + VMOVAPS(ZMM(26), ZMM(8)) LEA(R9, MEM(R8,R8,2)) //*3 + VMOVAPS(ZMM(27), ZMM(8)) LEA(R10, MEM(R8,R8,4)) //*5 + VMOVAPS(ZMM(28), ZMM(8)) LEA(R11, MEM(R9,R8,4)) //*7 VMOVAPS(ZMM(29), ZMM(8)) VMOVAPS(ZMM(30), ZMM(8)) VMOVAPS(ZMM(31), ZMM(8)) @@ -488,66 +473,30 @@ void bli_dgemm_opt_24x8( #endif //need 0+... to satisfy preprocessor - //CMP(RSI, IMM(0+C_MIN_L2_ITERS)) - //JLE(PREFETCH_C_L1) + CMP(RSI, IMM(0+C_MIN_L2_ITERS)) + JLE(PREFETCH_C_L1) - //SUB(RSI, IMM(0+C_L1_ITERS)) + SUB(RSI, IMM(0+C_L1_ITERS)) //prefetch C into L2 - //KXNORW(K(1), K(0), K(0)) - //KXNORW(K(2), K(0), K(0)) - //VSCATTERPFDPS(1, MEM(RCX,ZMM(2),8) MASK_K(1)) - //VSCATTERPFDPD(1, MEM(RCX,YMM(3),8) MASK_K(2)) + KXNORW(K(1), K(0), K(0)) + KXNORW(K(2), K(0), K(0)) + VSCATTERPFDPS(1, MEM(RCX,ZMM(2),8) MASK_K(1)) + VSCATTERPFDPD(1, MEM(RCX,YMM(3),8) MASK_K(2)) - //MAIN_LOOP_L2 + MAIN_LOOP_L2 - //MOV(RSI, IMM(0+C_L1_ITERS)) + MOV(RSI, IMM(0+C_L1_ITERS)) - //LABEL(PREFETCH_C_L1) + LABEL(PREFETCH_C_L1) //prefetch C into L1 - //KXNORW(K(1), K(0), K(0)) - //KXNORW(K(2), K(0), K(0)) - //VSCATTERPFDPS(0, MEM(RCX,ZMM(2),8) MASK_K(1)) - //VSCATTERPFDPD(0, MEM(RCX,YMM(3),8) MASK_K(2)) + KXNORW(K(1), K(0), K(0)) + KXNORW(K(2), K(0), K(0)) + VSCATTERPFDPS(0, MEM(RCX,ZMM(2),8) MASK_K(1)) + VSCATTERPFDPD(0, MEM(RCX,YMM(3),8) MASK_K(2)) - //MAIN_LOOP_L1 - - LABEL(MAINLOOP) - - VMOVAPD(ZMM(0), MEM(RBX)) - - VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(RAX, 0*8)) - VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(RAX, 1*8)) - VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(RAX, 2*8)) - VFMADD231PD(ZMM(11), ZMM(0), MEM_1TO8(RAX, 3*8)) - VFMADD231PD(ZMM(12), ZMM(0), MEM_1TO8(RAX, 4*8)) - VFMADD231PD(ZMM(13), ZMM(0), MEM_1TO8(RAX, 5*8)) - VFMADD231PD(ZMM(14), ZMM(0), MEM_1TO8(RAX, 6*8)) - VFMADD231PD(ZMM(15), ZMM(0), MEM_1TO8(RAX, 7*8)) - VFMADD231PD(ZMM(16), ZMM(0), MEM_1TO8(RAX, 8*8)) - VFMADD231PD(ZMM(17), ZMM(0), MEM_1TO8(RAX, 9*8)) - VFMADD231PD(ZMM(18), ZMM(0), MEM_1TO8(RAX,10*8)) - VFMADD231PD(ZMM(19), ZMM(0), MEM_1TO8(RAX,11*8)) - VFMADD231PD(ZMM(20), ZMM(0), MEM_1TO8(RAX,12*8)) - VFMADD231PD(ZMM(21), ZMM(0), MEM_1TO8(RAX,13*8)) - VFMADD231PD(ZMM(22), ZMM(0), MEM_1TO8(RAX,14*8)) - VFMADD231PD(ZMM(23), ZMM(0), MEM_1TO8(RAX,15*8)) - VFMADD231PD(ZMM(24), ZMM(0), MEM_1TO8(RAX,16*8)) - VFMADD231PD(ZMM(25), ZMM(0), MEM_1TO8(RAX,17*8)) - VFMADD231PD(ZMM(26), ZMM(0), MEM_1TO8(RAX,18*8)) - VFMADD231PD(ZMM(27), ZMM(0), MEM_1TO8(RAX,19*8)) - VFMADD231PD(ZMM(28), ZMM(0), MEM_1TO8(RAX,20*8)) - VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(RAX,21*8)) - VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(RAX,22*8)) - VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(RAX,23*8)) - - ADD(RAX, IMM(24*8)) - ADD(RBX, IMM( 8*8)) - - SUB(RSI, IMM(1)) - - JNZ(MAINLOOP) + MAIN_LOOP_L1 LABEL(POSTACCUM) @@ -564,15 +513,15 @@ void bli_dgemm_opt_24x8( // Check if C is row stride. If not, jump to the slow scattered update MOV(RAX, VAR(rs_c)) + LEA(RAX, MEM(,RAX,8)) MOV(RBX, VAR(cs_c)) LEA(RDI, MEM(RAX,RAX,2)) CMP(RBX, IMM(1)) - //JNE(SCATTEREDUPDATE) - JMP(SCATTEREDUPDATE) + JNE(SCATTEREDUPDATE) VMOVQ(RDX, XMM(1)) SAL1(RDX) //shift out sign bit - //JZ(COLSTORBZ) + JZ(COLSTORBZ) UPDATE_C_FOUR_ROWS( 8, 9,10,11) UPDATE_C_FOUR_ROWS(12,13,14,15) @@ -599,12 +548,12 @@ void bli_dgemm_opt_24x8( MOV(RDI, VAR(offsetPtr)) VMOVAPS(ZMM(2), MEM(RDI)) /* Note that this ignores the upper 32 bits in cs_c */ - VBROADCASTSS(ZMM(3), VAR(cs_c)) + VPBROADCASTD(ZMM(3), EBX) VPMULLD(ZMM(2), ZMM(3), ZMM(2)) VMOVQ(RDX, XMM(1)) SAL1(RDX) //shift out sign bit - //JZ(SCATTERBZ) + JZ(SCATTERBZ) UPDATE_C_ROW_SCATTERED( 8) UPDATE_C_ROW_SCATTERED( 9)