Fixed out of bound memory access in sgemmsup zen rv kernels

Details:

1. In sgemmsup_zen_rv_?x2 kernels "vmovps" instruction
   is used to load B matrix in k loop and k last loop,
   which is loading 128 bit into xmm than 64 bit as expected.

2. Changed vmovps instruction to vmovsd instrucntions
   which load only 64 bit in xmm register

3. Avoided C memory access by vfma instruction when multiplying
   with non-beta at corner cases with required access to 128 bit
   which leads to out of bound. Replaced with vmovq first to
   get 64 bit data then peformed vfma on xmm register in rv_6x8m
   and rv_6x4m

   AMD-Internal: [CPUPL-2472]

Change-Id: Iad397f8f5b5cc607b4278b603b1e0ea3f6b082f2
This commit is contained in:
Nallani Bhaskar
2022-08-29 23:06:28 +05:30
parent 40c71dd2e1
commit ea79efa915
2 changed files with 50 additions and 40 deletions

View File

@@ -6853,7 +6853,7 @@ void bli_sgemmsup_rv_zen_asm_6x2
// ---------------------------------- iteration 0
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), xmm2)
@@ -6874,7 +6874,7 @@ void bli_sgemmsup_rv_zen_asm_6x2
// ---------------------------------- iteration 1
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), xmm2)
@@ -6895,7 +6895,7 @@ void bli_sgemmsup_rv_zen_asm_6x2
// ---------------------------------- iteration 2
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), xmm2)
@@ -6916,7 +6916,7 @@ void bli_sgemmsup_rv_zen_asm_6x2
// ---------------------------------- iteration 3
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), xmm2)
@@ -6949,7 +6949,7 @@ void bli_sgemmsup_rv_zen_asm_6x2
label(.SLOOPKLEFT) // EDGE LOOP
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), xmm2)
@@ -7237,7 +7237,7 @@ void bli_sgemmsup_rv_zen_asm_5x2
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), xmm2)
@@ -7256,7 +7256,7 @@ void bli_sgemmsup_rv_zen_asm_5x2
// ---------------------------------- iteration 1
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), xmm2)
@@ -7275,7 +7275,7 @@ void bli_sgemmsup_rv_zen_asm_5x2
// ---------------------------------- iteration 2
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), xmm2)
@@ -7294,7 +7294,7 @@ void bli_sgemmsup_rv_zen_asm_5x2
// ---------------------------------- iteration 3
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), xmm2)
@@ -7325,7 +7325,7 @@ void bli_sgemmsup_rv_zen_asm_5x2
label(.SLOOPKLEFT) // EDGE LOOP
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), xmm2)
@@ -7612,7 +7612,7 @@ void bli_sgemmsup_rv_zen_asm_4x2
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -7627,7 +7627,7 @@ void bli_sgemmsup_rv_zen_asm_4x2
vfmadd231ps(xmm0, xmm3, xmm10)
// ---------------------------------- iteration 1
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -7642,7 +7642,7 @@ void bli_sgemmsup_rv_zen_asm_4x2
vfmadd231ps(xmm0, xmm3, xmm10)
// ---------------------------------- iteration 2
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -7657,7 +7657,7 @@ void bli_sgemmsup_rv_zen_asm_4x2
vfmadd231ps(xmm0, xmm3, xmm10)
// ---------------------------------- iteration 3
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -7685,7 +7685,7 @@ void bli_sgemmsup_rv_zen_asm_4x2
label(.SLOOPKLEFT) // EDGE LOOP
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -7940,7 +7940,7 @@ void bli_sgemmsup_rv_zen_asm_3x2
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -7953,7 +7953,7 @@ void bli_sgemmsup_rv_zen_asm_3x2
vfmadd231ps(xmm0, xmm2, xmm8)
// ---------------------------------- iteration 1
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -7966,7 +7966,7 @@ void bli_sgemmsup_rv_zen_asm_3x2
vfmadd231ps(xmm0, xmm2, xmm8)
// ---------------------------------- iteration 2
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -7979,7 +7979,7 @@ void bli_sgemmsup_rv_zen_asm_3x2
vfmadd231ps(xmm0, xmm2, xmm8)
// ---------------------------------- iteration 3
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -8005,7 +8005,7 @@ void bli_sgemmsup_rv_zen_asm_3x2
label(.SLOOPKLEFT) // EDGE LOOP
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -8245,7 +8245,7 @@ void bli_sgemmsup_rv_zen_asm_2x2
// ---------------------------------- iteration 0
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
vbroadcastss(mem(rax, r8, 1), ymm3)
@@ -8254,7 +8254,7 @@ void bli_sgemmsup_rv_zen_asm_2x2
vfmadd231ps(xmm0, xmm3, xmm6)
// ---------------------------------- iteration 1
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
vbroadcastss(mem(rax, r8, 1), ymm3)
@@ -8263,7 +8263,7 @@ void bli_sgemmsup_rv_zen_asm_2x2
vfmadd231ps(xmm0, xmm3, xmm6)
// ---------------------------------- iteration 2
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -8273,7 +8273,7 @@ void bli_sgemmsup_rv_zen_asm_2x2
vfmadd231ps(xmm0, xmm3, xmm6)
// ---------------------------------- iteration 3
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -8294,7 +8294,7 @@ void bli_sgemmsup_rv_zen_asm_2x2
label(.SLOOPKLEFT) // EDGE LOOP
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -8503,7 +8503,7 @@ void bli_sgemmsup_rv_zen_asm_1x2
label(.SLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -8511,7 +8511,7 @@ void bli_sgemmsup_rv_zen_asm_1x2
vfmadd231ps(xmm0, xmm2, xmm4)
// ---------------------------------- iteration 1
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -8520,7 +8520,7 @@ void bli_sgemmsup_rv_zen_asm_1x2
// ---------------------------------- iteration 2
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -8528,7 +8528,7 @@ void bli_sgemmsup_rv_zen_asm_1x2
vfmadd231ps(xmm0, xmm2, xmm4)
// ---------------------------------- iteration 3
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), ymm2)
@@ -8548,7 +8548,7 @@ void bli_sgemmsup_rv_zen_asm_1x2
label(.SLOOPKLEFT) // EDGE LOOP
vmovups(mem(rbx, 0*32), xmm0)
vmovsd(mem(rbx, 0*32), xmm0)
add(r10, rbx) // b += rs_b;
vbroadcastss(mem(rax ), xmm2)

View File

@@ -1291,13 +1291,17 @@ void bli_sgemmsup_rv_zen_asm_6x8m
vextractf128(imm(0x1), ymm0, xmm2)
vpermilps(imm(0xe),xmm0,xmm5)
vpermilps(imm(0xe),xmm2,xmm6)
vfmadd231ps(mem(rdx), xmm3, xmm0)
vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm2)
vmovq(mem(rdx),xmm4)
vmovq(mem(rdx, rsi, 4),xmm1)
vfmadd231ps(xmm4, xmm3, xmm0)
vfmadd231ps(xmm1, xmm3, xmm2)
vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 )
vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 )
lea(mem(rdx, rsi, 1), rdx)
vfmadd231ps(mem(rdx), xmm3, xmm5)
vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm6)
vmovq(mem(rdx),xmm4)
vmovq(mem(rdx, rsi, 4),xmm1)
vfmadd231ps(xmm4, xmm3, xmm5)
vfmadd231ps(xmm1, xmm3, xmm6)
vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 )
vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 )
lea(mem(rdx, rsi, 1), rdx)
@@ -1306,13 +1310,17 @@ void bli_sgemmsup_rv_zen_asm_6x8m
vextractf128(imm(0x1), ymm0, xmm2)
vpermilps(imm(0xe),xmm0,xmm5)
vpermilps(imm(0xe),xmm2,xmm6)
vfmadd231ps(mem(rdx), xmm3, xmm0)
vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm2)
vmovq(mem(rdx),xmm4)
vmovq(mem(rdx, rsi, 4),xmm1)
vfmadd231ps(xmm4, xmm3, xmm0)
vfmadd231ps(xmm1, xmm3, xmm2)
vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 )
vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 )
lea(mem(rdx, rsi, 1), rdx)
vfmadd231ps(mem(rdx), xmm3, xmm5)
vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm6)
vmovq(mem(rdx),xmm4)
vmovq(mem(rdx, rsi, 4),xmm1)
vfmadd231ps(xmm4, xmm3, xmm5)
vfmadd231ps(xmm1, xmm3, xmm6)
vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 )
vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 )
@@ -1810,11 +1818,13 @@ void bli_sgemmsup_rv_zen_asm_6x4m
lea(mem(rdx, rsi, 1), rdx)
vunpckhps(xmm14, xmm12, xmm0)
vpermilps(imm(0x4e), xmm0, xmm5)
vfmadd231ps(mem(rdx), xmm3, xmm0)
vmovq(mem(rdx),xmm4)
vfmadd231ps(xmm4, xmm3, xmm0)
vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 )
lea(mem(rdx, rsi, 1), rdx)
vfmadd231ps(mem(rdx), xmm3, xmm5)
vmovq(mem(rdx),xmm4)
vfmadd231ps(xmm4, xmm3, xmm5)
vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 )
jmp(.SDONE) // jump to end.