Fixed few out of bound memory reads in sgemmsup kernels

Details:
 Fixed memory access bugs in the bli_sgemmsup_rd_zen_asm_s1x16()
  kernel. The bugs were caused by loading four
  single-precision elements of C, via instructions such as:

	vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4)

        or

        vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4)

  in situations where only two elements are guaranteed to exist. (These
  bugs may not have manifested in earlier tests due to the leading
  dimension alignment that BLIS employs by default.) The issue was fixed
  by replacing lines like the one above with:

	vmovsd(mem(rcx), xmm0)
	vfmadd231ps(xmm0, xmm3, xmm4)

  Thus, we use vmovsd to explicitly load only two elements of C into
  registers, and then operate on those values using register addressing.

  AMD_CPUPLID: CPUPL-2279

Change-Id: Ic39290d651f5218b2e548351a87ac5e4b5b79c68
This commit is contained in:
Nallani Bhaskar
2022-07-29 07:31:58 +05:30
parent fde812015f
commit 1d31386c02
3 changed files with 29 additions and 16 deletions

View File

@@ -3,7 +3,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2020, Advanced Micro Devices, Inc.
Copyright (C) 2020 - 2022 , Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
@@ -516,7 +516,8 @@ void bli_sgemmsup_rd_zen_asm_1x16
je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
label(.SROWSTORED)
vfmadd231ps(mem(rcx), ymm3, ymm4)
vmovups(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovups(xmm4, mem(rcx))
jmp(.SDONE) // jump to end.

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2020, Advanced Micro Devices, Inc.
Copyright (C) 2020-2022, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -8048,15 +8048,18 @@ void bli_sgemmsup_rv_zen_asm_3x2
label(.SROWSTORED)
vfmadd231ps(mem(rcx), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)////a0a1
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx))
add(rdi, rcx)
vfmadd231ps(mem(rcx), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)////a0a1
vfmadd231ps(xmm0, xmm3, xmm6)
vmovsd(xmm6, mem(rcx))
add(rdi, rcx)
vfmadd231ps(mem(rcx), xmm3, xmm8)
vmovsd(mem(rcx), xmm0)////a0a1
vfmadd231ps(xmm0, xmm3, xmm8)
vmovsd(xmm8, mem(rcx))
jmp(.SDONE) // jump to end.
@@ -8329,11 +8332,13 @@ void bli_sgemmsup_rv_zen_asm_2x2
label(.SROWSTORED)
vfmadd231ps(mem(rcx), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)////a0a1
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx))
add(rdi, rcx)
vfmadd231ps(mem(rcx), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)////a0a1
vfmadd231ps(xmm0, xmm3, xmm6)
vmovsd(xmm6, mem(rcx))
jmp(.SDONE) // jump to end.
@@ -8577,7 +8582,8 @@ void bli_sgemmsup_rv_zen_asm_1x2
label(.SROWSTORED)
vfmadd231ps(mem(rcx), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovsd(xmm4, mem(rcx))
jmp(.SDONE) // jump to end.

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2020, Advanced Micro Devices, Inc.
Copyright (C) 2020-2022, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -2231,22 +2231,28 @@ void bli_sgemmsup_rv_zen_asm_6x2m
label(.SROWSTORED)
vfmadd231ps(mem(rcx), xmm3, xmm4)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm4)
vmovlpd(xmm4, mem(rcx))
add(rdi, rcx)
vfmadd231ps(mem(rcx), xmm3, xmm6)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm6)
vmovlpd(xmm6, mem(rcx))
add(rdi, rcx)
vfmadd231ps(mem(rcx), xmm3, xmm8)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm8)
vmovlpd(xmm8, mem(rcx))
add(rdi, rcx)
vfmadd231ps(mem(rcx), xmm3, xmm10)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm10)
vmovlpd(xmm10, mem(rcx))
add(rdi, rcx)
vfmadd231ps(mem(rcx), xmm3, xmm12)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm12)
vmovlpd(xmm12, mem(rcx))
add(rdi, rcx)
vfmadd231ps(mem(rcx), xmm3, xmm14)
vmovsd(mem(rcx), xmm0)
vfmadd231ps(xmm0, xmm3, xmm14)
vmovlpd(xmm14, mem(rcx))
jmp(.SDONE) // jump to end.