Fixed out-of-bounds read in haswell gemmsup kernels.

Details:
- Fixed memory access bugs in the bli_sgemmsup_rv_haswell_asm_Mx2()
  kernels, where M = {1,2,3,4,5,6}. The bugs were caused by loading four
  single-precision elements of C, via instructions such as:

	vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)

  in situations where only two elements are guaranteed to exist. (These
  bugs may not have manifested in earlier tests due to the leading
  dimension alignment that BLIS employs by default.) The issue was fixed
  by replacing lines like the one above with:

	vmovsd(mem(rcx), xmm0)
	vfmadd231ps(xmm0, xmm3, xmm4)

  Thus, we use vmovsd to explicitly load only two elements of C into
  registers, and then operate on those values using register addressing.
  Thanks to Daniël de Kok for reporting these bugs in #635, and to
  Bhaskar Nallani for proposing the fix).
- CREDITS file update.

Change-Id: Ib525c36bcbf20b2bbbe380da3d74d142b338fe9b
This commit is contained in:
Field G. Van Zee
2022-07-14 17:55:34 -05:00
committed by Nallani Bhaskar
parent 1d31386c02
commit 4b1663213c
2 changed files with 79 additions and 63 deletions

View File

@@ -23,6 +23,7 @@ but many others have contributed code and feedback, including
Dilyn Corner @dilyn-corner
Mat Cross @matcross (NAG)
@decandia50
Daniël de Kok @danieldk (Explosion)
Kay Dewhurst @jkd2016 (Max Planck Institute, Halle, Germany)
Jeff Diamond (Oracle)
Johannes Dieterich @iotamudelta

View File

@@ -387,34 +387,39 @@ void bli_sgemmsup_rv_haswell_asm_6x2
label(.SROWSTORED)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm6)
vmovsd(xmm6, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm8)
vmovsd(xmm8, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm10)
vmovsd(xmm10, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm12)
vmovsd(xmm12, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm14)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm14)
vmovsd(xmm14, mem(rcx, 0*32))
//add(rdi, rcx)
@@ -846,29 +851,33 @@ void bli_sgemmsup_rv_haswell_asm_5x2
label(.SROWSTORED)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm6)
vmovsd(xmm6, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm8)
vmovsd(xmm8, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm10)
vmovsd(xmm10, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm12)
vmovsd(xmm12, mem(rcx, 0*32))
//add(rdi, rcx)
@@ -1286,24 +1295,27 @@ void bli_sgemmsup_rv_haswell_asm_4x2
label(.SROWSTORED)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm6)
vmovsd(xmm6, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm8)
vmovsd(xmm8, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm10)
vmovsd(xmm10, mem(rcx, 0*32))
//add(rdi, rcx)
@@ -1681,19 +1693,21 @@ void bli_sgemmsup_rv_haswell_asm_3x2
label(.SROWSTORED)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm6)
vmovsd(xmm6, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm8)
vmovsd(xmm8, mem(rcx, 0*32))
//add(rdi, rcx)
@@ -2064,14 +2078,15 @@ void bli_sgemmsup_rv_haswell_asm_2x2
label(.SROWSTORED)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx, 0*32))
add(rdi, rcx)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm6)
vmovsd(xmm6, mem(rcx, 0*32))
//add(rdi, rcx)
@@ -2402,9 +2417,9 @@ void bli_sgemmsup_rv_haswell_asm_1x2
label(.SROWSTORED)
vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx, 0*32))
//add(rdi, rcx)