All fixed.

This commit is contained in:
Devin Matthews
2016-07-25 15:15:13 -05:00
parent 963d0393b0
commit a7d8ca97b8

View File

@@ -37,7 +37,7 @@
#include "bli_avx512_macros.h"
#define UNROLL_K 1
#define UNROLL_K 32
#define PREFETCH_A_L2 0
#define PREFETCH_B_L2 0
@@ -225,6 +225,8 @@
\
TEST(RDI, RDI) \
JZ(NAME##_DONE) \
\
LABEL(NAME##_TAIL) \
\
SUBITER(0,1,0,RAX) \
\
@@ -399,8 +401,7 @@
#define LOOP_K(M,K,NAME) LOOP_K_(M,K)(NAME)
#define MAIN_LOOP_L2 LOOP_K(MAIN_LOOP_,UNROLL_K,MAIN_LOOP_L2)
//#define MAIN_LOOP_L1 LOOP_K(MAIN_LOOP_,C_L1_ITERS,MAIN_LOOP_L1)
#define MAIN_LOOP_L1 LOOP_K(MAIN_LOOP_,1,MAIN_LOOP_L1)
#define MAIN_LOOP_L1 LOOP_K(MAIN_LOOP_,C_L1_ITERS,MAIN_LOOP_L1)
//This is an array used for the scatter/gather instructions.
extern int32_t offsets[24];
@@ -418,22 +419,6 @@ void bli_dgemm_opt_24x8(
cntx_t* restrict cntx
)
{
/*
for (dim_t i = 0;i < 24;i++)
{
for (dim_t j = 0;j < 8;j++)
{
c[i*rs_c+j*cs_c] *= *beta;
for (dim_t p = 0;p < k;p++)
{
c[i*rs_c+j*cs_c] += (*alpha)*a[i+p*24]*b[j+p*8];
}
}
}
return;
*/
const double * a_next = bli_auxinfo_next_a( data );
const double * b_next = bli_auxinfo_next_b( data );
@@ -462,21 +447,21 @@ void bli_dgemm_opt_24x8(
VMOVAPS(ZMM(11), ZMM(8)) MOV(RAX, VAR(a)) //load address of a
VMOVAPS(ZMM(12), ZMM(8))
VMOVAPS(ZMM(13), ZMM(8)) MOV(RBX, VAR(b)) //load address of b
VMOVAPS(ZMM(14), ZMM(8)) //VMOVAPD(ZMM(0), MEM(RBX)) //pre-load b
VMOVAPS(ZMM(14), ZMM(8)) VMOVAPD(ZMM(0), MEM(RBX)) //pre-load b
VMOVAPS(ZMM(15), ZMM(8))
VMOVAPS(ZMM(16), ZMM(8)) MOV(RCX, VAR(c)) //load address of c
VMOVAPS(ZMM(17), ZMM(8)) //set up indexing information for prefetching C
VMOVAPS(ZMM(18), ZMM(8)) //MOV(RDI, VAR(offsetPtr))
VMOVAPS(ZMM(19), ZMM(8)) //VBROADCASTSS(ZMM(4), VAR(rs_c))
VMOVAPS(ZMM(20), ZMM(8)) //VMOVAPS(ZMM(2), MEM(RDI)) //at this point zmm2 contains (0...15)
VMOVAPS(ZMM(21), ZMM(8)) //VPMULLD(ZMM(2), ZMM(2), ZMM(4)) //and now zmm2 contains (rs_c*0...15)
VMOVAPS(ZMM(22), ZMM(8)) //VMOVAPS(YMM(3), MEM(RDI,64)) //at this point ymm3 contains (16...23)
VMOVAPS(ZMM(23), ZMM(8)) //VPMULLD(YMM(3), YMM(3), YMM(4)) //and now ymm3 contains (rs_c*16...23)
VMOVAPS(ZMM(18), ZMM(8)) MOV(RDI, VAR(offsetPtr))
VMOVAPS(ZMM(19), ZMM(8)) VBROADCASTSS(ZMM(4), VAR(rs_c))
VMOVAPS(ZMM(20), ZMM(8)) VMOVAPS(ZMM(2), MEM(RDI)) //at this point zmm2 contains (0...15)
VMOVAPS(ZMM(21), ZMM(8)) VPMULLD(ZMM(2), ZMM(2), ZMM(4)) //and now zmm2 contains (rs_c*0...15)
VMOVAPS(ZMM(22), ZMM(8)) VMOVAPS(YMM(3), MEM(RDI,64)) //at this point ymm3 contains (16...23)
VMOVAPS(ZMM(23), ZMM(8)) VPMULLD(YMM(3), YMM(3), YMM(4)) //and now ymm3 contains (rs_c*16...23)
VMOVAPS(ZMM(24), ZMM(8))
VMOVAPS(ZMM(25), ZMM(8)) //MOV(R8, IMM(4*24*8)) //offset for 4 iterations
VMOVAPS(ZMM(26), ZMM(8)) //LEA(R9, MEM(R8,R8,2)) //*3
VMOVAPS(ZMM(27), ZMM(8)) //LEA(R10, MEM(R8,R8,4)) //*5
VMOVAPS(ZMM(28), ZMM(8)) //LEA(R11, MEM(R9,R8,4)) //*7
VMOVAPS(ZMM(25), ZMM(8)) MOV(R8, IMM(4*24*8)) //offset for 4 iterations
VMOVAPS(ZMM(26), ZMM(8)) LEA(R9, MEM(R8,R8,2)) //*3
VMOVAPS(ZMM(27), ZMM(8)) LEA(R10, MEM(R8,R8,4)) //*5
VMOVAPS(ZMM(28), ZMM(8)) LEA(R11, MEM(R9,R8,4)) //*7
VMOVAPS(ZMM(29), ZMM(8))
VMOVAPS(ZMM(30), ZMM(8))
VMOVAPS(ZMM(31), ZMM(8))
@@ -488,66 +473,30 @@ void bli_dgemm_opt_24x8(
#endif
//need 0+... to satisfy preprocessor
//CMP(RSI, IMM(0+C_MIN_L2_ITERS))
//JLE(PREFETCH_C_L1)
CMP(RSI, IMM(0+C_MIN_L2_ITERS))
JLE(PREFETCH_C_L1)
//SUB(RSI, IMM(0+C_L1_ITERS))
SUB(RSI, IMM(0+C_L1_ITERS))
//prefetch C into L2
//KXNORW(K(1), K(0), K(0))
//KXNORW(K(2), K(0), K(0))
//VSCATTERPFDPS(1, MEM(RCX,ZMM(2),8) MASK_K(1))
//VSCATTERPFDPD(1, MEM(RCX,YMM(3),8) MASK_K(2))
KXNORW(K(1), K(0), K(0))
KXNORW(K(2), K(0), K(0))
VSCATTERPFDPS(1, MEM(RCX,ZMM(2),8) MASK_K(1))
VSCATTERPFDPD(1, MEM(RCX,YMM(3),8) MASK_K(2))
//MAIN_LOOP_L2
MAIN_LOOP_L2
//MOV(RSI, IMM(0+C_L1_ITERS))
MOV(RSI, IMM(0+C_L1_ITERS))
//LABEL(PREFETCH_C_L1)
LABEL(PREFETCH_C_L1)
//prefetch C into L1
//KXNORW(K(1), K(0), K(0))
//KXNORW(K(2), K(0), K(0))
//VSCATTERPFDPS(0, MEM(RCX,ZMM(2),8) MASK_K(1))
//VSCATTERPFDPD(0, MEM(RCX,YMM(3),8) MASK_K(2))
KXNORW(K(1), K(0), K(0))
KXNORW(K(2), K(0), K(0))
VSCATTERPFDPS(0, MEM(RCX,ZMM(2),8) MASK_K(1))
VSCATTERPFDPD(0, MEM(RCX,YMM(3),8) MASK_K(2))
//MAIN_LOOP_L1
LABEL(MAINLOOP)
VMOVAPD(ZMM(0), MEM(RBX))
VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(RAX, 0*8))
VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(RAX, 1*8))
VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(RAX, 2*8))
VFMADD231PD(ZMM(11), ZMM(0), MEM_1TO8(RAX, 3*8))
VFMADD231PD(ZMM(12), ZMM(0), MEM_1TO8(RAX, 4*8))
VFMADD231PD(ZMM(13), ZMM(0), MEM_1TO8(RAX, 5*8))
VFMADD231PD(ZMM(14), ZMM(0), MEM_1TO8(RAX, 6*8))
VFMADD231PD(ZMM(15), ZMM(0), MEM_1TO8(RAX, 7*8))
VFMADD231PD(ZMM(16), ZMM(0), MEM_1TO8(RAX, 8*8))
VFMADD231PD(ZMM(17), ZMM(0), MEM_1TO8(RAX, 9*8))
VFMADD231PD(ZMM(18), ZMM(0), MEM_1TO8(RAX,10*8))
VFMADD231PD(ZMM(19), ZMM(0), MEM_1TO8(RAX,11*8))
VFMADD231PD(ZMM(20), ZMM(0), MEM_1TO8(RAX,12*8))
VFMADD231PD(ZMM(21), ZMM(0), MEM_1TO8(RAX,13*8))
VFMADD231PD(ZMM(22), ZMM(0), MEM_1TO8(RAX,14*8))
VFMADD231PD(ZMM(23), ZMM(0), MEM_1TO8(RAX,15*8))
VFMADD231PD(ZMM(24), ZMM(0), MEM_1TO8(RAX,16*8))
VFMADD231PD(ZMM(25), ZMM(0), MEM_1TO8(RAX,17*8))
VFMADD231PD(ZMM(26), ZMM(0), MEM_1TO8(RAX,18*8))
VFMADD231PD(ZMM(27), ZMM(0), MEM_1TO8(RAX,19*8))
VFMADD231PD(ZMM(28), ZMM(0), MEM_1TO8(RAX,20*8))
VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(RAX,21*8))
VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(RAX,22*8))
VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(RAX,23*8))
ADD(RAX, IMM(24*8))
ADD(RBX, IMM( 8*8))
SUB(RSI, IMM(1))
JNZ(MAINLOOP)
MAIN_LOOP_L1
LABEL(POSTACCUM)
@@ -564,15 +513,15 @@ void bli_dgemm_opt_24x8(
// Check if C is row stride. If not, jump to the slow scattered update
MOV(RAX, VAR(rs_c))
LEA(RAX, MEM(,RAX,8))
MOV(RBX, VAR(cs_c))
LEA(RDI, MEM(RAX,RAX,2))
CMP(RBX, IMM(1))
//JNE(SCATTEREDUPDATE)
JMP(SCATTEREDUPDATE)
JNE(SCATTEREDUPDATE)
VMOVQ(RDX, XMM(1))
SAL1(RDX) //shift out sign bit
//JZ(COLSTORBZ)
JZ(COLSTORBZ)
UPDATE_C_FOUR_ROWS( 8, 9,10,11)
UPDATE_C_FOUR_ROWS(12,13,14,15)
@@ -599,12 +548,12 @@ void bli_dgemm_opt_24x8(
MOV(RDI, VAR(offsetPtr))
VMOVAPS(ZMM(2), MEM(RDI))
/* Note that this ignores the upper 32 bits in cs_c */
VBROADCASTSS(ZMM(3), VAR(cs_c))
VPBROADCASTD(ZMM(3), EBX)
VPMULLD(ZMM(2), ZMM(3), ZMM(2))
VMOVQ(RDX, XMM(1))
SAL1(RDX) //shift out sign bit
//JZ(SCATTERBZ)
JZ(SCATTERBZ)
UPDATE_C_ROW_SCATTERED( 8)
UPDATE_C_ROW_SCATTERED( 9)