Obliterated usage of rbp register in SUP gemm kernel

- Mx4 edge kernels were overwriting rbp
registers for prefetches.
- Since rbp along with rsp defines stack frame,
it resulted in stack overflow issue.
- Replaced rbp with rdx register for prefetches.

AMD-Internal: [CPUPL-2987]
Change-Id: I4e52cf691b70be5ab63f562d7630d640b29e1cfd
This commit is contained in:
Harsh Dave
2023-02-06 11:49:45 -06:00
parent 299bed3fa8
commit 222e00e840

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2019, Advanced Micro Devices, Inc.
Copyright (C) 2019-23, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -115,9 +115,9 @@ void bli_dgemmsup_rv_haswell_asm_6x4
// -------------------------------------------------------------------------
begin_asm()
vzeroall() // zero all xmm/ymm registers.
mov(var(a), rax) // load address of a.
mov(var(rs_a), r8) // load rs_a
mov(var(cs_a), r9) // load cs_a
@@ -132,7 +132,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4
//mov(var(cs_b), r11) // load cs_b
lea(mem(, r10, 8), r10) // rs_b *= sizeof(double)
//lea(mem(, r11, 8), r11) // cs_b *= sizeof(double)
// NOTE: We cannot pre-load elements of a or b
// because it could eventually, in the last
// unrolled iter or the cleanup loop, result
@@ -163,38 +163,38 @@ void bli_dgemmsup_rv_haswell_asm_6x4
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double)
lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c;
lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c;
prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c
prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c
prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c
prefetch(0, mem(rcx, rbp, 1, 5*8)) // prefetch c + 3*cs_c
prefetch(0, mem(rcx, rdx, 1, 5*8)) // prefetch c + 3*cs_c
label(.DPOSTPFETCH) // done prefetching c
#if 1
lea(mem(rax, r9, 8), rdx) //
lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a;
#endif
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.DCONSIDKLEFT) // if i == 0, jump to code that
// contains the k_left loop.
label(.DLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
#if 1
prefetch(0, mem(rdx, 5*8))
#endif
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -202,19 +202,19 @@ void bli_dgemmsup_rv_haswell_asm_6x4
vbroadcastsd(mem(rax, r8, 1), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm0, ymm3, ymm6)
vbroadcastsd(mem(rax, r8, 2), ymm2)
vbroadcastsd(mem(rax, r13, 1), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm0, ymm3, ymm10)
vbroadcastsd(mem(rax, r8, 4), ymm2)
vbroadcastsd(mem(rax, r15, 1), ymm3)
add(r9, rax) // a += cs_a;
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm0, ymm3, ymm14)
// ---------------------------------- iteration 1
#if 0
@@ -228,25 +228,25 @@ void bli_dgemmsup_rv_haswell_asm_6x4
vbroadcastsd(mem(rax, r8, 1), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm0, ymm3, ymm6)
vbroadcastsd(mem(rax, r8, 2), ymm2)
vbroadcastsd(mem(rax, r13, 1), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm0, ymm3, ymm10)
vbroadcastsd(mem(rax, r8, 4), ymm2)
vbroadcastsd(mem(rax, r15, 1), ymm3)
add(r9, rax) // a += cs_a;
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm0, ymm3, ymm14)
// ---------------------------------- iteration 2
#if 1
prefetch(0, mem(rdx, r9, 2, 5*8))
#endif
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -254,18 +254,18 @@ void bli_dgemmsup_rv_haswell_asm_6x4
vbroadcastsd(mem(rax, r8, 1), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm0, ymm3, ymm6)
vbroadcastsd(mem(rax, r8, 2), ymm2)
vbroadcastsd(mem(rax, r13, 1), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm0, ymm3, ymm10)
vbroadcastsd(mem(rax, r8, 4), ymm2)
vbroadcastsd(mem(rax, r15, 1), ymm3)
add(r9, rax) // a += cs_a;
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm0, ymm3, ymm14)
// ---------------------------------- iteration 3
@@ -280,43 +280,43 @@ void bli_dgemmsup_rv_haswell_asm_6x4
vbroadcastsd(mem(rax, r8, 1), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm0, ymm3, ymm6)
vbroadcastsd(mem(rax, r8, 2), ymm2)
vbroadcastsd(mem(rax, r13, 1), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm0, ymm3, ymm10)
vbroadcastsd(mem(rax, r8, 4), ymm2)
vbroadcastsd(mem(rax, r15, 1), ymm3)
add(r9, rax) // a += cs_a;
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm0, ymm3, ymm14)
dec(rsi) // i -= 1;
jne(.DLOOPKITER) // iterate again if i != 0.
label(.DCONSIDKLEFT)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.DPOSTACCUM) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.DLOOPKLEFT) // EDGE LOOP
#if 0
prefetch(0, mem(rdx, 5*8))
add(r9, rdx)
#endif
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -324,57 +324,57 @@ void bli_dgemmsup_rv_haswell_asm_6x4
vbroadcastsd(mem(rax, r8, 1), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm0, ymm3, ymm6)
vbroadcastsd(mem(rax, r8, 2), ymm2)
vbroadcastsd(mem(rax, r13, 1), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm0, ymm3, ymm10)
vbroadcastsd(mem(rax, r8, 4), ymm2)
vbroadcastsd(mem(rax, r15, 1), ymm3)
add(r9, rax) // a += cs_a;
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm0, ymm3, ymm14)
dec(rsi) // i -= 1;
jne(.DLOOPKLEFT) // iterate again if i != 0.
label(.DPOSTACCUM)
mov(var(alpha), rax) // load address of alpha
mov(var(beta), rbx) // load address of beta
vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate
vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate
vmulpd(ymm0, ymm4, ymm4) // scale by alpha
vmulpd(ymm0, ymm6, ymm6)
vmulpd(ymm0, ymm8, ymm8)
vmulpd(ymm0, ymm10, ymm10)
vmulpd(ymm0, ymm12, ymm12)
vmulpd(ymm0, ymm14, ymm14)
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double)
//lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c;
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
// now avoid loading C if beta == 0
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
vucomisd(xmm0, xmm3) // set ZF if beta == 0.
je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case
@@ -383,42 +383,42 @@ void bli_dgemmsup_rv_haswell_asm_6x4
cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8.
jz(.DCOLSTORED) // jump to column storage case
label(.DROWSTORED)
vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4)
vmovupd(ymm4, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6)
vmovupd(ymm6, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8)
vmovupd(ymm8, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10)
vmovupd(ymm10, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12)
vmovupd(ymm12, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14)
vmovupd(ymm14, mem(rcx, 0*32))
//add(rdi, rcx)
jmp(.DDONE) // jump to end.
@@ -466,45 +466,45 @@ void bli_dgemmsup_rv_haswell_asm_6x4
jmp(.DDONE) // jump to end.
label(.DBETAZERO)
cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8.
jz(.DCOLSTORBZ) // jump to column storage case
label(.DROWSTORBZ)
vmovupd(ymm4, mem(rcx, 0*32))
add(rdi, rcx)
vmovupd(ymm6, mem(rcx, 0*32))
add(rdi, rcx)
vmovupd(ymm8, mem(rcx, 0*32))
add(rdi, rcx)
vmovupd(ymm10, mem(rcx, 0*32))
add(rdi, rcx)
vmovupd(ymm12, mem(rcx, 0*32))
add(rdi, rcx)
vmovupd(ymm14, mem(rcx, 0*32))
//add(rdi, rcx)
jmp(.DDONE) // jump to end.
@@ -539,13 +539,13 @@ void bli_dgemmsup_rv_haswell_asm_6x4
vmovupd(xmm4, mem(rdx, rax, 1))
//lea(mem(rdx, rsi, 4), rdx)
label(.DDONE)
end_asm(
: // output operands (none)
@@ -566,7 +566,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -657,11 +657,11 @@ void bli_dgemmsup_rv_haswell_asm_5x4
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double)
lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c;
lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c;
prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c
prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c
prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c
prefetch(0, mem(rcx, rbp, 1, 4*8)) // prefetch c + 3*cs_c
prefetch(0, mem(rcx, rdx, 1, 4*8)) // prefetch c + 3*cs_c
label(.DPOSTPFETCH) // done prefetching c
@@ -1037,7 +1037,7 @@ void bli_dgemmsup_rv_haswell_asm_5x4
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -1127,11 +1127,11 @@ void bli_dgemmsup_rv_haswell_asm_4x4
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double)
lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c;
lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c;
prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c
prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c
prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c
prefetch(0, mem(rcx, rbp, 1, 3*8)) // prefetch c + 3*cs_c
prefetch(0, mem(rcx, rdx, 1, 3*8)) // prefetch c + 3*cs_c
label(.DPOSTPFETCH) // done prefetching c
@@ -1457,7 +1457,7 @@ void bli_dgemmsup_rv_haswell_asm_4x4
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -1546,11 +1546,11 @@ void bli_dgemmsup_rv_haswell_asm_3x4
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double)
lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c;
lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c;
prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c
prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c
prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c
prefetch(0, mem(rcx, rbp, 1, 2*8)) // prefetch c + 3*cs_c
prefetch(0, mem(rcx, rdx, 1, 2*8)) // prefetch c + 3*cs_c
label(.DPOSTPFETCH) // done prefetching c
@@ -1884,7 +1884,7 @@ void bli_dgemmsup_rv_haswell_asm_3x4
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -1972,11 +1972,11 @@ void bli_dgemmsup_rv_haswell_asm_2x4
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double)
lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c;
lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c;
prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c
prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c
prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c
prefetch(0, mem(rcx, rbp, 1, 1*8)) // prefetch c + 3*cs_c
prefetch(0, mem(rcx, rdx, 1, 1*8)) // prefetch c + 3*cs_c
label(.DPOSTPFETCH) // done prefetching c
@@ -2247,7 +2247,7 @@ void bli_dgemmsup_rv_haswell_asm_2x4
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -2334,11 +2334,11 @@ void bli_dgemmsup_rv_haswell_asm_1x4
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double)
lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c;
lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c;
prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c
prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c
prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c
prefetch(0, mem(rcx, rbp, 1, 0*8)) // prefetch c + 3*cs_c
prefetch(0, mem(rcx, rdx, 1, 0*8)) // prefetch c + 3*cs_c
label(.DPOSTPFETCH) // done prefetching c
@@ -2588,7 +2588,7 @@ void bli_dgemmsup_rv_haswell_asm_1x4
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp",
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",