mirror of
https://github.com/amd/blis.git
synced 2026-05-12 01:59:59 +00:00
Bugfixes, cleanup of sup dgemm ukernels.
Details:
- Fixed a few not-really-bugs:
- Previously, the d6x8m kernels were still prefetching the next upanel
of A using MR*rs_a instead of ps_a (same for prefetching of next
upanel of B in d6x8n kernels using NR*cs_b instead of ps_b). Given
that the upanels might be packed, using ps_a or ps_b is the correct
way to compute the prefetch address.
- Fixed an obscure bug in the rd_d6x8m kernel that, by dumb luck,
executed as intended even though it was based on a faulty pointer
management. Basically, in the rd_d6x8m kernel, the pointer for B
(stored in rdx) was loaded only once, outside of the jj loop, and in
the second iteration its new position was calculated by incrementing
rdx by the *absolute* offset (four columns), which happened to be the
same as the relative offset (also four columns) that was needed. It
worked only because that loop only executed twice. A similar issue
was fixed in the rd_d6x8n kernels.
- Various cleanups and additions, including:
- Factored out the loading of rs_c into rdi in rd_d6x8[mn] kernels so
that it is loaded only once outside of the loops rather than
multiple times inside the loops.
- Changed outer loop in rd kernels so that the jump/comparison and
loop bounds more closely mimic what you'd see in higher-level source
code. That is, something like:
for( i = 0; i < 6; i+=3 )
rather than something like:
for( i = 0; i <= 3; i+=3 )
- Switched row-based IO to use byte offsets instead of byte column
strides (e.g. via rsi register), which were known to be 8 anyway
since otherwise that conditional branch wouldn't have executed.
- Cleaned up and homogenized prefetching a bit.
- Updated the comments that show the before and after of the
in-register transpositions.
- Added comments to column-based IO cases to indicate which columns
are being accessed/updated.
- Added rbp register to clobber lists.
- Removed some dead (commented out) code.
- Fixed some copy-paste typos in comments in the rv_6x8n kernels.
- Cleaned up whitespace (including leading ws -> tabs).
- Moved edge case (non-milli) kernels to their own directory, d6x8,
and split them into separate files based on the "NR" value of the
kernels (Mx8, Mx4, Mx2, etc.).
- Moved config-specific reference Mx1 kernels into their own file
(e.g. bli_gemmsup_r_haswell_ref_dMx1.c) inside the d6x8 directory.
- Added rd_dMx1 assembly kernels, which seems marginally faster than
the corresponding reference kernels.
- Updated comments in ref_kernels/bli_cntx_ref.c and changed to using
the row-oriented reference kernels for all storage combos.
This commit is contained in:
@@ -65,23 +65,6 @@
|
||||
// Prototype reference microkernels.
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref )
|
||||
|
||||
#if 0
|
||||
// Define parameters and variables for edge case kernel map.
|
||||
#define NUM_MR 4
|
||||
#define NUM_NR 4
|
||||
#define FUNCPTR_T dgemmsup_ker_ft
|
||||
|
||||
static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 };
|
||||
static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 };
|
||||
static FUNCPTR_T kmap[NUM_MR][NUM_NR] =
|
||||
{ /* 8 4 2 1 */
|
||||
/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8m, bli_dgemmsup_rd_haswell_asm_6x4m, bli_dgemmsup_rd_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref_6x1 },
|
||||
/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8m, bli_dgemmsup_rd_haswell_asm_3x4m, bli_dgemmsup_rd_haswell_asm_3x2m, bli_dgemmsup_r_haswell_ref_3x1 },
|
||||
/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8m, bli_dgemmsup_rd_haswell_asm_2x4m, bli_dgemmsup_rd_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref_2x1 },
|
||||
/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8m, bli_dgemmsup_rd_haswell_asm_1x4m, bli_dgemmsup_rd_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref_1x1 }
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
(
|
||||
@@ -135,7 +118,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
}
|
||||
if ( 1 == n_left )
|
||||
{
|
||||
#if 0
|
||||
#if 0
|
||||
const dim_t nr_cur = 1;
|
||||
|
||||
bli_dgemmsup_r_haswell_ref
|
||||
@@ -144,14 +127,14 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
beta, cij, rs_c0, cs_c0, data, cntx
|
||||
);
|
||||
#else
|
||||
#else
|
||||
bli_dgemv_ex
|
||||
(
|
||||
BLIS_NO_TRANSPOSE, conjb, m0, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0,
|
||||
beta, cij, rs_c0, cntx, NULL
|
||||
);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -193,7 +176,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
//lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
//lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a
|
||||
|
||||
mov(var(b), rdx) // load address of b.
|
||||
//mov(var(b), rdx) // load address of b.
|
||||
//mov(var(rs_b), r10) // load rs_b
|
||||
mov(var(cs_b), r11) // load cs_b
|
||||
//lea(mem(, r10, 8), r10) // rs_b *= sizeof(double)
|
||||
@@ -204,8 +187,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
|
||||
|
||||
//mov(var(c), r12) // load address of c
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
|
||||
|
||||
@@ -222,10 +205,11 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
|
||||
|
||||
mov(var(a), r14) // load address of a
|
||||
mov(var(b), rdx) // load address of b
|
||||
mov(var(c), r12) // load address of c
|
||||
|
||||
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
|
||||
imul(imm(1*8), rsi) // rsi *= cs_c = 1*8
|
||||
imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8
|
||||
lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c;
|
||||
|
||||
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
|
||||
@@ -266,14 +250,14 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
lea(mem(rdx), rbx) // rbx = b_jj;
|
||||
|
||||
|
||||
#if 0
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
|
||||
#if 1
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
|
||||
prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c
|
||||
prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c
|
||||
#endif
|
||||
lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a
|
||||
lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a
|
||||
|
||||
|
||||
|
||||
@@ -290,15 +274,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
#if 1
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a
|
||||
prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
|
||||
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -327,7 +311,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -354,15 +338,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
// ---------------------------------- iteration 2
|
||||
|
||||
#if 1
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a
|
||||
prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
|
||||
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -391,7 +375,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -436,15 +420,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
label(.DLOOPKITER4) // EDGE LOOP (ymm)
|
||||
|
||||
#if 1
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a
|
||||
prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
|
||||
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -493,7 +477,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
vmovsd(mem(rax ), xmm0)
|
||||
vmovsd(mem(rax, r8, 1), xmm1)
|
||||
vmovsd(mem(rax, r8, 2), xmm2)
|
||||
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
|
||||
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
|
||||
|
||||
vmovsd(mem(rbx ), xmm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -527,8 +511,6 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
|
||||
|
||||
label(.DPOSTACCUM)
|
||||
|
||||
|
||||
|
||||
// ymm4 ymm7 ymm10 ymm13
|
||||
// ymm5 ymm8 ymm11 ymm14
|
||||
@@ -536,14 +518,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
|
||||
vhaddpd( ymm7, ymm4, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7)
|
||||
vaddpd( xmm0, xmm1, xmm0 )
|
||||
|
||||
vhaddpd( ymm13, ymm10, ymm2 )
|
||||
vextractf128(imm(1), ymm2, xmm1 )
|
||||
vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13)
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm4 )
|
||||
|
||||
// xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7)
|
||||
// xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13)
|
||||
|
||||
vhaddpd( ymm8, ymm5, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
@@ -554,6 +537,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm5 )
|
||||
// xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8)
|
||||
// xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14)
|
||||
|
||||
|
||||
vhaddpd( ymm9, ymm6, ymm0 )
|
||||
@@ -565,15 +550,14 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm6 )
|
||||
// xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9)
|
||||
// xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15)
|
||||
|
||||
// ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13)
|
||||
// ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14)
|
||||
// ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15)
|
||||
|
||||
|
||||
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
mov(var(alpha), rax) // load address of alpha
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
@@ -661,8 +645,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
|
||||
|
||||
add(imm(4), r15) // jj += 4;
|
||||
cmp(imm(4), r15) // compare jj to 4
|
||||
jle(.DLOOP3X4J) // if jj <= 4, jump to beginning
|
||||
cmp(imm(8), r15) // compare jj to 8
|
||||
jl(.DLOOP3X4J) // if jj < 8, jump to beginning
|
||||
// of jj loop; otherwise, loop ends.
|
||||
|
||||
|
||||
@@ -692,7 +676,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
@@ -803,8 +787,8 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
|
||||
|
||||
mov(var(c), r12) // load address of c
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
|
||||
|
||||
@@ -813,13 +797,13 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
// rdx = rbx = b
|
||||
// r9 = m dim index ii
|
||||
// r15 = n dim index jj
|
||||
// r10 = unused
|
||||
|
||||
mov(var(m_iter), r9) // ii = m_iter;
|
||||
|
||||
label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ]
|
||||
|
||||
|
||||
|
||||
#if 0
|
||||
vzeroall() // zero all xmm/ymm registers.
|
||||
#else
|
||||
@@ -846,14 +830,14 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
lea(mem(rdx), rbx) // rbx = b;
|
||||
|
||||
|
||||
#if 0
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
|
||||
#if 1
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
|
||||
prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c
|
||||
prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c
|
||||
#endif
|
||||
lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a
|
||||
lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a
|
||||
|
||||
|
||||
|
||||
@@ -870,15 +854,15 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
#if 1
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a
|
||||
prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
|
||||
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -907,7 +891,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -934,15 +918,15 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
// ---------------------------------- iteration 2
|
||||
|
||||
#if 1
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a
|
||||
prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
|
||||
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -971,7 +955,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1014,11 +998,17 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
|
||||
|
||||
label(.DLOOPKITER4) // EDGE LOOP (ymm)
|
||||
|
||||
#if 1
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
|
||||
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1067,7 +1057,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
vmovsd(mem(rax ), xmm0)
|
||||
vmovsd(mem(rax, r8, 1), xmm1)
|
||||
vmovsd(mem(rax, r8, 2), xmm2)
|
||||
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
|
||||
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
|
||||
|
||||
vmovsd(mem(rbx ), xmm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1101,8 +1091,6 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
|
||||
|
||||
label(.DPOSTACCUM)
|
||||
|
||||
|
||||
|
||||
// ymm4 ymm7 ymm10 ymm13
|
||||
// ymm5 ymm8 ymm11 ymm14
|
||||
@@ -1110,14 +1098,15 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
|
||||
vhaddpd( ymm7, ymm4, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7)
|
||||
vaddpd( xmm0, xmm1, xmm0 )
|
||||
|
||||
vhaddpd( ymm13, ymm10, ymm2 )
|
||||
vextractf128(imm(1), ymm2, xmm1 )
|
||||
vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13)
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm4 )
|
||||
|
||||
// ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7)
|
||||
// ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13)
|
||||
|
||||
vhaddpd( ymm8, ymm5, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
@@ -1128,7 +1117,8 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm5 )
|
||||
|
||||
// ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8)
|
||||
// ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14)
|
||||
|
||||
vhaddpd( ymm9, ymm6, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
@@ -1139,15 +1129,14 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm6 )
|
||||
// ymm6[0] = sum(ymm6); ymm6[1] = sum(ymm9)
|
||||
// ymm6[2] = sum(ymm12); ymm6[3] = sum(ymm15)
|
||||
|
||||
// ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13)
|
||||
// ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14)
|
||||
// ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15)
|
||||
|
||||
|
||||
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
mov(var(alpha), rax) // load address of alpha
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
@@ -1259,7 +1248,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
@@ -1366,6 +1355,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
|
||||
lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
|
||||
|
||||
//lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b
|
||||
//lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a
|
||||
|
||||
|
||||
mov(var(c), r12) // load address of c
|
||||
@@ -1375,41 +1365,44 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
|
||||
|
||||
|
||||
// r12 = rcx = c
|
||||
// r14 = rax = a
|
||||
// rdx = rbx = b
|
||||
// r9 = m dim index ii
|
||||
// r14 = rax = a
|
||||
// rdx = rbx = b
|
||||
// r9 = m dim index ii
|
||||
|
||||
mov(var(m_iter), r9) // ii = m_iter;
|
||||
|
||||
label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
|
||||
|
||||
#if 0
|
||||
vzeroall() // zero all xmm/ymm registers.
|
||||
vzeroall() // zero all xmm/ymm registers.
|
||||
#else
|
||||
// skylake can execute 3 vxorpd ipc with
|
||||
// a latency of 1 cycle, while vzeroall
|
||||
// has a latency of 12 cycles.
|
||||
vxorpd(ymm4, ymm4, ymm4)
|
||||
vxorpd(ymm5, ymm5, ymm5)
|
||||
vxorpd(ymm6, ymm6, ymm6)
|
||||
vxorpd(ymm7, ymm7, ymm7)
|
||||
vxorpd(ymm8, ymm8, ymm8)
|
||||
vxorpd(ymm9, ymm9, ymm9)
|
||||
vxorpd(ymm10, ymm10, ymm10)
|
||||
vxorpd(ymm11, ymm11, ymm11)
|
||||
vxorpd(ymm12, ymm12, ymm12)
|
||||
vxorpd(ymm13, ymm13, ymm13)
|
||||
vxorpd(ymm14, ymm14, ymm14)
|
||||
vxorpd(ymm15, ymm15, ymm15)
|
||||
// skylake can execute 3 vxorpd ipc with
|
||||
// a latency of 1 cycle, while vzeroall
|
||||
// has a latency of 12 cycles.
|
||||
vxorpd(ymm4, ymm4, ymm4)
|
||||
vxorpd(ymm5, ymm5, ymm5)
|
||||
vxorpd(ymm6, ymm6, ymm6)
|
||||
vxorpd(ymm7, ymm7, ymm7)
|
||||
vxorpd(ymm8, ymm8, ymm8)
|
||||
vxorpd(ymm9, ymm9, ymm9)
|
||||
vxorpd(ymm10, ymm10, ymm10)
|
||||
vxorpd(ymm11, ymm11, ymm11)
|
||||
vxorpd(ymm12, ymm12, ymm12)
|
||||
vxorpd(ymm13, ymm13, ymm13)
|
||||
vxorpd(ymm14, ymm14, ymm14)
|
||||
vxorpd(ymm15, ymm15, ymm15)
|
||||
#endif
|
||||
|
||||
|
||||
lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c;
|
||||
lea(mem(r14), rax) // rax = a + 6*ii*rs_a;
|
||||
lea(mem(rdx), rbx) // rbx = b;
|
||||
lea(mem(r12), rcx) // rcx = c_ii;
|
||||
lea(mem(r14), rax) // rax = a_ii;
|
||||
lea(mem(rdx), rbx) // rbx = b;
|
||||
|
||||
|
||||
#if 1
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
lea(mem(rcx, rdi, 2), r10) //
|
||||
lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c;
|
||||
prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c
|
||||
@@ -1418,6 +1411,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
|
||||
prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c
|
||||
prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c
|
||||
prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
@@ -1433,6 +1427,12 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
|
||||
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
#if 0
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
|
||||
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rbx ), ymm0)
|
||||
vmovupd(mem(rbx, r11, 1), ymm1)
|
||||
add(imm(4*8), rbx) // b += 4*rs_b = 4*8;
|
||||
@@ -1497,6 +1497,12 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
|
||||
|
||||
// ---------------------------------- iteration 2
|
||||
|
||||
#if 0
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
|
||||
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rbx ), ymm0)
|
||||
vmovupd(mem(rbx, r11, 1), ymm1)
|
||||
add(imm(4*8), rbx) // b += 4*rs_b = 4*8;
|
||||
@@ -1578,6 +1584,12 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
|
||||
|
||||
|
||||
label(.DLOOPKITER4) // EDGE LOOP (ymm)
|
||||
|
||||
#if 0
|
||||
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
|
||||
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
|
||||
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rbx ), ymm0)
|
||||
vmovupd(mem(rbx, r11, 1), ymm1)
|
||||
@@ -1703,15 +1715,18 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
vaddpd( xmm0, xmm1, xmm14 )
|
||||
|
||||
// xmm4 = sum(ymm4) sum(ymm5)
|
||||
// xmm6 = sum(ymm6) sum(ymm7)
|
||||
// xmm8 = sum(ymm8) sum(ymm9)
|
||||
// xmm10 = sum(ymm10) sum(ymm11)
|
||||
// xmm12 = sum(ymm12) sum(ymm13)
|
||||
// xmm14 = sum(ymm14) sum(ymm15)
|
||||
// xmm4[0:1] = sum(ymm4) sum(ymm5)
|
||||
// xmm6[0:1] = sum(ymm6) sum(ymm7)
|
||||
// xmm8[0:1] = sum(ymm8) sum(ymm9)
|
||||
// xmm10[0:1] = sum(ymm10) sum(ymm11)
|
||||
// xmm12[0:1] = sum(ymm12) sum(ymm13)
|
||||
// xmm14[0:1] = sum(ymm14) sum(ymm15)
|
||||
|
||||
|
||||
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
mov(var(alpha), rax) // load address of alpha
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate
|
||||
@@ -1810,13 +1825,13 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
|
||||
|
||||
|
||||
lea(mem(r12, rdi, 4), r12) //
|
||||
lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c
|
||||
lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c
|
||||
|
||||
lea(mem(r14, r8, 4), r14) //
|
||||
lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a
|
||||
lea(mem(r14, r8, 4), r14) //
|
||||
lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a
|
||||
|
||||
dec(r9) // ii -= 1;
|
||||
jne(.DLOOP3X4I) // iterate again if ii != 0.
|
||||
dec(r9) // ii -= 1;
|
||||
jne(.DLOOP3X4I) // iterate again if ii != 0.
|
||||
|
||||
|
||||
|
||||
@@ -1846,7 +1861,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
@@ -1872,6 +1887,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
|
||||
const dim_t mr_cur = 3;
|
||||
|
||||
bli_dgemmsup_rd_haswell_asm_3x2
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, nr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
@@ -1884,6 +1900,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
|
||||
const dim_t mr_cur = 2;
|
||||
|
||||
bli_dgemmsup_rd_haswell_asm_2x2
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, nr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
@@ -1896,6 +1913,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
|
||||
const dim_t mr_cur = 1;
|
||||
|
||||
bli_dgemmsup_rd_haswell_asm_1x2
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, nr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
|
||||
@@ -65,23 +65,6 @@
|
||||
// Prototype reference microkernels.
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref )
|
||||
|
||||
#if 0
|
||||
// Define parameters and variables for edge case kernel map.
|
||||
#define NUM_MR 4
|
||||
#define NUM_NR 4
|
||||
#define FUNCPTR_T dgemmsup_ker_ft
|
||||
|
||||
static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 };
|
||||
static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 };
|
||||
static FUNCPTR_T kmap[NUM_MR][NUM_NR] =
|
||||
{ /* 8 4 2 1 */
|
||||
/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8n, bli_dgemmsup_rd_haswell_asm_6x4n, bli_dgemmsup_rd_haswell_asm_6x2n, bli_dgemmsup_r_haswell_ref_6x1 },
|
||||
/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8n, bli_dgemmsup_rd_haswell_asm_3x4n, bli_dgemmsup_rd_haswell_asm_3x2n, bli_dgemmsup_r_haswell_ref_3x1 },
|
||||
/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8n, bli_dgemmsup_rd_haswell_asm_2x4n, bli_dgemmsup_rd_haswell_asm_2x2n, bli_dgemmsup_r_haswell_ref_2x1 },
|
||||
/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8n, bli_dgemmsup_rd_haswell_asm_1x4n, bli_dgemmsup_rd_haswell_asm_1x2n, bli_dgemmsup_r_haswell_ref_1x1 }
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
(
|
||||
@@ -161,6 +144,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
const dim_t mr_cur = 3;
|
||||
|
||||
bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, n0, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
@@ -173,6 +157,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
const dim_t mr_cur = 2;
|
||||
|
||||
bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, n0, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
@@ -182,23 +167,24 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
}
|
||||
if ( 1 == m_left )
|
||||
{
|
||||
#if 0
|
||||
#if 1
|
||||
const dim_t mr_cur = 1;
|
||||
|
||||
bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, n0, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
beta, cij, rs_c0, cs_c0, data, cntx
|
||||
);
|
||||
#else
|
||||
#else
|
||||
bli_dgemv_ex
|
||||
(
|
||||
BLIS_TRANSPOSE, conja, k0, n0,
|
||||
alpha, bj, rs_b0, cs_b0, ai, cs_a0,
|
||||
beta, cij, cs_c0, cntx, NULL
|
||||
);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -231,7 +217,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
|
||||
//vzeroall() // zero all xmm/ymm registers.
|
||||
|
||||
mov(var(a), rdx) // load address of a.
|
||||
//mov(var(a), rdx) // load address of a.
|
||||
mov(var(rs_a), r8) // load rs_a
|
||||
//mov(var(cs_a), r9) // load cs_a
|
||||
lea(mem(, r8, 8), r8) // rs_a *= sizeof(double)
|
||||
@@ -247,11 +233,12 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
|
||||
|
||||
lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b
|
||||
//lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a
|
||||
|
||||
|
||||
//mov(var(c), r12) // load address of c
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
|
||||
|
||||
@@ -267,12 +254,10 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
|
||||
|
||||
|
||||
mov(var(a), rdx) // load address of a
|
||||
mov(var(b), r14) // load address of b
|
||||
mov(var(c), r12) // load address of c
|
||||
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii;
|
||||
imul(rdi, rsi) // rsi *= rs_c
|
||||
lea(mem(r12, rsi, 1), r12) // r12 = c + 3*ii*rs_c;
|
||||
@@ -309,19 +294,20 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
vxorpd(ymm15, ymm15, ymm15)
|
||||
#endif
|
||||
|
||||
|
||||
lea(mem(r12), rcx) // rcx = c_iijj;
|
||||
lea(mem(rdx), rax) // rax = a_ii;
|
||||
lea(mem(r14), rbx) // rbx = b_jj;
|
||||
|
||||
|
||||
#if 1
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
|
||||
prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c
|
||||
prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c
|
||||
#endif
|
||||
lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b
|
||||
//lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a
|
||||
lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b
|
||||
|
||||
|
||||
@@ -338,13 +324,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
#if 0
|
||||
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
|
||||
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
|
||||
prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b
|
||||
prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b
|
||||
add(imm(8*8), r10) // r10 += 8*rs_b = 8*8;
|
||||
#else
|
||||
#if 1
|
||||
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
|
||||
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
|
||||
#endif
|
||||
@@ -352,7 +332,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -386,7 +366,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -420,7 +400,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -455,7 +435,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -502,7 +482,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -551,7 +531,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
vmovsd(mem(rax ), xmm0)
|
||||
vmovsd(mem(rax, r8, 1), xmm1)
|
||||
vmovsd(mem(rax, r8, 2), xmm2)
|
||||
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
|
||||
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
|
||||
|
||||
vmovsd(mem(rbx ), xmm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -585,8 +565,6 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
|
||||
|
||||
label(.DPOSTACCUM)
|
||||
|
||||
|
||||
|
||||
// ymm4 ymm7 ymm10 ymm13
|
||||
// ymm5 ymm8 ymm11 ymm14
|
||||
@@ -594,14 +572,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
|
||||
vhaddpd( ymm7, ymm4, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7)
|
||||
vaddpd( xmm0, xmm1, xmm0 )
|
||||
|
||||
vhaddpd( ymm13, ymm10, ymm2 )
|
||||
vextractf128(imm(1), ymm2, xmm1 )
|
||||
vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13)
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm4 )
|
||||
|
||||
// ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7)
|
||||
// ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13)
|
||||
|
||||
vhaddpd( ymm8, ymm5, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
@@ -612,7 +591,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm5 )
|
||||
|
||||
// ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8)
|
||||
// ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14)
|
||||
|
||||
vhaddpd( ymm9, ymm6, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
@@ -623,15 +603,14 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm6 )
|
||||
// ymm6[0] = sum(ymm6); ymm6[1] = sum(ymm9)
|
||||
// ymm6[2] = sum(ymm12); ymm6[3] = sum(ymm15)
|
||||
|
||||
// ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13)
|
||||
// ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14)
|
||||
// ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15)
|
||||
|
||||
|
||||
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
mov(var(alpha), rax) // load address of alpha
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
@@ -716,8 +695,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
|
||||
|
||||
add(imm(3), r9) // ii += 3;
|
||||
cmp(imm(3), r9) // compare ii to 3
|
||||
jle(.DLOOP3X4I) // if ii <= 3, jump to beginning
|
||||
cmp(imm(6), r9) // compare ii to 6
|
||||
jl(.DLOOP3X4I) // if ii < 6, jump to beginning
|
||||
// of ii loop; otherwise, loop ends.
|
||||
|
||||
|
||||
@@ -748,7 +727,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
@@ -759,7 +738,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
// Handle edge cases in the m dimension, if they exist.
|
||||
// Handle edge cases in the n dimension, if they exist.
|
||||
if ( n_left )
|
||||
{
|
||||
const dim_t mr_cur = 6;
|
||||
@@ -774,6 +753,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
const dim_t nr_cur = 2;
|
||||
|
||||
bli_dgemmsup_rd_haswell_asm_6x2
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, nr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
@@ -783,24 +763,24 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
|
||||
}
|
||||
if ( 1 == n_left )
|
||||
{
|
||||
#if 0
|
||||
#if 1
|
||||
const dim_t nr_cur = 1;
|
||||
|
||||
//bli_dgemmsup_rd_haswell_asm_6x1n
|
||||
bli_dgemmsup_r_haswell_ref
|
||||
bli_dgemmsup_rd_haswell_asm_6x1
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, nr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
beta, cij, rs_c0, cs_c0, data, cntx
|
||||
);
|
||||
#else
|
||||
#else
|
||||
bli_dgemv_ex
|
||||
(
|
||||
BLIS_NO_TRANSPOSE, conjb, mr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0,
|
||||
beta, cij, rs_c0, cntx, NULL
|
||||
);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -865,11 +845,12 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
|
||||
|
||||
lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b
|
||||
//lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a
|
||||
|
||||
|
||||
mov(var(c), r12) // load address of c
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
|
||||
|
||||
@@ -905,19 +886,20 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
vxorpd(ymm15, ymm15, ymm15)
|
||||
#endif
|
||||
|
||||
|
||||
lea(mem(r12), rcx) // rcx = c_iijj;
|
||||
lea(mem(rdx), rax) // rax = a_ii;
|
||||
lea(mem(r14), rbx) // rbx = b_jj;
|
||||
|
||||
|
||||
#if 1
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
|
||||
prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c
|
||||
prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c
|
||||
#endif
|
||||
lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b
|
||||
//lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a
|
||||
lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b
|
||||
|
||||
|
||||
@@ -934,13 +916,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
#if 0
|
||||
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
|
||||
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
|
||||
prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b
|
||||
prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b
|
||||
add(imm(8*8), r10) // r10 += 8*rs_b = 8*8;
|
||||
#else
|
||||
#if 1
|
||||
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
|
||||
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
|
||||
#endif
|
||||
@@ -948,7 +924,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -982,7 +958,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1016,7 +992,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1051,7 +1027,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1098,7 +1074,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
vmovupd(mem(rax, r8, 2), ymm2)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1147,7 +1123,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
vmovsd(mem(rax ), xmm0)
|
||||
vmovsd(mem(rax, r8, 1), xmm1)
|
||||
vmovsd(mem(rax, r8, 2), xmm2)
|
||||
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
|
||||
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
|
||||
|
||||
vmovsd(mem(rbx ), xmm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1181,8 +1157,6 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
|
||||
|
||||
label(.DPOSTACCUM)
|
||||
|
||||
|
||||
|
||||
// ymm4 ymm7 ymm10 ymm13
|
||||
// ymm5 ymm8 ymm11 ymm14
|
||||
@@ -1190,14 +1164,15 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
|
||||
vhaddpd( ymm7, ymm4, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7)
|
||||
vaddpd( xmm0, xmm1, xmm0 )
|
||||
|
||||
vhaddpd( ymm13, ymm10, ymm2 )
|
||||
vextractf128(imm(1), ymm2, xmm1 )
|
||||
vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13)
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm4 )
|
||||
|
||||
// ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7)
|
||||
// ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13)
|
||||
|
||||
vhaddpd( ymm8, ymm5, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
@@ -1208,7 +1183,8 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm5 )
|
||||
|
||||
// ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8)
|
||||
// ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14)
|
||||
|
||||
vhaddpd( ymm9, ymm6, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
@@ -1219,15 +1195,14 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm6 )
|
||||
|
||||
// ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13)
|
||||
// ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14)
|
||||
// ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15)
|
||||
// ymm6[0] = sum(ymm6); ymm6[1] = sum(ymm9)
|
||||
// ymm6[2] = sum(ymm12); ymm6[3] = sum(ymm15)
|
||||
|
||||
|
||||
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
mov(var(alpha), rax) // load address of alpha
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
@@ -1312,6 +1287,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DRETURN)
|
||||
|
||||
|
||||
@@ -1337,7 +1313,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
@@ -1348,7 +1324,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
// Handle edge cases in the m dimension, if they exist.
|
||||
// Handle edge cases in the n dimension, if they exist.
|
||||
if ( n_left )
|
||||
{
|
||||
const dim_t mr_cur = 3;
|
||||
@@ -1363,6 +1339,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
const dim_t nr_cur = 2;
|
||||
|
||||
bli_dgemmsup_rd_haswell_asm_3x2
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, nr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
@@ -1372,23 +1349,25 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
|
||||
}
|
||||
if ( 1 == n_left )
|
||||
{
|
||||
#if 0
|
||||
#if 1
|
||||
const dim_t nr_cur = 1;
|
||||
|
||||
bli_dgemmsup_r_haswell_ref_3x1
|
||||
bli_dgemmsup_rd_haswell_asm_3x1
|
||||
//bli_dgemmsup_r_haswell_ref_3x1
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, nr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
beta, cij, rs_c0, cs_c0, data, cntx
|
||||
);
|
||||
#else
|
||||
#else
|
||||
bli_dgemv_ex
|
||||
(
|
||||
BLIS_NO_TRANSPOSE, conjb, mr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0,
|
||||
beta, cij, rs_c0, cntx, NULL
|
||||
);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1453,11 +1432,12 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
|
||||
|
||||
lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b
|
||||
//lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a
|
||||
|
||||
|
||||
mov(var(c), r12) // load address of c
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
|
||||
|
||||
@@ -1489,18 +1469,19 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
vxorpd(ymm14, ymm14, ymm14)
|
||||
#endif
|
||||
|
||||
|
||||
lea(mem(r12), rcx) // rcx = c_iijj;
|
||||
lea(mem(rdx), rax) // rax = a_ii;
|
||||
lea(mem(r14), rbx) // rbx = b_jj;
|
||||
|
||||
|
||||
#if 1
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
|
||||
prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c
|
||||
#endif
|
||||
lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b
|
||||
//lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a
|
||||
lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b
|
||||
|
||||
|
||||
@@ -1517,20 +1498,14 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
#if 0
|
||||
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
|
||||
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
|
||||
prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b
|
||||
prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b
|
||||
add(imm(8*8), r10) // r10 += 8*rs_b = 8*8;
|
||||
#else
|
||||
#if 1
|
||||
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
|
||||
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1559,7 +1534,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1588,7 +1563,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1618,7 +1593,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1660,7 +1635,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
vmovupd(mem(rax, r8, 1), ymm1)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1704,7 +1679,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
|
||||
vmovsd(mem(rax ), xmm0)
|
||||
vmovsd(mem(rax, r8, 1), xmm1)
|
||||
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
|
||||
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
|
||||
|
||||
vmovsd(mem(rbx ), xmm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1734,23 +1709,21 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
|
||||
|
||||
label(.DPOSTACCUM)
|
||||
|
||||
|
||||
|
||||
// ymm4 ymm7 ymm10 ymm13
|
||||
// ymm5 ymm8 ymm11 ymm14
|
||||
// ymm6 ymm9 ymm12 ymm15
|
||||
|
||||
vhaddpd( ymm7, ymm4, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7)
|
||||
vaddpd( xmm0, xmm1, xmm0 )
|
||||
|
||||
vhaddpd( ymm13, ymm10, ymm2 )
|
||||
vextractf128(imm(1), ymm2, xmm1 )
|
||||
vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13)
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm4 )
|
||||
|
||||
// ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7)
|
||||
// ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13)
|
||||
|
||||
vhaddpd( ymm8, ymm5, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
@@ -1761,14 +1734,14 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm5 )
|
||||
// ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8)
|
||||
// ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14)
|
||||
|
||||
// ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13)
|
||||
// ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14)
|
||||
|
||||
|
||||
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
mov(var(alpha), rax) // load address of alpha
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
@@ -1777,7 +1750,6 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
|
||||
vmulpd(ymm0, ymm4, ymm4) // scale by alpha
|
||||
vmulpd(ymm0, ymm5, ymm5)
|
||||
vmulpd(ymm0, ymm6, ymm6)
|
||||
|
||||
|
||||
|
||||
@@ -1846,6 +1818,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DRETURN)
|
||||
|
||||
|
||||
@@ -1871,7 +1844,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
@@ -1882,7 +1855,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
// Handle edge cases in the m dimension, if they exist.
|
||||
// Handle edge cases in the n dimension, if they exist.
|
||||
if ( n_left )
|
||||
{
|
||||
const dim_t mr_cur = 2;
|
||||
@@ -1897,6 +1870,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
const dim_t nr_cur = 2;
|
||||
|
||||
bli_dgemmsup_rd_haswell_asm_2x2
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, nr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
@@ -1906,23 +1880,24 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
|
||||
}
|
||||
if ( 1 == n_left )
|
||||
{
|
||||
#if 0
|
||||
#if 1
|
||||
const dim_t nr_cur = 1;
|
||||
|
||||
bli_dgemmsup_r_haswell_ref_2x1
|
||||
bli_dgemmsup_rd_haswell_asm_2x1
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, nr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
beta, cij, rs_c0, cs_c0, data, cntx
|
||||
);
|
||||
#else
|
||||
#else
|
||||
bli_dgemv_ex
|
||||
(
|
||||
BLIS_NO_TRANSPOSE, conjb, mr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0,
|
||||
beta, cij, rs_c0, cntx, NULL
|
||||
);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1987,11 +1962,12 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
|
||||
|
||||
lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b
|
||||
//lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a
|
||||
|
||||
|
||||
mov(var(c), r12) // load address of c
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
|
||||
|
||||
@@ -2019,18 +1995,18 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
vxorpd(ymm13, ymm13, ymm13)
|
||||
#endif
|
||||
|
||||
|
||||
lea(mem(r12), rcx) // rcx = c_iijj;
|
||||
lea(mem(rdx), rax) // rax = a_ii;
|
||||
lea(mem(r14), rbx) // rbx = b_jj;
|
||||
|
||||
|
||||
#if 1
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
|
||||
prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
|
||||
#endif
|
||||
lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b
|
||||
//lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a
|
||||
lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b
|
||||
|
||||
|
||||
@@ -2047,19 +2023,13 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
#if 0
|
||||
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
|
||||
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
|
||||
prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b
|
||||
prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b
|
||||
add(imm(8*8), r10) // r10 += 8*rs_b = 8*8;
|
||||
#else
|
||||
#if 1
|
||||
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
|
||||
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -2083,7 +2053,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -2107,7 +2077,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -2132,7 +2102,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
#endif
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -2169,7 +2139,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
label(.DLOOPKITER4) // EDGE LOOP (ymm)
|
||||
|
||||
vmovupd(mem(rax ), ymm0)
|
||||
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
|
||||
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
|
||||
|
||||
vmovupd(mem(rbx ), ymm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -2208,7 +2178,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
// which would destory intermediate results.
|
||||
|
||||
vmovsd(mem(rax ), xmm0)
|
||||
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
|
||||
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
|
||||
|
||||
vmovsd(mem(rbx ), xmm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -2234,29 +2204,26 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
|
||||
|
||||
label(.DPOSTACCUM)
|
||||
|
||||
|
||||
|
||||
// ymm4 ymm7 ymm10 ymm13
|
||||
// ymm5 ymm8 ymm11 ymm14
|
||||
// ymm6 ymm9 ymm12 ymm15
|
||||
|
||||
vhaddpd( ymm7, ymm4, ymm0 )
|
||||
vextractf128(imm(1), ymm0, xmm1 )
|
||||
vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7)
|
||||
vaddpd( xmm0, xmm1, xmm0 )
|
||||
|
||||
vhaddpd( ymm13, ymm10, ymm2 )
|
||||
vextractf128(imm(1), ymm2, xmm1 )
|
||||
vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13)
|
||||
vaddpd( xmm2, xmm1, xmm2 )
|
||||
|
||||
vperm2f128(imm(0x20), ymm2, ymm0, ymm4 )
|
||||
// ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7)
|
||||
// ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13)
|
||||
|
||||
// ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13)
|
||||
|
||||
|
||||
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
//mov(var(rs_c), rdi) // load rs_c
|
||||
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
|
||||
|
||||
mov(var(alpha), rax) // load address of alpha
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
@@ -2325,6 +2292,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
|
||||
|
||||
|
||||
|
||||
label(.DRETURN)
|
||||
|
||||
|
||||
@@ -2350,7 +2318,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
@@ -2361,7 +2329,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
|
||||
consider_edge_cases:
|
||||
|
||||
// Handle edge cases in the m dimension, if they exist.
|
||||
// Handle edge cases in the n dimension, if they exist.
|
||||
if ( n_left )
|
||||
{
|
||||
const dim_t mr_cur = 1;
|
||||
@@ -2376,6 +2344,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
const dim_t nr_cur = 2;
|
||||
|
||||
bli_dgemmsup_rd_haswell_asm_1x2
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, nr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
@@ -2385,23 +2354,24 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
|
||||
}
|
||||
if ( 1 == n_left )
|
||||
{
|
||||
#if 0
|
||||
#if 1
|
||||
const dim_t nr_cur = 1;
|
||||
|
||||
bli_dgemmsup_r_haswell_ref_1x1
|
||||
bli_dgemmsup_rd_haswell_asm_1x1
|
||||
//bli_dgemmsup_r_haswell_ref
|
||||
(
|
||||
conja, conjb, mr_cur, nr_cur, k0,
|
||||
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
|
||||
beta, cij, rs_c0, cs_c0, data, cntx
|
||||
);
|
||||
#else
|
||||
#else
|
||||
bli_ddotxv_ex
|
||||
(
|
||||
conja, conjb, k0,
|
||||
alpha, ai, cs_a0, bj, rs_b0,
|
||||
beta, cij, cntx, NULL
|
||||
);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
158
kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c
Normal file
158
kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c
Normal file
@@ -0,0 +1,158 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2019, Advanced Micro Devices, Inc.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
|
||||
#define BLIS_ASM_SYNTAX_ATT
|
||||
#include "bli_x86_asm_macros.h"
|
||||
|
||||
/*
|
||||
rrr:
|
||||
-------- ------ --------
|
||||
-------- ------ --------
|
||||
-------- += ------ ... --------
|
||||
-------- ------ --------
|
||||
-------- ------ :
|
||||
-------- ------ :
|
||||
|
||||
rcr:
|
||||
-------- | | | | --------
|
||||
-------- | | | | --------
|
||||
-------- += | | | | ... --------
|
||||
-------- | | | | --------
|
||||
-------- | | | | :
|
||||
-------- | | | | :
|
||||
|
||||
Assumptions:
|
||||
- B is row-stored;
|
||||
- A is row- or column-stored;
|
||||
- m0 and n0 are at most MR and NR, respectively.
|
||||
Therefore, this (r)ow-preferential kernel is well-suited for contiguous
|
||||
(v)ector loads on B and single-element broadcasts from A.
|
||||
|
||||
NOTE: These kernels explicitly support column-oriented IO, implemented
|
||||
via an in-register transpose. And thus they also support the crr and
|
||||
ccr cases, though only crr is ever utilized (because ccr is handled by
|
||||
transposing the operation and executing rcr, which does not incur the
|
||||
cost of the in-register transpose).
|
||||
|
||||
crr:
|
||||
| | | | | | | | ------ --------
|
||||
| | | | | | | | ------ --------
|
||||
| | | | | | | | += ------ ... --------
|
||||
| | | | | | | | ------ --------
|
||||
| | | | | | | | ------ :
|
||||
| | | | | | | | ------ :
|
||||
*/
|
||||
|
||||
// Prototype reference microkernels.
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref )
|
||||
|
||||
|
||||
// NOTE: Normally, for any "?x1" kernel, we would call the reference kernel.
|
||||
// However, at least one other subconfiguration (zen) uses this kernel set, so
|
||||
// we need to be able to call a set of "?x1" kernels that we know will actually
|
||||
// exist regardless of which subconfiguration these kernels were used by. Thus,
|
||||
// the compromise employed here is to inline the reference kernel so it gets
|
||||
// compiled as part of the haswell kernel set, and hence can unconditionally be
|
||||
// called by other kernels within that kernel set.
|
||||
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC( ctype, ch, opname, mdim ) \
|
||||
\
|
||||
void PASTEMAC(ch,opname) \
|
||||
( \
|
||||
conj_t conja, \
|
||||
conj_t conjb, \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
ctype* restrict alpha, \
|
||||
ctype* restrict a, inc_t rs_a, inc_t cs_a, \
|
||||
ctype* restrict b, inc_t rs_b, inc_t cs_b, \
|
||||
ctype* restrict beta, \
|
||||
ctype* restrict c, inc_t rs_c, inc_t cs_c, \
|
||||
auxinfo_t* restrict data, \
|
||||
cntx_t* restrict cntx \
|
||||
) \
|
||||
{ \
|
||||
for ( dim_t i = 0; i < mdim; ++i ) \
|
||||
{ \
|
||||
ctype* restrict ci = &c[ i*rs_c ]; \
|
||||
ctype* restrict ai = &a[ i*rs_a ]; \
|
||||
\
|
||||
/* for ( dim_t j = 0; j < 1; ++j ) */ \
|
||||
{ \
|
||||
ctype* restrict cij = ci /*[ j*cs_c ]*/ ; \
|
||||
ctype* restrict bj = b /*[ j*cs_b ]*/ ; \
|
||||
ctype ab; \
|
||||
\
|
||||
PASTEMAC(ch,set0s)( ab ); \
|
||||
\
|
||||
/* Perform a dot product to update the (i,j) element of c. */ \
|
||||
for ( dim_t l = 0; l < k; ++l ) \
|
||||
{ \
|
||||
ctype* restrict aij = &ai[ l*cs_a ]; \
|
||||
ctype* restrict bij = &bj[ l*rs_b ]; \
|
||||
\
|
||||
PASTEMAC(ch,dots)( *aij, *bij, ab ); \
|
||||
} \
|
||||
\
|
||||
/* If beta is one, add ab into c. If beta is zero, overwrite c
|
||||
with the result in ab. Otherwise, scale by beta and accumulate
|
||||
ab to c. */ \
|
||||
if ( PASTEMAC(ch,eq1)( *beta ) ) \
|
||||
{ \
|
||||
PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \
|
||||
} \
|
||||
else if ( PASTEMAC(d,eq0)( *beta ) ) \
|
||||
{ \
|
||||
PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
GENTFUNC( double, d, gemmsup_r_haswell_ref_6x1, 6 )
|
||||
GENTFUNC( double, d, gemmsup_r_haswell_ref_5x1, 5 )
|
||||
GENTFUNC( double, d, gemmsup_r_haswell_ref_4x1, 4 )
|
||||
GENTFUNC( double, d, gemmsup_r_haswell_ref_3x1, 3 )
|
||||
GENTFUNC( double, d, gemmsup_r_haswell_ref_2x1, 2 )
|
||||
GENTFUNC( double, d, gemmsup_r_haswell_ref_1x1, 1 )
|
||||
|
||||
1698
kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c
Normal file
1698
kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c
Normal file
File diff suppressed because it is too large
Load Diff
1794
kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c
Normal file
1794
kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c
Normal file
File diff suppressed because it is too large
Load Diff
1450
kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c
Normal file
1450
kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c
Normal file
File diff suppressed because it is too large
Load Diff
1617
kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c
Normal file
1617
kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c
Normal file
File diff suppressed because it is too large
Load Diff
2496
kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c
Normal file
2496
kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c
Normal file
File diff suppressed because it is too large
Load Diff
2600
kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c
Normal file
2600
kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c
Normal file
File diff suppressed because it is too large
Load Diff
3095
kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c
Normal file
3095
kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c
Normal file
File diff suppressed because it is too large
Load Diff
3260
kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c
Normal file
3260
kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c
Normal file
File diff suppressed because it is too large
Load Diff
@@ -219,7 +219,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8
|
||||
mov(var(c), r12) // load address of c
|
||||
|
||||
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
|
||||
imul(imm(1*8), rsi) // rsi *= cs_c = 1*8
|
||||
imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8
|
||||
lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c;
|
||||
|
||||
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
|
||||
@@ -487,7 +487,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8
|
||||
vmovsd(mem(rax ), xmm0)
|
||||
vmovsd(mem(rax, r8, 1), xmm1)
|
||||
vmovsd(mem(rax, r8, 2), xmm2)
|
||||
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
|
||||
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
|
||||
|
||||
vmovsd(mem(rbx ), xmm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -788,6 +788,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8
|
||||
lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
|
||||
|
||||
lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b
|
||||
//lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a
|
||||
|
||||
|
||||
mov(var(c), r12) // load address of c
|
||||
@@ -826,7 +827,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8
|
||||
|
||||
|
||||
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
|
||||
imul(imm(1*8), rsi) // rsi *= cs_c = 1*8
|
||||
imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8
|
||||
lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c;
|
||||
|
||||
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
|
||||
@@ -1018,7 +1019,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8
|
||||
|
||||
vmovsd(mem(rax ), xmm0)
|
||||
vmovsd(mem(rax, r8, 1), xmm1)
|
||||
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
|
||||
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
|
||||
|
||||
vmovsd(mem(rbx ), xmm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1274,7 +1275,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8
|
||||
|
||||
|
||||
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
|
||||
imul(imm(1*8), rsi) // rsi *= cs_c = 1*8
|
||||
imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8
|
||||
lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c;
|
||||
|
||||
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
|
||||
@@ -1438,7 +1439,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8
|
||||
// which would destory intermediate results.
|
||||
|
||||
vmovsd(mem(rax ), xmm0)
|
||||
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
|
||||
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
|
||||
|
||||
vmovsd(mem(rbx ), xmm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -1908,7 +1909,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4
|
||||
vmovsd(mem(rax ), xmm0)
|
||||
vmovsd(mem(rax, r8, 1), xmm1)
|
||||
vmovsd(mem(rax, r8, 2), xmm2)
|
||||
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
|
||||
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
|
||||
|
||||
vmovsd(mem(rbx ), xmm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -2398,7 +2399,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4
|
||||
|
||||
vmovsd(mem(rax ), xmm0)
|
||||
vmovsd(mem(rax, r8, 1), xmm1)
|
||||
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
|
||||
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
|
||||
|
||||
vmovsd(mem(rbx ), xmm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -2783,7 +2784,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4
|
||||
// which would destory intermediate results.
|
||||
|
||||
vmovsd(mem(rax ), xmm0)
|
||||
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
|
||||
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
|
||||
|
||||
vmovsd(mem(rbx ), xmm3)
|
||||
vfmadd231pd(ymm0, ymm3, ymm4)
|
||||
@@ -257,6 +257,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8
|
||||
prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c
|
||||
prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c
|
||||
prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c
|
||||
prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c
|
||||
|
||||
jmp(.DPOSTPFETCH) // jump to end of prefetching c
|
||||
label(.DCOLPFETCH) // column-stored prefetching c
|
||||
@@ -65,6 +65,17 @@ GEMMTRSM_UKR_PROT( double, d, gemmtrsm_u_haswell_asm_6x8 )
|
||||
|
||||
// -- level-3 sup --------------------------------------------------------------
|
||||
|
||||
// -- double real --
|
||||
|
||||
// gemmsup_r
|
||||
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_6x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_5x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_4x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_3x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_2x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_1x1 )
|
||||
|
||||
// gemmsup_rv
|
||||
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8 )
|
||||
@@ -95,13 +106,6 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x2 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x2 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x2 )
|
||||
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_6x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_5x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_4x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_3x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_2x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_1x1 )
|
||||
|
||||
// gemmsup_rv (mkernel in m dim)
|
||||
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m )
|
||||
@@ -133,6 +137,11 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_3x2 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x2 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x2 )
|
||||
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_3x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x1 )
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x1 )
|
||||
|
||||
// gemmsup_rd (mkernel in m dim)
|
||||
|
||||
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m )
|
||||
|
||||
@@ -418,7 +418,8 @@ void GENBARNAME(cntx_init)
|
||||
gen_func_init( &funcs[ BLIS_TRSM_L_UKR ], trsm_l_ukr_name );
|
||||
gen_func_init( &funcs[ BLIS_TRSM_U_UKR ], trsm_u_ukr_name );
|
||||
|
||||
bli_mbool_init( &mbools[ BLIS_GEMM_UKR ], TRUE, TRUE, TRUE, TRUE );
|
||||
// s d c z
|
||||
bli_mbool_init( &mbools[ BLIS_GEMM_UKR ], TRUE, TRUE, TRUE, TRUE );
|
||||
bli_mbool_init( &mbools[ BLIS_GEMMTRSM_L_UKR ], FALSE, FALSE, FALSE, FALSE );
|
||||
bli_mbool_init( &mbools[ BLIS_GEMMTRSM_U_UKR ], FALSE, FALSE, FALSE, FALSE );
|
||||
bli_mbool_init( &mbools[ BLIS_TRSM_L_UKR ], FALSE, FALSE, FALSE, FALSE );
|
||||
@@ -427,11 +428,14 @@ void GENBARNAME(cntx_init)
|
||||
|
||||
// -- Set level-3 small/unpacked thresholds --------------------------------
|
||||
|
||||
// NOTE: The default thresholds are set very low so that the sup framework
|
||||
// only actives for exceedingly small dimensions. If a sub-configuration
|
||||
// registers optimized sup kernels, then that sub-configuration should also
|
||||
// register new (probably larger) thresholds that are almost surely more
|
||||
// appropriate that these default values.
|
||||
// NOTE: The default thresholds are set to zero so that the sup framework
|
||||
// does not activate by default. Note that the semantic meaning of the
|
||||
// thresholds is that the sup code path is executed if a dimension is
|
||||
// strictly less than its corresponding threshold. So actually, the
|
||||
// thresholds specify the minimum dimension size that will still dispatch
|
||||
// the non-sup/large code path. This "strictly less than" behavior was
|
||||
// chosen over "less than or equal to" so that threshold values of 0 would
|
||||
// effectively disable sup (even for matrix dimensions of 0).
|
||||
// s d c z
|
||||
bli_blksz_init_easy( &thresh[ BLIS_MT ], 0, 0, 0, 0 );
|
||||
bli_blksz_init_easy( &thresh[ BLIS_NT ], 0, 0, 0, 0 );
|
||||
@@ -486,10 +490,10 @@ void GENBARNAME(cntx_init)
|
||||
gen_func_init( &funcs[ BLIS_RRC ], gemmsup_rv_ukr_name );
|
||||
gen_func_init( &funcs[ BLIS_RCR ], gemmsup_rv_ukr_name );
|
||||
gen_func_init( &funcs[ BLIS_RCC ], gemmsup_rv_ukr_name );
|
||||
gen_func_init( &funcs[ BLIS_CRR ], gemmsup_cv_ukr_name );
|
||||
gen_func_init( &funcs[ BLIS_CRC ], gemmsup_cv_ukr_name );
|
||||
gen_func_init( &funcs[ BLIS_CCR ], gemmsup_cv_ukr_name );
|
||||
gen_func_init( &funcs[ BLIS_CCC ], gemmsup_cv_ukr_name );
|
||||
gen_func_init( &funcs[ BLIS_CRR ], gemmsup_rv_ukr_name );
|
||||
gen_func_init( &funcs[ BLIS_CRC ], gemmsup_rv_ukr_name );
|
||||
gen_func_init( &funcs[ BLIS_CCR ], gemmsup_rv_ukr_name );
|
||||
gen_func_init( &funcs[ BLIS_CCC ], gemmsup_rv_ukr_name );
|
||||
|
||||
// Register the general-stride/generic ukernel to the "catch-all" slot
|
||||
// associated with the BLIS_XXX enum value. This slot will be queried if
|
||||
@@ -498,16 +502,17 @@ void GENBARNAME(cntx_init)
|
||||
|
||||
|
||||
// Set the l3 sup ukernel storage preferences.
|
||||
bli_mbool_init( &mbools[ BLIS_RRR ], TRUE, TRUE, TRUE, TRUE );
|
||||
bli_mbool_init( &mbools[ BLIS_RRC ], TRUE, TRUE, TRUE, TRUE );
|
||||
bli_mbool_init( &mbools[ BLIS_RCR ], TRUE, TRUE, TRUE, TRUE );
|
||||
bli_mbool_init( &mbools[ BLIS_RCC ], TRUE, TRUE, TRUE, TRUE );
|
||||
bli_mbool_init( &mbools[ BLIS_CRR ], FALSE, FALSE, FALSE, FALSE );
|
||||
bli_mbool_init( &mbools[ BLIS_CRC ], FALSE, FALSE, FALSE, FALSE );
|
||||
bli_mbool_init( &mbools[ BLIS_CCR ], FALSE, FALSE, FALSE, FALSE );
|
||||
bli_mbool_init( &mbools[ BLIS_CCC ], FALSE, FALSE, FALSE, FALSE );
|
||||
// s d c z
|
||||
bli_mbool_init( &mbools[ BLIS_RRR ], TRUE, TRUE, TRUE, TRUE );
|
||||
bli_mbool_init( &mbools[ BLIS_RRC ], TRUE, TRUE, TRUE, TRUE );
|
||||
bli_mbool_init( &mbools[ BLIS_RCR ], TRUE, TRUE, TRUE, TRUE );
|
||||
bli_mbool_init( &mbools[ BLIS_RCC ], TRUE, TRUE, TRUE, TRUE );
|
||||
bli_mbool_init( &mbools[ BLIS_CRR ], TRUE, TRUE, TRUE, TRUE );
|
||||
bli_mbool_init( &mbools[ BLIS_CRC ], TRUE, TRUE, TRUE, TRUE );
|
||||
bli_mbool_init( &mbools[ BLIS_CCR ], TRUE, TRUE, TRUE, TRUE );
|
||||
bli_mbool_init( &mbools[ BLIS_CCC ], TRUE, TRUE, TRUE, TRUE );
|
||||
|
||||
bli_mbool_init( &mbools[ BLIS_XXX ], FALSE, FALSE, FALSE, FALSE );
|
||||
bli_mbool_init( &mbools[ BLIS_XXX ], TRUE, TRUE, TRUE, TRUE );
|
||||
|
||||
|
||||
// -- Set level-1f kernels -------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user