Fixed functionality failure of DGEMM pack kernel. (#31)

* Fixed functionality failure of DGEMM pack kernel.

- Corrected the mask preparation needed for load/store
in edge kernel where m = 18.

- Corrected the usage of right vector registers while
storing data back to buffer in edge kernels.

AMD-Internal: [CPUPL-6773]

* Fixed functionality failure of DGEMM pack kernel.

- Corrected the mask preparation needed for load/store
in edge kernel where m = 18.

- Corrected the usage of right vector registers while
storing data back to buffer in edge kernels.

AMD-Internal: [CPUPL-6773]

* Update bli_packm_zen4_asm_d24xk.c

---------

Co-authored-by: Harsh Dave <harsdave@amd.com>
This commit is contained in:
Dave, Harsh
2025-06-03 17:33:16 +05:30
committed by GitHub
parent dcf72968cf
commit 3c8b7895f7

View File

@@ -460,14 +460,14 @@ void bli_dpackm_zen4_asm_24xk
// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
const uint64_t k_iter = k0 / 8;
const uint64_t k_iter = (uint64_t)(k0 / 8);
/**
* prepares mask for k_left, since we are computing in multiple of 8,
* for edge cases mask is initialized for loading and storing only
* left over elements.
*/
const uint64_t k_left = k0 % 8;
const uint64_t k_left = (uint64_t)(k0 % 8);
uint8_t mask = 0xff >> (0x8 - (k_left & 7));
if (mask == 0) mask = 0xff;
@@ -479,9 +479,9 @@ void bli_dpackm_zen4_asm_24xk
// where elements of each column of the packed matrix P are contiguous.
// (This packm kernel can still be used to pack micropanels of matrix B
// in a gemm operation.)
const uint64_t inca = inca0;
const uint64_t lda = lda0;
const uint64_t ldp = ldp0;
const uint64_t inca = (uint64_t)inca0;
const uint64_t lda = (uint64_t)lda0;
const uint64_t ldp = (uint64_t)ldp0;
const bool gs = ( inca0 != 1 && lda0 != 1 );
@@ -1583,7 +1583,7 @@ void bli_dpackm_zen4_asm_24xk
LABEL(.UPDATEDONEL2)
vmovupd(mem(rax, 0), xmm6 MASK_KZ(2))
vmovupd(mem(rax, 0), zmm6 MASK_KZ(2))
vmulpd(zmm6, zmm17, zmm6) // scale by kappa
vmovupd(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
vmulpd(zmm8, zmm17, zmm8)
@@ -1624,58 +1624,58 @@ void bli_dpackm_zen4_asm_24xk
LABEL(.UPDATE7L3)
//Update 8x7 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(xmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
vmovupd(xmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
vmovupd(xmm3, mem(rbx, 6*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(zmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
vmovupd(zmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
vmovupd(zmm3, mem(rbx, 6*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATE6L3)
//Update 8x6 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(xmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
vmovupd(xmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(zmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
vmovupd(zmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATE5L3)
//Update 8x5 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(xmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(zmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATE4L3)
//Update 8x4 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATE3L3)
//Update 8x3 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATE2L3)
//Update 8x2 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATE1L3)
//Update 8x1 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATEDONEL3)
@@ -2118,7 +2118,7 @@ void bli_dpackm_zen4_asm_24xk
LABEL(.UPDATEDONEL2)
vmovupd(mem(rax, 0), xmm6 MASK_KZ(2))
vmovupd(mem(rax, 0), zmm6 MASK_KZ(2))
vmulpd(zmm6, zmm17, zmm6) // scale by kappa
vmovupd(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
vmulpd(zmm8, zmm17, zmm8)
@@ -2158,58 +2158,58 @@ void bli_dpackm_zen4_asm_24xk
LABEL(.UPDATE7L3)
//Update 8x7 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(xmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
vmovupd(xmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
vmovupd(xmm3, mem(rbx, 6*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(zmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
vmovupd(zmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
vmovupd(zmm3, mem(rbx, 6*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATE6L3)
//Update 8x6 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(xmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
vmovupd(xmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(zmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
vmovupd(zmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATE5L3)
//Update 8x5 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(xmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(zmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATE4L3)
//Update 8x4 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATE3L3)
//Update 8x3 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATE2L3)
//Update 8x2 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATE1L3)
//Update 8x1 tile to destination buffer
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
jmp(.UPDATEDONEL3)
LABEL(.UPDATEDONEL3)
@@ -3715,7 +3715,7 @@ void bli_dpackm_zen4_asm_24xk
LABEL(.UPDATEDONEL2)
vmovupd(mem(rax, 0), xmm6 MASK_KZ(2))
vmovupd(mem(rax, 0), zmm6 MASK_KZ(2))
vmulpd(zmm6, zmm17, zmm6) // scale by kappa
vmovupd(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
vmulpd(zmm8, zmm17, zmm8)
@@ -3847,9 +3847,12 @@ void bli_dpackm_zen4_asm_24xk
}
case 18:
{
uint8_t mmask = ((1 << 2) - 1);
begin_asm()
mov(var(mask), rdx) // rdx <- mask value
kmovw(edx, k(2)) // k(2) <- mask
mov(var(mmask), rdx) // rdx <- m mask value
kmovw(edx, k(3)) // k(3) <- mmask
mov(var(kappa), r10) // r10 <- kappa
vbroadcastsd(mem(r10), zmm17) // zmm17 <- [kappa, kappa, ..., kappa]
@@ -4237,9 +4240,9 @@ void bli_dpackm_zen4_asm_24xk
LABEL(.UPDATEDONEL2)
vmovupd(mem(rax, 0), xmm6 MASK_KZ(2))
vmovupd(mem(rax, 0), zmm6 MASK_KZ(2))
vmulpd(zmm6, zmm17, zmm6) // scale by kappa
vmovupd(mem(rax, r8, 1, 0), xmm8 MASK_KZ(2))
vmovupd(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
vmulpd(zmm8, zmm17, zmm8)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
@@ -4344,6 +4347,7 @@ void bli_dpackm_zen4_asm_24xk
: // output operands (none)
: // input operands
[mask] "m" (mask),
[mmask] "m" (mmask),
[k_iter] "m" (k_iter),
[k_left] "m" (k_left),
[a] "m" (a),