mirror of
https://github.com/amd/blis.git
synced 2026-05-11 09:39:59 +00:00
All fixed.
This commit is contained in:
@@ -37,7 +37,7 @@
|
||||
|
||||
#include "bli_avx512_macros.h"
|
||||
|
||||
#define UNROLL_K 1
|
||||
#define UNROLL_K 32
|
||||
|
||||
#define PREFETCH_A_L2 0
|
||||
#define PREFETCH_B_L2 0
|
||||
@@ -225,6 +225,8 @@
|
||||
\
|
||||
TEST(RDI, RDI) \
|
||||
JZ(NAME##_DONE) \
|
||||
\
|
||||
LABEL(NAME##_TAIL) \
|
||||
\
|
||||
SUBITER(0,1,0,RAX) \
|
||||
\
|
||||
@@ -399,8 +401,7 @@
|
||||
#define LOOP_K(M,K,NAME) LOOP_K_(M,K)(NAME)
|
||||
|
||||
#define MAIN_LOOP_L2 LOOP_K(MAIN_LOOP_,UNROLL_K,MAIN_LOOP_L2)
|
||||
//#define MAIN_LOOP_L1 LOOP_K(MAIN_LOOP_,C_L1_ITERS,MAIN_LOOP_L1)
|
||||
#define MAIN_LOOP_L1 LOOP_K(MAIN_LOOP_,1,MAIN_LOOP_L1)
|
||||
#define MAIN_LOOP_L1 LOOP_K(MAIN_LOOP_,C_L1_ITERS,MAIN_LOOP_L1)
|
||||
|
||||
//This is an array used for the scatter/gather instructions.
|
||||
extern int32_t offsets[24];
|
||||
@@ -418,22 +419,6 @@ void bli_dgemm_opt_24x8(
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
/*
|
||||
for (dim_t i = 0;i < 24;i++)
|
||||
{
|
||||
for (dim_t j = 0;j < 8;j++)
|
||||
{
|
||||
c[i*rs_c+j*cs_c] *= *beta;
|
||||
for (dim_t p = 0;p < k;p++)
|
||||
{
|
||||
c[i*rs_c+j*cs_c] += (*alpha)*a[i+p*24]*b[j+p*8];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
*/
|
||||
|
||||
const double * a_next = bli_auxinfo_next_a( data );
|
||||
const double * b_next = bli_auxinfo_next_b( data );
|
||||
|
||||
@@ -462,21 +447,21 @@ void bli_dgemm_opt_24x8(
|
||||
VMOVAPS(ZMM(11), ZMM(8)) MOV(RAX, VAR(a)) //load address of a
|
||||
VMOVAPS(ZMM(12), ZMM(8))
|
||||
VMOVAPS(ZMM(13), ZMM(8)) MOV(RBX, VAR(b)) //load address of b
|
||||
VMOVAPS(ZMM(14), ZMM(8)) //VMOVAPD(ZMM(0), MEM(RBX)) //pre-load b
|
||||
VMOVAPS(ZMM(14), ZMM(8)) VMOVAPD(ZMM(0), MEM(RBX)) //pre-load b
|
||||
VMOVAPS(ZMM(15), ZMM(8))
|
||||
VMOVAPS(ZMM(16), ZMM(8)) MOV(RCX, VAR(c)) //load address of c
|
||||
VMOVAPS(ZMM(17), ZMM(8)) //set up indexing information for prefetching C
|
||||
VMOVAPS(ZMM(18), ZMM(8)) //MOV(RDI, VAR(offsetPtr))
|
||||
VMOVAPS(ZMM(19), ZMM(8)) //VBROADCASTSS(ZMM(4), VAR(rs_c))
|
||||
VMOVAPS(ZMM(20), ZMM(8)) //VMOVAPS(ZMM(2), MEM(RDI)) //at this point zmm2 contains (0...15)
|
||||
VMOVAPS(ZMM(21), ZMM(8)) //VPMULLD(ZMM(2), ZMM(2), ZMM(4)) //and now zmm2 contains (rs_c*0...15)
|
||||
VMOVAPS(ZMM(22), ZMM(8)) //VMOVAPS(YMM(3), MEM(RDI,64)) //at this point ymm3 contains (16...23)
|
||||
VMOVAPS(ZMM(23), ZMM(8)) //VPMULLD(YMM(3), YMM(3), YMM(4)) //and now ymm3 contains (rs_c*16...23)
|
||||
VMOVAPS(ZMM(18), ZMM(8)) MOV(RDI, VAR(offsetPtr))
|
||||
VMOVAPS(ZMM(19), ZMM(8)) VBROADCASTSS(ZMM(4), VAR(rs_c))
|
||||
VMOVAPS(ZMM(20), ZMM(8)) VMOVAPS(ZMM(2), MEM(RDI)) //at this point zmm2 contains (0...15)
|
||||
VMOVAPS(ZMM(21), ZMM(8)) VPMULLD(ZMM(2), ZMM(2), ZMM(4)) //and now zmm2 contains (rs_c*0...15)
|
||||
VMOVAPS(ZMM(22), ZMM(8)) VMOVAPS(YMM(3), MEM(RDI,64)) //at this point ymm3 contains (16...23)
|
||||
VMOVAPS(ZMM(23), ZMM(8)) VPMULLD(YMM(3), YMM(3), YMM(4)) //and now ymm3 contains (rs_c*16...23)
|
||||
VMOVAPS(ZMM(24), ZMM(8))
|
||||
VMOVAPS(ZMM(25), ZMM(8)) //MOV(R8, IMM(4*24*8)) //offset for 4 iterations
|
||||
VMOVAPS(ZMM(26), ZMM(8)) //LEA(R9, MEM(R8,R8,2)) //*3
|
||||
VMOVAPS(ZMM(27), ZMM(8)) //LEA(R10, MEM(R8,R8,4)) //*5
|
||||
VMOVAPS(ZMM(28), ZMM(8)) //LEA(R11, MEM(R9,R8,4)) //*7
|
||||
VMOVAPS(ZMM(25), ZMM(8)) MOV(R8, IMM(4*24*8)) //offset for 4 iterations
|
||||
VMOVAPS(ZMM(26), ZMM(8)) LEA(R9, MEM(R8,R8,2)) //*3
|
||||
VMOVAPS(ZMM(27), ZMM(8)) LEA(R10, MEM(R8,R8,4)) //*5
|
||||
VMOVAPS(ZMM(28), ZMM(8)) LEA(R11, MEM(R9,R8,4)) //*7
|
||||
VMOVAPS(ZMM(29), ZMM(8))
|
||||
VMOVAPS(ZMM(30), ZMM(8))
|
||||
VMOVAPS(ZMM(31), ZMM(8))
|
||||
@@ -488,66 +473,30 @@ void bli_dgemm_opt_24x8(
|
||||
#endif
|
||||
|
||||
//need 0+... to satisfy preprocessor
|
||||
//CMP(RSI, IMM(0+C_MIN_L2_ITERS))
|
||||
//JLE(PREFETCH_C_L1)
|
||||
CMP(RSI, IMM(0+C_MIN_L2_ITERS))
|
||||
JLE(PREFETCH_C_L1)
|
||||
|
||||
//SUB(RSI, IMM(0+C_L1_ITERS))
|
||||
SUB(RSI, IMM(0+C_L1_ITERS))
|
||||
|
||||
//prefetch C into L2
|
||||
//KXNORW(K(1), K(0), K(0))
|
||||
//KXNORW(K(2), K(0), K(0))
|
||||
//VSCATTERPFDPS(1, MEM(RCX,ZMM(2),8) MASK_K(1))
|
||||
//VSCATTERPFDPD(1, MEM(RCX,YMM(3),8) MASK_K(2))
|
||||
KXNORW(K(1), K(0), K(0))
|
||||
KXNORW(K(2), K(0), K(0))
|
||||
VSCATTERPFDPS(1, MEM(RCX,ZMM(2),8) MASK_K(1))
|
||||
VSCATTERPFDPD(1, MEM(RCX,YMM(3),8) MASK_K(2))
|
||||
|
||||
//MAIN_LOOP_L2
|
||||
MAIN_LOOP_L2
|
||||
|
||||
//MOV(RSI, IMM(0+C_L1_ITERS))
|
||||
MOV(RSI, IMM(0+C_L1_ITERS))
|
||||
|
||||
//LABEL(PREFETCH_C_L1)
|
||||
LABEL(PREFETCH_C_L1)
|
||||
|
||||
//prefetch C into L1
|
||||
//KXNORW(K(1), K(0), K(0))
|
||||
//KXNORW(K(2), K(0), K(0))
|
||||
//VSCATTERPFDPS(0, MEM(RCX,ZMM(2),8) MASK_K(1))
|
||||
//VSCATTERPFDPD(0, MEM(RCX,YMM(3),8) MASK_K(2))
|
||||
KXNORW(K(1), K(0), K(0))
|
||||
KXNORW(K(2), K(0), K(0))
|
||||
VSCATTERPFDPS(0, MEM(RCX,ZMM(2),8) MASK_K(1))
|
||||
VSCATTERPFDPD(0, MEM(RCX,YMM(3),8) MASK_K(2))
|
||||
|
||||
//MAIN_LOOP_L1
|
||||
|
||||
LABEL(MAINLOOP)
|
||||
|
||||
VMOVAPD(ZMM(0), MEM(RBX))
|
||||
|
||||
VFMADD231PD(ZMM( 8), ZMM(0), MEM_1TO8(RAX, 0*8))
|
||||
VFMADD231PD(ZMM( 9), ZMM(0), MEM_1TO8(RAX, 1*8))
|
||||
VFMADD231PD(ZMM(10), ZMM(0), MEM_1TO8(RAX, 2*8))
|
||||
VFMADD231PD(ZMM(11), ZMM(0), MEM_1TO8(RAX, 3*8))
|
||||
VFMADD231PD(ZMM(12), ZMM(0), MEM_1TO8(RAX, 4*8))
|
||||
VFMADD231PD(ZMM(13), ZMM(0), MEM_1TO8(RAX, 5*8))
|
||||
VFMADD231PD(ZMM(14), ZMM(0), MEM_1TO8(RAX, 6*8))
|
||||
VFMADD231PD(ZMM(15), ZMM(0), MEM_1TO8(RAX, 7*8))
|
||||
VFMADD231PD(ZMM(16), ZMM(0), MEM_1TO8(RAX, 8*8))
|
||||
VFMADD231PD(ZMM(17), ZMM(0), MEM_1TO8(RAX, 9*8))
|
||||
VFMADD231PD(ZMM(18), ZMM(0), MEM_1TO8(RAX,10*8))
|
||||
VFMADD231PD(ZMM(19), ZMM(0), MEM_1TO8(RAX,11*8))
|
||||
VFMADD231PD(ZMM(20), ZMM(0), MEM_1TO8(RAX,12*8))
|
||||
VFMADD231PD(ZMM(21), ZMM(0), MEM_1TO8(RAX,13*8))
|
||||
VFMADD231PD(ZMM(22), ZMM(0), MEM_1TO8(RAX,14*8))
|
||||
VFMADD231PD(ZMM(23), ZMM(0), MEM_1TO8(RAX,15*8))
|
||||
VFMADD231PD(ZMM(24), ZMM(0), MEM_1TO8(RAX,16*8))
|
||||
VFMADD231PD(ZMM(25), ZMM(0), MEM_1TO8(RAX,17*8))
|
||||
VFMADD231PD(ZMM(26), ZMM(0), MEM_1TO8(RAX,18*8))
|
||||
VFMADD231PD(ZMM(27), ZMM(0), MEM_1TO8(RAX,19*8))
|
||||
VFMADD231PD(ZMM(28), ZMM(0), MEM_1TO8(RAX,20*8))
|
||||
VFMADD231PD(ZMM(29), ZMM(0), MEM_1TO8(RAX,21*8))
|
||||
VFMADD231PD(ZMM(30), ZMM(0), MEM_1TO8(RAX,22*8))
|
||||
VFMADD231PD(ZMM(31), ZMM(0), MEM_1TO8(RAX,23*8))
|
||||
|
||||
ADD(RAX, IMM(24*8))
|
||||
ADD(RBX, IMM( 8*8))
|
||||
|
||||
SUB(RSI, IMM(1))
|
||||
|
||||
JNZ(MAINLOOP)
|
||||
MAIN_LOOP_L1
|
||||
|
||||
LABEL(POSTACCUM)
|
||||
|
||||
@@ -564,15 +513,15 @@ void bli_dgemm_opt_24x8(
|
||||
|
||||
// Check if C is row stride. If not, jump to the slow scattered update
|
||||
MOV(RAX, VAR(rs_c))
|
||||
LEA(RAX, MEM(,RAX,8))
|
||||
MOV(RBX, VAR(cs_c))
|
||||
LEA(RDI, MEM(RAX,RAX,2))
|
||||
CMP(RBX, IMM(1))
|
||||
//JNE(SCATTEREDUPDATE)
|
||||
JMP(SCATTEREDUPDATE)
|
||||
JNE(SCATTEREDUPDATE)
|
||||
|
||||
VMOVQ(RDX, XMM(1))
|
||||
SAL1(RDX) //shift out sign bit
|
||||
//JZ(COLSTORBZ)
|
||||
JZ(COLSTORBZ)
|
||||
|
||||
UPDATE_C_FOUR_ROWS( 8, 9,10,11)
|
||||
UPDATE_C_FOUR_ROWS(12,13,14,15)
|
||||
@@ -599,12 +548,12 @@ void bli_dgemm_opt_24x8(
|
||||
MOV(RDI, VAR(offsetPtr))
|
||||
VMOVAPS(ZMM(2), MEM(RDI))
|
||||
/* Note that this ignores the upper 32 bits in cs_c */
|
||||
VBROADCASTSS(ZMM(3), VAR(cs_c))
|
||||
VPBROADCASTD(ZMM(3), EBX)
|
||||
VPMULLD(ZMM(2), ZMM(3), ZMM(2))
|
||||
|
||||
VMOVQ(RDX, XMM(1))
|
||||
SAL1(RDX) //shift out sign bit
|
||||
//JZ(SCATTERBZ)
|
||||
JZ(SCATTERBZ)
|
||||
|
||||
UPDATE_C_ROW_SCATTERED( 8)
|
||||
UPDATE_C_ROW_SCATTERED( 9)
|
||||
|
||||
Reference in New Issue
Block a user