Bugfixes, cleanup of sup dgemm ukernels.

Details:
- Fixed a few not-really-bugs:
  - Previously, the d6x8m kernels were still prefetching the next upanel
    of A using MR*rs_a instead of ps_a (same for prefetching of next
    upanel of B in d6x8n kernels using NR*cs_b instead of ps_b). Given
    that the upanels might be packed, using ps_a or ps_b is the correct
    way to compute the prefetch address.
  - Fixed an obscure bug in the rd_d6x8m kernel that, by dumb luck,
    executed as intended even though it was based on a faulty pointer
    management. Basically, in the rd_d6x8m kernel, the pointer for B
    (stored in rdx) was loaded only once, outside of the jj loop, and in
    the second iteration its new position was calculated by incrementing
    rdx by the *absolute* offset (four columns), which happened to be the
    same as the relative offset (also four columns) that was needed. It
    worked only because that loop only executed twice. A similar issue
    was fixed in the rd_d6x8n kernels.
- Various cleanups and additions, including:
  - Factored out the loading of rs_c into rdi in rd_d6x8[mn] kernels so
    that it is loaded only once outside of the loops rather than
    multiple times inside the loops.
  - Changed outer loop in rd kernels so that the jump/comparison and
    loop bounds more closely mimic what you'd see in higher-level source
    code. That is, something like:
      for( i = 0; i < 6; i+=3 )
    rather than something like:
      for( i = 0; i <= 3; i+=3 )
  - Switched row-based IO to use byte offsets instead of byte column
    strides (e.g. via rsi register), which were known to be 8 anyway
    since otherwise that conditional branch wouldn't have executed.
  - Cleaned up and homogenized prefetching a bit.
  - Updated the comments that show the before and after of the
    in-register transpositions.
  - Added comments to column-based IO cases to indicate which columns
    are being accessed/updated.
  - Added rbp register to clobber lists.
  - Removed some dead (commented out) code.
  - Fixed some copy-paste typos in comments in the rv_6x8n kernels.
  - Cleaned up whitespace (including leading ws -> tabs).
  - Moved edge case (non-milli) kernels to their own directory, d6x8,
    and split them into separate files based on the "NR" value of the
    kernels (Mx8, Mx4, Mx2, etc.).
  - Moved config-specific reference Mx1 kernels into their own file
    (e.g. bli_gemmsup_r_haswell_ref_dMx1.c) inside the d6x8 directory.
  - Added rd_dMx1 assembly kernels, which seems marginally faster than
    the corresponding reference kernels.
  - Updated comments in ref_kernels/bli_cntx_ref.c and changed to using
    the row-oriented reference kernels for all storage combos.
This commit is contained in:
Field G. Van Zee
2020-06-04 17:21:08 -05:00
parent 943a21def0
commit 1c719c91a3
17 changed files with 19158 additions and 1394 deletions

View File

@@ -65,23 +65,6 @@
// Prototype reference microkernels.
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref )
#if 0
// Define parameters and variables for edge case kernel map.
#define NUM_MR 4
#define NUM_NR 4
#define FUNCPTR_T dgemmsup_ker_ft
static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 };
static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 };
static FUNCPTR_T kmap[NUM_MR][NUM_NR] =
{ /* 8 4 2 1 */
/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8m, bli_dgemmsup_rd_haswell_asm_6x4m, bli_dgemmsup_rd_haswell_asm_6x2m, bli_dgemmsup_r_haswell_ref_6x1 },
/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8m, bli_dgemmsup_rd_haswell_asm_3x4m, bli_dgemmsup_rd_haswell_asm_3x2m, bli_dgemmsup_r_haswell_ref_3x1 },
/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8m, bli_dgemmsup_rd_haswell_asm_2x4m, bli_dgemmsup_rd_haswell_asm_2x2m, bli_dgemmsup_r_haswell_ref_2x1 },
/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8m, bli_dgemmsup_rd_haswell_asm_1x4m, bli_dgemmsup_rd_haswell_asm_1x2m, bli_dgemmsup_r_haswell_ref_1x1 }
};
#endif
void bli_dgemmsup_rd_haswell_asm_6x8m
(
@@ -135,7 +118,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
}
if ( 1 == n_left )
{
#if 0
#if 0
const dim_t nr_cur = 1;
bli_dgemmsup_r_haswell_ref
@@ -144,14 +127,14 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
beta, cij, rs_c0, cs_c0, data, cntx
);
#else
#else
bli_dgemv_ex
(
BLIS_NO_TRANSPOSE, conjb, m0, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0,
beta, cij, rs_c0, cntx, NULL
);
#endif
#endif
}
return;
}
@@ -193,7 +176,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
//lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
//lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a
mov(var(b), rdx) // load address of b.
//mov(var(b), rdx) // load address of b.
//mov(var(rs_b), r10) // load rs_b
mov(var(cs_b), r11) // load cs_b
//lea(mem(, r10, 8), r10) // rs_b *= sizeof(double)
@@ -204,8 +187,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
//mov(var(c), r12) // load address of c
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
@@ -222,10 +205,11 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
mov(var(a), r14) // load address of a
mov(var(b), rdx) // load address of b
mov(var(c), r12) // load address of c
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
imul(imm(1*8), rsi) // rsi *= cs_c = 1*8
imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8
lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c;
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
@@ -266,14 +250,14 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
lea(mem(rdx), rbx) // rbx = b_jj;
#if 0
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
#if 1
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c
prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c
#endif
lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a
lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a
@@ -290,15 +274,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
// ---------------------------------- iteration 0
#if 1
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a
prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
#endif
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -327,7 +311,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -354,15 +338,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
// ---------------------------------- iteration 2
#if 1
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a
prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
#endif
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -391,7 +375,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -436,15 +420,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
label(.DLOOPKITER4) // EDGE LOOP (ymm)
#if 1
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a
prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
#endif
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -493,7 +477,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
vmovsd(mem(rax ), xmm0)
vmovsd(mem(rax, r8, 1), xmm1)
vmovsd(mem(rax, r8, 2), xmm2)
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
vmovsd(mem(rbx ), xmm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -527,8 +511,6 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
label(.DPOSTACCUM)
// ymm4 ymm7 ymm10 ymm13
// ymm5 ymm8 ymm11 ymm14
@@ -536,14 +518,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
vhaddpd( ymm7, ymm4, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7)
vaddpd( xmm0, xmm1, xmm0 )
vhaddpd( ymm13, ymm10, ymm2 )
vextractf128(imm(1), ymm2, xmm1 )
vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13)
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm4 )
// xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7)
// xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13)
vhaddpd( ymm8, ymm5, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
@@ -554,6 +537,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm5 )
// xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8)
// xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14)
vhaddpd( ymm9, ymm6, ymm0 )
@@ -565,15 +550,14 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm6 )
// xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9)
// xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15)
// ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13)
// ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14)
// ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15)
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(alpha), rax) // load address of alpha
mov(var(beta), rbx) // load address of beta
@@ -661,8 +645,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
add(imm(4), r15) // jj += 4;
cmp(imm(4), r15) // compare jj to 4
jle(.DLOOP3X4J) // if jj <= 4, jump to beginning
cmp(imm(8), r15) // compare jj to 8
jl(.DLOOP3X4J) // if jj < 8, jump to beginning
// of jj loop; otherwise, loop ends.
@@ -692,7 +676,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -803,8 +787,8 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
mov(var(c), r12) // load address of c
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
@@ -813,13 +797,13 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
// rdx = rbx = b
// r9 = m dim index ii
// r15 = n dim index jj
// r10 = unused
mov(var(m_iter), r9) // ii = m_iter;
label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter .. 1 0 ]
#if 0
vzeroall() // zero all xmm/ymm registers.
#else
@@ -846,14 +830,14 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
lea(mem(rdx), rbx) // rbx = b;
#if 0
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
#if 1
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c
prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c
#endif
lea(mem(r8, r8, 4), rdi) // rdi = 5*rs_a
lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a
@@ -870,15 +854,15 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
// ---------------------------------- iteration 0
#if 1
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a
prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
#endif
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -907,7 +891,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -934,15 +918,15 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
// ---------------------------------- iteration 2
#if 1
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a
prefetch(0, mem(rax, rdi, 1, 0*8)) // prefetch rax + 5*cs_a
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
#endif
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -971,7 +955,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1014,11 +998,17 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
label(.DLOOPKITER4) // EDGE LOOP (ymm)
#if 1
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
#endif
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1067,7 +1057,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
vmovsd(mem(rax ), xmm0)
vmovsd(mem(rax, r8, 1), xmm1)
vmovsd(mem(rax, r8, 2), xmm2)
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
vmovsd(mem(rbx ), xmm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1101,8 +1091,6 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
label(.DPOSTACCUM)
// ymm4 ymm7 ymm10 ymm13
// ymm5 ymm8 ymm11 ymm14
@@ -1110,14 +1098,15 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
vhaddpd( ymm7, ymm4, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7)
vaddpd( xmm0, xmm1, xmm0 )
vhaddpd( ymm13, ymm10, ymm2 )
vextractf128(imm(1), ymm2, xmm1 )
vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13)
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm4 )
// ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7)
// ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13)
vhaddpd( ymm8, ymm5, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
@@ -1128,7 +1117,8 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm5 )
// ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8)
// ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14)
vhaddpd( ymm9, ymm6, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
@@ -1139,15 +1129,14 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm6 )
// ymm6[0] = sum(ymm6); ymm6[1] = sum(ymm9)
// ymm6[2] = sum(ymm12); ymm6[3] = sum(ymm15)
// ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13)
// ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14)
// ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15)
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(alpha), rax) // load address of alpha
mov(var(beta), rbx) // load address of beta
@@ -1259,7 +1248,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -1366,6 +1355,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
//lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b
//lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a
mov(var(c), r12) // load address of c
@@ -1375,41 +1365,44 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
// r12 = rcx = c
// r14 = rax = a
// rdx = rbx = b
// r9 = m dim index ii
// r14 = rax = a
// rdx = rbx = b
// r9 = m dim index ii
mov(var(m_iter), r9) // ii = m_iter;
label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ]
label(.DLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ]
#if 0
vzeroall() // zero all xmm/ymm registers.
vzeroall() // zero all xmm/ymm registers.
#else
// skylake can execute 3 vxorpd ipc with
// a latency of 1 cycle, while vzeroall
// has a latency of 12 cycles.
vxorpd(ymm4, ymm4, ymm4)
vxorpd(ymm5, ymm5, ymm5)
vxorpd(ymm6, ymm6, ymm6)
vxorpd(ymm7, ymm7, ymm7)
vxorpd(ymm8, ymm8, ymm8)
vxorpd(ymm9, ymm9, ymm9)
vxorpd(ymm10, ymm10, ymm10)
vxorpd(ymm11, ymm11, ymm11)
vxorpd(ymm12, ymm12, ymm12)
vxorpd(ymm13, ymm13, ymm13)
vxorpd(ymm14, ymm14, ymm14)
vxorpd(ymm15, ymm15, ymm15)
// skylake can execute 3 vxorpd ipc with
// a latency of 1 cycle, while vzeroall
// has a latency of 12 cycles.
vxorpd(ymm4, ymm4, ymm4)
vxorpd(ymm5, ymm5, ymm5)
vxorpd(ymm6, ymm6, ymm6)
vxorpd(ymm7, ymm7, ymm7)
vxorpd(ymm8, ymm8, ymm8)
vxorpd(ymm9, ymm9, ymm9)
vxorpd(ymm10, ymm10, ymm10)
vxorpd(ymm11, ymm11, ymm11)
vxorpd(ymm12, ymm12, ymm12)
vxorpd(ymm13, ymm13, ymm13)
vxorpd(ymm14, ymm14, ymm14)
vxorpd(ymm15, ymm15, ymm15)
#endif
lea(mem(r12), rcx) // rcx = c + 6*ii*rs_c;
lea(mem(r14), rax) // rax = a + 6*ii*rs_a;
lea(mem(rdx), rbx) // rbx = b;
lea(mem(r12), rcx) // rcx = c_ii;
lea(mem(r14), rax) // rax = a_ii;
lea(mem(rdx), rbx) // rbx = b;
#if 1
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
lea(mem(rcx, rdi, 2), r10) //
lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c;
prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c
@@ -1418,6 +1411,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
prefetch(0, mem(r10, 1*8)) // prefetch c + 3*rs_c
prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c
prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c
#endif
@@ -1433,6 +1427,12 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
// ---------------------------------- iteration 0
#if 0
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
#endif
vmovupd(mem(rbx ), ymm0)
vmovupd(mem(rbx, r11, 1), ymm1)
add(imm(4*8), rbx) // b += 4*rs_b = 4*8;
@@ -1497,6 +1497,12 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
// ---------------------------------- iteration 2
#if 0
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
#endif
vmovupd(mem(rbx ), ymm0)
vmovupd(mem(rbx, r11, 1), ymm1)
add(imm(4*8), rbx) // b += 4*rs_b = 4*8;
@@ -1578,6 +1584,12 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
label(.DLOOPKITER4) // EDGE LOOP (ymm)
#if 0
prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a
prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a
prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a
#endif
vmovupd(mem(rbx ), ymm0)
vmovupd(mem(rbx, r11, 1), ymm1)
@@ -1703,15 +1715,18 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
vextractf128(imm(1), ymm0, xmm1 )
vaddpd( xmm0, xmm1, xmm14 )
// xmm4 = sum(ymm4) sum(ymm5)
// xmm6 = sum(ymm6) sum(ymm7)
// xmm8 = sum(ymm8) sum(ymm9)
// xmm10 = sum(ymm10) sum(ymm11)
// xmm12 = sum(ymm12) sum(ymm13)
// xmm14 = sum(ymm14) sum(ymm15)
// xmm4[0:1] = sum(ymm4) sum(ymm5)
// xmm6[0:1] = sum(ymm6) sum(ymm7)
// xmm8[0:1] = sum(ymm8) sum(ymm9)
// xmm10[0:1] = sum(ymm10) sum(ymm11)
// xmm12[0:1] = sum(ymm12) sum(ymm13)
// xmm14[0:1] = sum(ymm14) sum(ymm15)
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(alpha), rax) // load address of alpha
mov(var(beta), rbx) // load address of beta
vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate
@@ -1810,13 +1825,13 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
lea(mem(r12, rdi, 4), r12) //
lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c
lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c
lea(mem(r14, r8, 4), r14) //
lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a
lea(mem(r14, r8, 4), r14) //
lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a
dec(r9) // ii -= 1;
jne(.DLOOP3X4I) // iterate again if ii != 0.
dec(r9) // ii -= 1;
jne(.DLOOP3X4I) // iterate again if ii != 0.
@@ -1846,7 +1861,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -1872,6 +1887,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
const dim_t mr_cur = 3;
bli_dgemmsup_rd_haswell_asm_3x2
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, nr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
@@ -1884,6 +1900,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
const dim_t mr_cur = 2;
bli_dgemmsup_rd_haswell_asm_2x2
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, nr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
@@ -1896,6 +1913,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m
const dim_t mr_cur = 1;
bli_dgemmsup_rd_haswell_asm_1x2
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, nr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,

View File

@@ -65,23 +65,6 @@
// Prototype reference microkernels.
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref )
#if 0
// Define parameters and variables for edge case kernel map.
#define NUM_MR 4
#define NUM_NR 4
#define FUNCPTR_T dgemmsup_ker_ft
static dim_t mrs[NUM_MR] = { 6, 3, 2, 1 };
static dim_t nrs[NUM_NR] = { 8, 4, 2, 1 };
static FUNCPTR_T kmap[NUM_MR][NUM_NR] =
{ /* 8 4 2 1 */
/* 6 */ { bli_dgemmsup_rd_haswell_asm_6x8n, bli_dgemmsup_rd_haswell_asm_6x4n, bli_dgemmsup_rd_haswell_asm_6x2n, bli_dgemmsup_r_haswell_ref_6x1 },
/* 3 */ { bli_dgemmsup_rd_haswell_asm_3x8n, bli_dgemmsup_rd_haswell_asm_3x4n, bli_dgemmsup_rd_haswell_asm_3x2n, bli_dgemmsup_r_haswell_ref_3x1 },
/* 2 */ { bli_dgemmsup_rd_haswell_asm_2x8n, bli_dgemmsup_rd_haswell_asm_2x4n, bli_dgemmsup_rd_haswell_asm_2x2n, bli_dgemmsup_r_haswell_ref_2x1 },
/* 1 */ { bli_dgemmsup_rd_haswell_asm_1x8n, bli_dgemmsup_rd_haswell_asm_1x4n, bli_dgemmsup_rd_haswell_asm_1x2n, bli_dgemmsup_r_haswell_ref_1x1 }
};
#endif
void bli_dgemmsup_rd_haswell_asm_6x8n
(
@@ -161,6 +144,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
const dim_t mr_cur = 3;
bli_dgemmsup_rd_haswell_asm_3x8n
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, n0, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
@@ -173,6 +157,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
const dim_t mr_cur = 2;
bli_dgemmsup_rd_haswell_asm_2x8n
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, n0, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
@@ -182,23 +167,24 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
}
if ( 1 == m_left )
{
#if 0
#if 1
const dim_t mr_cur = 1;
bli_dgemmsup_rd_haswell_asm_1x8n
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, n0, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
beta, cij, rs_c0, cs_c0, data, cntx
);
#else
#else
bli_dgemv_ex
(
BLIS_TRANSPOSE, conja, k0, n0,
alpha, bj, rs_b0, cs_b0, ai, cs_a0,
beta, cij, cs_c0, cntx, NULL
);
#endif
#endif
}
return;
}
@@ -231,7 +217,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
//vzeroall() // zero all xmm/ymm registers.
mov(var(a), rdx) // load address of a.
//mov(var(a), rdx) // load address of a.
mov(var(rs_a), r8) // load rs_a
//mov(var(cs_a), r9) // load cs_a
lea(mem(, r8, 8), r8) // rs_a *= sizeof(double)
@@ -247,11 +233,12 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b
//lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a
//mov(var(c), r12) // load address of c
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
@@ -267,12 +254,10 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
mov(var(a), rdx) // load address of a
mov(var(b), r14) // load address of b
mov(var(c), r12) // load address of c
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
lea(mem( , r9, 1), rsi) // rsi = r9 = 3*ii;
imul(rdi, rsi) // rsi *= rs_c
lea(mem(r12, rsi, 1), r12) // r12 = c + 3*ii*rs_c;
@@ -309,19 +294,20 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
vxorpd(ymm15, ymm15, ymm15)
#endif
lea(mem(r12), rcx) // rcx = c_iijj;
lea(mem(rdx), rax) // rax = a_ii;
lea(mem(r14), rbx) // rbx = b_jj;
#if 1
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c
prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c
#endif
lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b
//lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a
lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b
@@ -338,13 +324,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
// ---------------------------------- iteration 0
#if 0
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b
prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b
add(imm(8*8), r10) // r10 += 8*rs_b = 8*8;
#else
#if 1
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
#endif
@@ -352,7 +332,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -386,7 +366,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -420,7 +400,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -455,7 +435,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -502,7 +482,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -551,7 +531,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
vmovsd(mem(rax ), xmm0)
vmovsd(mem(rax, r8, 1), xmm1)
vmovsd(mem(rax, r8, 2), xmm2)
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
vmovsd(mem(rbx ), xmm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -585,8 +565,6 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
label(.DPOSTACCUM)
// ymm4 ymm7 ymm10 ymm13
// ymm5 ymm8 ymm11 ymm14
@@ -594,14 +572,15 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
vhaddpd( ymm7, ymm4, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7)
vaddpd( xmm0, xmm1, xmm0 )
vhaddpd( ymm13, ymm10, ymm2 )
vextractf128(imm(1), ymm2, xmm1 )
vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13)
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm4 )
// ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7)
// ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13)
vhaddpd( ymm8, ymm5, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
@@ -612,7 +591,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm5 )
// ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8)
// ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14)
vhaddpd( ymm9, ymm6, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
@@ -623,15 +603,14 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm6 )
// ymm6[0] = sum(ymm6); ymm6[1] = sum(ymm9)
// ymm6[2] = sum(ymm12); ymm6[3] = sum(ymm15)
// ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13)
// ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14)
// ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15)
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(alpha), rax) // load address of alpha
mov(var(beta), rbx) // load address of beta
@@ -716,8 +695,8 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
add(imm(3), r9) // ii += 3;
cmp(imm(3), r9) // compare ii to 3
jle(.DLOOP3X4I) // if ii <= 3, jump to beginning
cmp(imm(6), r9) // compare ii to 6
jl(.DLOOP3X4I) // if ii < 6, jump to beginning
// of ii loop; otherwise, loop ends.
@@ -748,7 +727,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -759,7 +738,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
consider_edge_cases:
// Handle edge cases in the m dimension, if they exist.
// Handle edge cases in the n dimension, if they exist.
if ( n_left )
{
const dim_t mr_cur = 6;
@@ -774,6 +753,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
const dim_t nr_cur = 2;
bli_dgemmsup_rd_haswell_asm_6x2
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, nr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
@@ -783,24 +763,24 @@ void bli_dgemmsup_rd_haswell_asm_6x8n
}
if ( 1 == n_left )
{
#if 0
#if 1
const dim_t nr_cur = 1;
//bli_dgemmsup_rd_haswell_asm_6x1n
bli_dgemmsup_r_haswell_ref
bli_dgemmsup_rd_haswell_asm_6x1
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, nr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
beta, cij, rs_c0, cs_c0, data, cntx
);
#else
#else
bli_dgemv_ex
(
BLIS_NO_TRANSPOSE, conjb, mr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0,
beta, cij, rs_c0, cntx, NULL
);
#endif
#endif
}
}
}
@@ -865,11 +845,12 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b
//lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a
mov(var(c), r12) // load address of c
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
@@ -905,19 +886,20 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
vxorpd(ymm15, ymm15, ymm15)
#endif
lea(mem(r12), rcx) // rcx = c_iijj;
lea(mem(rdx), rax) // rax = a_ii;
lea(mem(r14), rbx) // rbx = b_jj;
#if 1
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c
prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c
#endif
lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b
//lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a
lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b
@@ -934,13 +916,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
// ---------------------------------- iteration 0
#if 0
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b
prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b
add(imm(8*8), r10) // r10 += 8*rs_b = 8*8;
#else
#if 1
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
#endif
@@ -948,7 +924,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -982,7 +958,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1016,7 +992,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1051,7 +1027,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1098,7 +1074,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
vmovupd(mem(rax, r8, 2), ymm2)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1147,7 +1123,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
vmovsd(mem(rax ), xmm0)
vmovsd(mem(rax, r8, 1), xmm1)
vmovsd(mem(rax, r8, 2), xmm2)
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
vmovsd(mem(rbx ), xmm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1181,8 +1157,6 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
label(.DPOSTACCUM)
// ymm4 ymm7 ymm10 ymm13
// ymm5 ymm8 ymm11 ymm14
@@ -1190,14 +1164,15 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
vhaddpd( ymm7, ymm4, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7)
vaddpd( xmm0, xmm1, xmm0 )
vhaddpd( ymm13, ymm10, ymm2 )
vextractf128(imm(1), ymm2, xmm1 )
vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13)
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm4 )
// ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7)
// ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13)
vhaddpd( ymm8, ymm5, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
@@ -1208,7 +1183,8 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm5 )
// ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8)
// ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14)
vhaddpd( ymm9, ymm6, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
@@ -1219,15 +1195,14 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm6 )
// ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13)
// ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14)
// ymm6 = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15)
// ymm6[0] = sum(ymm6); ymm6[1] = sum(ymm9)
// ymm6[2] = sum(ymm12); ymm6[3] = sum(ymm15)
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(alpha), rax) // load address of alpha
mov(var(beta), rbx) // load address of beta
@@ -1312,6 +1287,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
label(.DRETURN)
@@ -1337,7 +1313,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -1348,7 +1324,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
consider_edge_cases:
// Handle edge cases in the m dimension, if they exist.
// Handle edge cases in the n dimension, if they exist.
if ( n_left )
{
const dim_t mr_cur = 3;
@@ -1363,6 +1339,7 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
const dim_t nr_cur = 2;
bli_dgemmsup_rd_haswell_asm_3x2
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, nr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
@@ -1372,23 +1349,25 @@ void bli_dgemmsup_rd_haswell_asm_3x8n
}
if ( 1 == n_left )
{
#if 0
#if 1
const dim_t nr_cur = 1;
bli_dgemmsup_r_haswell_ref_3x1
bli_dgemmsup_rd_haswell_asm_3x1
//bli_dgemmsup_r_haswell_ref_3x1
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, nr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
beta, cij, rs_c0, cs_c0, data, cntx
);
#else
#else
bli_dgemv_ex
(
BLIS_NO_TRANSPOSE, conjb, mr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0,
beta, cij, rs_c0, cntx, NULL
);
#endif
#endif
}
}
}
@@ -1453,11 +1432,12 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b
//lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a
mov(var(c), r12) // load address of c
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
@@ -1489,18 +1469,19 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
vxorpd(ymm14, ymm14, ymm14)
#endif
lea(mem(r12), rcx) // rcx = c_iijj;
lea(mem(rdx), rax) // rax = a_ii;
lea(mem(r14), rbx) // rbx = b_jj;
#if 1
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c
#endif
lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b
//lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a
lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b
@@ -1517,20 +1498,14 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
// ---------------------------------- iteration 0
#if 0
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b
prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b
add(imm(8*8), r10) // r10 += 8*rs_b = 8*8;
#else
#if 1
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
#endif
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1559,7 +1534,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1588,7 +1563,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1618,7 +1593,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1660,7 +1635,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
vmovupd(mem(rax ), ymm0)
vmovupd(mem(rax, r8, 1), ymm1)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1704,7 +1679,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
vmovsd(mem(rax ), xmm0)
vmovsd(mem(rax, r8, 1), xmm1)
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
vmovsd(mem(rbx ), xmm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1734,23 +1709,21 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
label(.DPOSTACCUM)
// ymm4 ymm7 ymm10 ymm13
// ymm5 ymm8 ymm11 ymm14
// ymm6 ymm9 ymm12 ymm15
vhaddpd( ymm7, ymm4, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7)
vaddpd( xmm0, xmm1, xmm0 )
vhaddpd( ymm13, ymm10, ymm2 )
vextractf128(imm(1), ymm2, xmm1 )
vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13)
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm4 )
// ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7)
// ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13)
vhaddpd( ymm8, ymm5, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
@@ -1761,14 +1734,14 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm5 )
// ymm5[0] = sum(ymm5); ymm5[1] = sum(ymm8)
// ymm5[2] = sum(ymm11); ymm5[3] = sum(ymm14)
// ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13)
// ymm5 = sum(ymm5) sum(ymm8) sum(ymm11) sum(ymm14)
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(alpha), rax) // load address of alpha
mov(var(beta), rbx) // load address of beta
@@ -1777,7 +1750,6 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
vmulpd(ymm0, ymm4, ymm4) // scale by alpha
vmulpd(ymm0, ymm5, ymm5)
vmulpd(ymm0, ymm6, ymm6)
@@ -1846,6 +1818,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
label(.DRETURN)
@@ -1871,7 +1844,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -1882,7 +1855,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
consider_edge_cases:
// Handle edge cases in the m dimension, if they exist.
// Handle edge cases in the n dimension, if they exist.
if ( n_left )
{
const dim_t mr_cur = 2;
@@ -1897,6 +1870,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
const dim_t nr_cur = 2;
bli_dgemmsup_rd_haswell_asm_2x2
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, nr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
@@ -1906,23 +1880,24 @@ void bli_dgemmsup_rd_haswell_asm_2x8n
}
if ( 1 == n_left )
{
#if 0
#if 1
const dim_t nr_cur = 1;
bli_dgemmsup_r_haswell_ref_2x1
bli_dgemmsup_rd_haswell_asm_2x1
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, nr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
beta, cij, rs_c0, cs_c0, data, cntx
);
#else
#else
bli_dgemv_ex
(
BLIS_NO_TRANSPOSE, conjb, mr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0,
beta, cij, rs_c0, cntx, NULL
);
#endif
#endif
}
}
}
@@ -1987,11 +1962,12 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b
//lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a
mov(var(c), r12) // load address of c
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
@@ -2019,18 +1995,18 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
vxorpd(ymm13, ymm13, ymm13)
#endif
lea(mem(r12), rcx) // rcx = c_iijj;
lea(mem(rdx), rax) // rax = a_ii;
lea(mem(r14), rbx) // rbx = b_jj;
#if 1
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c
#endif
lea(mem(r11, r11, 2), rdi) // rdi = 3*cs_b
//lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a
lea(mem(rbx, r11, 4), r10) // r10 = rbx + 4*cs_b
@@ -2047,19 +2023,13 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
// ---------------------------------- iteration 0
#if 0
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
prefetch(0, mem(r10, r11, 2, 0*8)) // prefetch rbx + 6*cs_b
prefetch(0, mem(r10, r13, 1, 0*8)) // prefetch rbx + 7*cs_b
add(imm(8*8), r10) // r10 += 8*rs_b = 8*8;
#else
#if 1
prefetch(0, mem(r10, 0*8)) // prefetch rbx + 4*cs_b
prefetch(0, mem(r10, r11, 1, 0*8)) // prefetch rbx + 5*cs_b
#endif
vmovupd(mem(rax ), ymm0)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -2083,7 +2053,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
#endif
vmovupd(mem(rax ), ymm0)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -2107,7 +2077,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
#endif
vmovupd(mem(rax ), ymm0)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -2132,7 +2102,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
#endif
vmovupd(mem(rax ), ymm0)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -2169,7 +2139,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
label(.DLOOPKITER4) // EDGE LOOP (ymm)
vmovupd(mem(rax ), ymm0)
add(imm(4*8), rax) // a += 4*cs_b = 4*8;
add(imm(4*8), rax) // a += 4*cs_a = 4*8;
vmovupd(mem(rbx ), ymm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -2208,7 +2178,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
// which would destory intermediate results.
vmovsd(mem(rax ), xmm0)
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
vmovsd(mem(rbx ), xmm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -2234,29 +2204,26 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
label(.DPOSTACCUM)
// ymm4 ymm7 ymm10 ymm13
// ymm5 ymm8 ymm11 ymm14
// ymm6 ymm9 ymm12 ymm15
vhaddpd( ymm7, ymm4, ymm0 )
vextractf128(imm(1), ymm0, xmm1 )
vaddpd( xmm0, xmm1, xmm0 ) // xmm0[0] = sum(ymm4); xmm0[1] = sum(ymm7)
vaddpd( xmm0, xmm1, xmm0 )
vhaddpd( ymm13, ymm10, ymm2 )
vextractf128(imm(1), ymm2, xmm1 )
vaddpd( xmm2, xmm1, xmm2 ) // xmm2[0] = sum(ymm10); xmm2[1] = sum(ymm13)
vaddpd( xmm2, xmm1, xmm2 )
vperm2f128(imm(0x20), ymm2, ymm0, ymm4 )
// ymm4[0] = sum(ymm4); ymm4[1] = sum(ymm7)
// ymm4[2] = sum(ymm10); ymm4[3] = sum(ymm13)
// ymm4 = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13)
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
//mov(var(rs_c), rdi) // load rs_c
//lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double)
mov(var(alpha), rax) // load address of alpha
mov(var(beta), rbx) // load address of beta
@@ -2325,6 +2292,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
label(.DRETURN)
@@ -2350,7 +2318,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -2361,7 +2329,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
consider_edge_cases:
// Handle edge cases in the m dimension, if they exist.
// Handle edge cases in the n dimension, if they exist.
if ( n_left )
{
const dim_t mr_cur = 1;
@@ -2376,6 +2344,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
const dim_t nr_cur = 2;
bli_dgemmsup_rd_haswell_asm_1x2
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, nr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
@@ -2385,23 +2354,24 @@ void bli_dgemmsup_rd_haswell_asm_1x8n
}
if ( 1 == n_left )
{
#if 0
#if 1
const dim_t nr_cur = 1;
bli_dgemmsup_r_haswell_ref_1x1
bli_dgemmsup_rd_haswell_asm_1x1
//bli_dgemmsup_r_haswell_ref
(
conja, conjb, mr_cur, nr_cur, k0,
alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0,
beta, cij, rs_c0, cs_c0, data, cntx
);
#else
#else
bli_ddotxv_ex
(
conja, conjb, k0,
alpha, ai, cs_a0, bj, rs_b0,
beta, cij, cntx, NULL
);
#endif
#endif
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,158 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2019, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#define BLIS_ASM_SYNTAX_ATT
#include "bli_x86_asm_macros.h"
/*
rrr:
-------- ------ --------
-------- ------ --------
-------- += ------ ... --------
-------- ------ --------
-------- ------ :
-------- ------ :
rcr:
-------- | | | | --------
-------- | | | | --------
-------- += | | | | ... --------
-------- | | | | --------
-------- | | | | :
-------- | | | | :
Assumptions:
- B is row-stored;
- A is row- or column-stored;
- m0 and n0 are at most MR and NR, respectively.
Therefore, this (r)ow-preferential kernel is well-suited for contiguous
(v)ector loads on B and single-element broadcasts from A.
NOTE: These kernels explicitly support column-oriented IO, implemented
via an in-register transpose. And thus they also support the crr and
ccr cases, though only crr is ever utilized (because ccr is handled by
transposing the operation and executing rcr, which does not incur the
cost of the in-register transpose).
crr:
| | | | | | | | ------ --------
| | | | | | | | ------ --------
| | | | | | | | += ------ ... --------
| | | | | | | | ------ --------
| | | | | | | | ------ :
| | | | | | | | ------ :
*/
// Prototype reference microkernels.
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref )
// NOTE: Normally, for any "?x1" kernel, we would call the reference kernel.
// However, at least one other subconfiguration (zen) uses this kernel set, so
// we need to be able to call a set of "?x1" kernels that we know will actually
// exist regardless of which subconfiguration these kernels were used by. Thus,
// the compromise employed here is to inline the reference kernel so it gets
// compiled as part of the haswell kernel set, and hence can unconditionally be
// called by other kernels within that kernel set.
#undef GENTFUNC
#define GENTFUNC( ctype, ch, opname, mdim ) \
\
void PASTEMAC(ch,opname) \
( \
conj_t conja, \
conj_t conjb, \
dim_t m, \
dim_t n, \
dim_t k, \
ctype* restrict alpha, \
ctype* restrict a, inc_t rs_a, inc_t cs_a, \
ctype* restrict b, inc_t rs_b, inc_t cs_b, \
ctype* restrict beta, \
ctype* restrict c, inc_t rs_c, inc_t cs_c, \
auxinfo_t* restrict data, \
cntx_t* restrict cntx \
) \
{ \
for ( dim_t i = 0; i < mdim; ++i ) \
{ \
ctype* restrict ci = &c[ i*rs_c ]; \
ctype* restrict ai = &a[ i*rs_a ]; \
\
/* for ( dim_t j = 0; j < 1; ++j ) */ \
{ \
ctype* restrict cij = ci /*[ j*cs_c ]*/ ; \
ctype* restrict bj = b /*[ j*cs_b ]*/ ; \
ctype ab; \
\
PASTEMAC(ch,set0s)( ab ); \
\
/* Perform a dot product to update the (i,j) element of c. */ \
for ( dim_t l = 0; l < k; ++l ) \
{ \
ctype* restrict aij = &ai[ l*cs_a ]; \
ctype* restrict bij = &bj[ l*rs_b ]; \
\
PASTEMAC(ch,dots)( *aij, *bij, ab ); \
} \
\
/* If beta is one, add ab into c. If beta is zero, overwrite c
with the result in ab. Otherwise, scale by beta and accumulate
ab to c. */ \
if ( PASTEMAC(ch,eq1)( *beta ) ) \
{ \
PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \
} \
else if ( PASTEMAC(d,eq0)( *beta ) ) \
{ \
PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \
} \
else \
{ \
PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \
} \
} \
} \
}
GENTFUNC( double, d, gemmsup_r_haswell_ref_6x1, 6 )
GENTFUNC( double, d, gemmsup_r_haswell_ref_5x1, 5 )
GENTFUNC( double, d, gemmsup_r_haswell_ref_4x1, 4 )
GENTFUNC( double, d, gemmsup_r_haswell_ref_3x1, 3 )
GENTFUNC( double, d, gemmsup_r_haswell_ref_2x1, 2 )
GENTFUNC( double, d, gemmsup_r_haswell_ref_1x1, 1 )

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -219,7 +219,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8
mov(var(c), r12) // load address of c
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
imul(imm(1*8), rsi) // rsi *= cs_c = 1*8
imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8
lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c;
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
@@ -487,7 +487,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8
vmovsd(mem(rax ), xmm0)
vmovsd(mem(rax, r8, 1), xmm1)
vmovsd(mem(rax, r8, 2), xmm2)
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
vmovsd(mem(rbx ), xmm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -788,6 +788,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8
lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b
//lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a
mov(var(c), r12) // load address of c
@@ -826,7 +827,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
imul(imm(1*8), rsi) // rsi *= cs_c = 1*8
imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8
lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c;
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
@@ -1018,7 +1019,7 @@ void bli_dgemmsup_rd_haswell_asm_2x8
vmovsd(mem(rax ), xmm0)
vmovsd(mem(rax, r8, 1), xmm1)
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
vmovsd(mem(rbx ), xmm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1274,7 +1275,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
imul(imm(1*8), rsi) // rsi *= cs_c = 1*8
imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8
lea(mem(r12, rsi, 1), rcx) // rcx = c + 4*jj*cs_c;
lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj;
@@ -1438,7 +1439,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8
// which would destory intermediate results.
vmovsd(mem(rax ), xmm0)
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
vmovsd(mem(rbx ), xmm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -1908,7 +1909,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4
vmovsd(mem(rax ), xmm0)
vmovsd(mem(rax, r8, 1), xmm1)
vmovsd(mem(rax, r8, 2), xmm2)
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
vmovsd(mem(rbx ), xmm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -2398,7 +2399,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4
vmovsd(mem(rax ), xmm0)
vmovsd(mem(rax, r8, 1), xmm1)
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
vmovsd(mem(rbx ), xmm3)
vfmadd231pd(ymm0, ymm3, ymm4)
@@ -2783,7 +2784,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4
// which would destory intermediate results.
vmovsd(mem(rax ), xmm0)
add(imm(1*8), rax) // a += 1*cs_b = 1*8;
add(imm(1*8), rax) // a += 1*cs_a = 1*8;
vmovsd(mem(rbx ), xmm3)
vfmadd231pd(ymm0, ymm3, ymm4)

View File

@@ -257,6 +257,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8
prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c
prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c
prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c
prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c
jmp(.DPOSTPFETCH) // jump to end of prefetching c
label(.DCOLPFETCH) // column-stored prefetching c

View File

@@ -65,6 +65,17 @@ GEMMTRSM_UKR_PROT( double, d, gemmtrsm_u_haswell_asm_6x8 )
// -- level-3 sup --------------------------------------------------------------
// -- double real --
// gemmsup_r
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_6x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_5x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_4x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_3x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_2x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_1x1 )
// gemmsup_rv
GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8 )
@@ -95,13 +106,6 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x2 )
GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x2 )
GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x2 )
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_6x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_5x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_4x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_3x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_2x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref_1x1 )
// gemmsup_rv (mkernel in m dim)
GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m )
@@ -133,6 +137,11 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_3x2 )
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x2 )
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x2 )
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_3x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_2x1 )
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x1 )
// gemmsup_rd (mkernel in m dim)
GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m )

View File

@@ -418,7 +418,8 @@ void GENBARNAME(cntx_init)
gen_func_init( &funcs[ BLIS_TRSM_L_UKR ], trsm_l_ukr_name );
gen_func_init( &funcs[ BLIS_TRSM_U_UKR ], trsm_u_ukr_name );
bli_mbool_init( &mbools[ BLIS_GEMM_UKR ], TRUE, TRUE, TRUE, TRUE );
// s d c z
bli_mbool_init( &mbools[ BLIS_GEMM_UKR ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_GEMMTRSM_L_UKR ], FALSE, FALSE, FALSE, FALSE );
bli_mbool_init( &mbools[ BLIS_GEMMTRSM_U_UKR ], FALSE, FALSE, FALSE, FALSE );
bli_mbool_init( &mbools[ BLIS_TRSM_L_UKR ], FALSE, FALSE, FALSE, FALSE );
@@ -427,11 +428,14 @@ void GENBARNAME(cntx_init)
// -- Set level-3 small/unpacked thresholds --------------------------------
// NOTE: The default thresholds are set very low so that the sup framework
// only actives for exceedingly small dimensions. If a sub-configuration
// registers optimized sup kernels, then that sub-configuration should also
// register new (probably larger) thresholds that are almost surely more
// appropriate that these default values.
// NOTE: The default thresholds are set to zero so that the sup framework
// does not activate by default. Note that the semantic meaning of the
// thresholds is that the sup code path is executed if a dimension is
// strictly less than its corresponding threshold. So actually, the
// thresholds specify the minimum dimension size that will still dispatch
// the non-sup/large code path. This "strictly less than" behavior was
// chosen over "less than or equal to" so that threshold values of 0 would
// effectively disable sup (even for matrix dimensions of 0).
// s d c z
bli_blksz_init_easy( &thresh[ BLIS_MT ], 0, 0, 0, 0 );
bli_blksz_init_easy( &thresh[ BLIS_NT ], 0, 0, 0, 0 );
@@ -486,10 +490,10 @@ void GENBARNAME(cntx_init)
gen_func_init( &funcs[ BLIS_RRC ], gemmsup_rv_ukr_name );
gen_func_init( &funcs[ BLIS_RCR ], gemmsup_rv_ukr_name );
gen_func_init( &funcs[ BLIS_RCC ], gemmsup_rv_ukr_name );
gen_func_init( &funcs[ BLIS_CRR ], gemmsup_cv_ukr_name );
gen_func_init( &funcs[ BLIS_CRC ], gemmsup_cv_ukr_name );
gen_func_init( &funcs[ BLIS_CCR ], gemmsup_cv_ukr_name );
gen_func_init( &funcs[ BLIS_CCC ], gemmsup_cv_ukr_name );
gen_func_init( &funcs[ BLIS_CRR ], gemmsup_rv_ukr_name );
gen_func_init( &funcs[ BLIS_CRC ], gemmsup_rv_ukr_name );
gen_func_init( &funcs[ BLIS_CCR ], gemmsup_rv_ukr_name );
gen_func_init( &funcs[ BLIS_CCC ], gemmsup_rv_ukr_name );
// Register the general-stride/generic ukernel to the "catch-all" slot
// associated with the BLIS_XXX enum value. This slot will be queried if
@@ -498,16 +502,17 @@ void GENBARNAME(cntx_init)
// Set the l3 sup ukernel storage preferences.
bli_mbool_init( &mbools[ BLIS_RRR ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_RRC ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_RCR ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_RCC ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_CRR ], FALSE, FALSE, FALSE, FALSE );
bli_mbool_init( &mbools[ BLIS_CRC ], FALSE, FALSE, FALSE, FALSE );
bli_mbool_init( &mbools[ BLIS_CCR ], FALSE, FALSE, FALSE, FALSE );
bli_mbool_init( &mbools[ BLIS_CCC ], FALSE, FALSE, FALSE, FALSE );
// s d c z
bli_mbool_init( &mbools[ BLIS_RRR ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_RRC ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_RCR ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_RCC ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_CRR ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_CRC ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_CCR ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_CCC ], TRUE, TRUE, TRUE, TRUE );
bli_mbool_init( &mbools[ BLIS_XXX ], FALSE, FALSE, FALSE, FALSE );
bli_mbool_init( &mbools[ BLIS_XXX ], TRUE, TRUE, TRUE, TRUE );
// -- Set level-1f kernels -------------------------------------------------