From 4b1663213cafbfc7f975926c8fce9df8d61a5a59 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Thu, 14 Jul 2022 17:55:34 -0500 Subject: [PATCH] Fixed out-of-bounds read in haswell gemmsup kernels. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Details: - Fixed memory access bugs in the bli_sgemmsup_rv_haswell_asm_Mx2() kernels, where M = {1,2,3,4,5,6}. The bugs were caused by loading four single-precision elements of C, via instructions such as: vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) in situations where only two elements are guaranteed to exist. (These bugs may not have manifested in earlier tests due to the leading dimension alignment that BLIS employs by default.) The issue was fixed by replacing lines like the one above with: vmovsd(mem(rcx), xmm0) vfmadd231ps(xmm0, xmm3, xmm4) Thus, we use vmovsd to explicitly load only two elements of C into registers, and then operate on those values using register addressing. Thanks to Daniël de Kok for reporting these bugs in #635, and to Bhaskar Nallani for proposing the fix). - CREDITS file update. Change-Id: Ib525c36bcbf20b2bbbe380da3d74d142b338fe9b --- CREDITS | 1 + .../s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c | 141 ++++++++++-------- 2 files changed, 79 insertions(+), 63 deletions(-) diff --git a/CREDITS b/CREDITS index c6d5d7151..d68bcca01 100644 --- a/CREDITS +++ b/CREDITS @@ -23,6 +23,7 @@ but many others have contributed code and feedback, including Dilyn Corner @dilyn-corner Mat Cross @matcross (NAG) @decandia50 + Daniël de Kok @danieldk (Explosion) Kay Dewhurst @jkd2016 (Max Planck Institute, Halle, Germany) Jeff Diamond (Oracle) Johannes Dieterich @iotamudelta diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c index 6090f8b0b..3cbb69a50 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c @@ -387,34 +387,39 @@ void bli_sgemmsup_rv_haswell_asm_6x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) vmovsd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm14) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) vmovsd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) @@ -846,29 +851,33 @@ void bli_sgemmsup_rv_haswell_asm_5x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) vmovsd(xmm12, mem(rcx, 0*32)) //add(rdi, rcx) @@ -1286,24 +1295,27 @@ void bli_sgemmsup_rv_haswell_asm_4x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) //add(rdi, rcx) @@ -1681,19 +1693,21 @@ void bli_sgemmsup_rv_haswell_asm_3x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2064,14 +2078,15 @@ void bli_sgemmsup_rv_haswell_asm_2x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2402,9 +2417,9 @@ void bli_sgemmsup_rv_haswell_asm_1x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) //add(rdi, rcx)