From ea79efa915af7c2ac1c6c2f3b4a86ec83b6a446d Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Mon, 29 Aug 2022 23:06:28 +0530 Subject: [PATCH] Fixed out of bound memory access in sgemmsup zen rv kernels Details: 1. In sgemmsup_zen_rv_?x2 kernels "vmovps" instruction is used to load B matrix in k loop and k last loop, which is loading 128 bit into xmm than 64 bit as expected. 2. Changed vmovps instruction to vmovsd instrucntions which load only 64 bit in xmm register 3. Avoided C memory access by vfma instruction when multiplying with non-beta at corner cases with required access to 128 bit which leads to out of bound. Replaced with vmovq first to get 64 bit data then peformed vfma on xmm register in rv_6x8m and rv_6x4m AMD-Internal: [CPUPL-2472] Change-Id: Iad397f8f5b5cc607b4278b603b1e0ea3f6b082f2 --- .../zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c | 60 +++++++++---------- .../zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c | 30 ++++++---- 2 files changed, 50 insertions(+), 40 deletions(-) diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c index 507ff5a71..7befbb69b 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c @@ -6853,7 +6853,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -6874,7 +6874,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -6895,7 +6895,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -6916,7 +6916,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -6949,7 +6949,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7237,7 +7237,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7256,7 +7256,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7275,7 +7275,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7294,7 +7294,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7325,7 +7325,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7612,7 +7612,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7627,7 +7627,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 vfmadd231ps(xmm0, xmm3, xmm10) // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7642,7 +7642,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 vfmadd231ps(xmm0, xmm3, xmm10) // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7657,7 +7657,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 vfmadd231ps(xmm0, xmm3, xmm10) // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7685,7 +7685,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7940,7 +7940,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7953,7 +7953,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 vfmadd231ps(xmm0, xmm2, xmm8) // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7966,7 +7966,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 vfmadd231ps(xmm0, xmm2, xmm8) // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7979,7 +7979,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 vfmadd231ps(xmm0, xmm2, xmm8) // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8005,7 +8005,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8245,7 +8245,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) vbroadcastss(mem(rax, r8, 1), ymm3) @@ -8254,7 +8254,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 vfmadd231ps(xmm0, xmm3, xmm6) // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) vbroadcastss(mem(rax, r8, 1), ymm3) @@ -8263,7 +8263,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 vfmadd231ps(xmm0, xmm3, xmm6) // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8273,7 +8273,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 vfmadd231ps(xmm0, xmm3, xmm6) // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8294,7 +8294,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8503,7 +8503,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8511,7 +8511,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 vfmadd231ps(xmm0, xmm2, xmm4) // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8520,7 +8520,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8528,7 +8528,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 vfmadd231ps(xmm0, xmm2, xmm4) // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8548,7 +8548,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c index e6ecd47f4..d5e2135a6 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c @@ -1291,13 +1291,17 @@ void bli_sgemmsup_rv_zen_asm_6x8m vextractf128(imm(0x1), ymm0, xmm2) vpermilps(imm(0xe),xmm0,xmm5) vpermilps(imm(0xe),xmm2,xmm6) - vfmadd231ps(mem(rdx), xmm3, xmm0) - vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm2) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm1, xmm3, xmm2) vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) lea(mem(rdx, rsi, 1), rdx) - vfmadd231ps(mem(rdx), xmm3, xmm5) - vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm6) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm1, xmm3, xmm6) vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) lea(mem(rdx, rsi, 1), rdx) @@ -1306,13 +1310,17 @@ void bli_sgemmsup_rv_zen_asm_6x8m vextractf128(imm(0x1), ymm0, xmm2) vpermilps(imm(0xe),xmm0,xmm5) vpermilps(imm(0xe),xmm2,xmm6) - vfmadd231ps(mem(rdx), xmm3, xmm0) - vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm2) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm1, xmm3, xmm2) vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) lea(mem(rdx, rsi, 1), rdx) - vfmadd231ps(mem(rdx), xmm3, xmm5) - vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm6) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm1, xmm3, xmm6) vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) @@ -1810,11 +1818,13 @@ void bli_sgemmsup_rv_zen_asm_6x4m lea(mem(rdx, rsi, 1), rdx) vunpckhps(xmm14, xmm12, xmm0) vpermilps(imm(0x4e), xmm0, xmm5) - vfmadd231ps(mem(rdx), xmm3, xmm0) + vmovq(mem(rdx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm0) vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) lea(mem(rdx, rsi, 1), rdx) - vfmadd231ps(mem(rdx), xmm3, xmm5) + vmovq(mem(rdx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm5) vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) jmp(.SDONE) // jump to end.