mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Fixed functionality failure of DGEMM pack kernel. (#31)
* Fixed functionality failure of DGEMM pack kernel. - Corrected the mask preparation needed for load/store in edge kernel where m = 18. - Corrected the usage of right vector registers while storing data back to buffer in edge kernels. AMD-Internal: [CPUPL-6773] * Fixed functionality failure of DGEMM pack kernel. - Corrected the mask preparation needed for load/store in edge kernel where m = 18. - Corrected the usage of right vector registers while storing data back to buffer in edge kernels. AMD-Internal: [CPUPL-6773] * Update bli_packm_zen4_asm_d24xk.c --------- Co-authored-by: Harsh Dave <harsdave@amd.com>
This commit is contained in:
@@ -460,14 +460,14 @@ void bli_dpackm_zen4_asm_24xk
|
||||
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
const uint64_t k_iter = k0 / 8;
|
||||
const uint64_t k_iter = (uint64_t)(k0 / 8);
|
||||
|
||||
/**
|
||||
* prepares mask for k_left, since we are computing in multiple of 8,
|
||||
* for edge cases mask is initialized for loading and storing only
|
||||
* left over elements.
|
||||
*/
|
||||
const uint64_t k_left = k0 % 8;
|
||||
const uint64_t k_left = (uint64_t)(k0 % 8);
|
||||
uint8_t mask = 0xff >> (0x8 - (k_left & 7));
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
@@ -479,9 +479,9 @@ void bli_dpackm_zen4_asm_24xk
|
||||
// where elements of each column of the packed matrix P are contiguous.
|
||||
// (This packm kernel can still be used to pack micropanels of matrix B
|
||||
// in a gemm operation.)
|
||||
const uint64_t inca = inca0;
|
||||
const uint64_t lda = lda0;
|
||||
const uint64_t ldp = ldp0;
|
||||
const uint64_t inca = (uint64_t)inca0;
|
||||
const uint64_t lda = (uint64_t)lda0;
|
||||
const uint64_t ldp = (uint64_t)ldp0;
|
||||
|
||||
const bool gs = ( inca0 != 1 && lda0 != 1 );
|
||||
|
||||
@@ -1583,7 +1583,7 @@ void bli_dpackm_zen4_asm_24xk
|
||||
|
||||
LABEL(.UPDATEDONEL2)
|
||||
|
||||
vmovupd(mem(rax, 0), xmm6 MASK_KZ(2))
|
||||
vmovupd(mem(rax, 0), zmm6 MASK_KZ(2))
|
||||
vmulpd(zmm6, zmm17, zmm6) // scale by kappa
|
||||
vmovupd(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
|
||||
vmulpd(zmm8, zmm17, zmm8)
|
||||
@@ -1624,58 +1624,58 @@ void bli_dpackm_zen4_asm_24xk
|
||||
|
||||
LABEL(.UPDATE7L3)
|
||||
//Update 8x7 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm3, mem(rbx, 6*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm3, mem(rbx, 6*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATE6L3)
|
||||
//Update 8x6 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATE5L3)
|
||||
//Update 8x5 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATE4L3)
|
||||
//Update 8x4 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATE3L3)
|
||||
//Update 8x3 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATE2L3)
|
||||
//Update 8x2 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATE1L3)
|
||||
//Update 8x1 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATEDONEL3)
|
||||
@@ -2118,7 +2118,7 @@ void bli_dpackm_zen4_asm_24xk
|
||||
|
||||
LABEL(.UPDATEDONEL2)
|
||||
|
||||
vmovupd(mem(rax, 0), xmm6 MASK_KZ(2))
|
||||
vmovupd(mem(rax, 0), zmm6 MASK_KZ(2))
|
||||
vmulpd(zmm6, zmm17, zmm6) // scale by kappa
|
||||
vmovupd(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
|
||||
vmulpd(zmm8, zmm17, zmm8)
|
||||
@@ -2158,58 +2158,58 @@ void bli_dpackm_zen4_asm_24xk
|
||||
|
||||
LABEL(.UPDATE7L3)
|
||||
//Update 8x7 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm3, mem(rbx, 6*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm3, mem(rbx, 6*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATE6L3)
|
||||
//Update 8x6 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm5, mem(rbx, 5*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATE5L3)
|
||||
//Update 8x5 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm1, mem(rbx, 4*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATE4L3)
|
||||
//Update 8x4 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm6, mem(rbx, 3*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATE3L3)
|
||||
//Update 8x3 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm2, mem(rbx, 2*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATE2L3)
|
||||
//Update 8x2 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(xmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm4, mem(rbx, 1*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATE1L3)
|
||||
//Update 8x1 tile to destination buffer
|
||||
vmovupd(xmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
vmovupd(zmm0, mem(rbx, 0*192 + 128) MASK_(k(3)))
|
||||
jmp(.UPDATEDONEL3)
|
||||
|
||||
LABEL(.UPDATEDONEL3)
|
||||
@@ -3715,7 +3715,7 @@ void bli_dpackm_zen4_asm_24xk
|
||||
|
||||
LABEL(.UPDATEDONEL2)
|
||||
|
||||
vmovupd(mem(rax, 0), xmm6 MASK_KZ(2))
|
||||
vmovupd(mem(rax, 0), zmm6 MASK_KZ(2))
|
||||
vmulpd(zmm6, zmm17, zmm6) // scale by kappa
|
||||
vmovupd(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
|
||||
vmulpd(zmm8, zmm17, zmm8)
|
||||
@@ -3847,9 +3847,12 @@ void bli_dpackm_zen4_asm_24xk
|
||||
}
|
||||
case 18:
|
||||
{
|
||||
uint8_t mmask = ((1 << 2) - 1);
|
||||
begin_asm()
|
||||
mov(var(mask), rdx) // rdx <- mask value
|
||||
kmovw(edx, k(2)) // k(2) <- mask
|
||||
mov(var(mmask), rdx) // rdx <- m mask value
|
||||
kmovw(edx, k(3)) // k(3) <- mmask
|
||||
|
||||
mov(var(kappa), r10) // r10 <- kappa
|
||||
vbroadcastsd(mem(r10), zmm17) // zmm17 <- [kappa, kappa, ..., kappa]
|
||||
@@ -4237,9 +4240,9 @@ void bli_dpackm_zen4_asm_24xk
|
||||
|
||||
LABEL(.UPDATEDONEL2)
|
||||
|
||||
vmovupd(mem(rax, 0), xmm6 MASK_KZ(2))
|
||||
vmovupd(mem(rax, 0), zmm6 MASK_KZ(2))
|
||||
vmulpd(zmm6, zmm17, zmm6) // scale by kappa
|
||||
vmovupd(mem(rax, r8, 1, 0), xmm8 MASK_KZ(2))
|
||||
vmovupd(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
|
||||
vmulpd(zmm8, zmm17, zmm8)
|
||||
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
@@ -4344,6 +4347,7 @@ void bli_dpackm_zen4_asm_24xk
|
||||
: // output operands (none)
|
||||
: // input operands
|
||||
[mask] "m" (mask),
|
||||
[mmask] "m" (mmask),
|
||||
[k_iter] "m" (k_iter),
|
||||
[k_left] "m" (k_left),
|
||||
[a] "m" (a),
|
||||
|
||||
Reference in New Issue
Block a user