Added in row storage support for C matrix.

- Added in-register transpose support for c matrix to
support row stored C matrix for dgemm sup.
- Support is added for all edge case kernels.
- FMA are made independent of each other, for faster
computation while storing data back to C matrix.

AMD-Internal: [CPUPL-2966]
Change-Id: I1d13af99a17ee66adbf5f537a4664ade489a7cad
This commit is contained in:
Harsh Dave
2023-03-12 00:45:28 -06:00
parent 31a4203c32
commit c1766e312a
10 changed files with 7694 additions and 118 deletions

View File

@@ -118,13 +118,6 @@ err_t bli_gemmsup
if((bli_arch_query_id() == BLIS_ARCH_ZEN4) && (bli_obj_dt(a) == BLIS_DOUBLE))
{
// This check will be removed once transpose and store of C matrix inside
// the kernel is supported.
if((stor_id == BLIS_RCC || stor_id == BLIS_CRR || stor_id == BLIS_RRC))
{
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Unsuppported storage type for dgemm.");
return BLIS_FAILURE;
}
// override the existing blocksizes with 24x8 specific ones.
// This can be removed when we use same blocksizes and function pointers
// for all level-3 SUP routines.

File diff suppressed because it is too large Load Diff

View File

@@ -38,10 +38,355 @@
#include "bli_x86_asm_macros.h"
#define TAIL_NITER 3
/**
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
* and store the results in D1.
* S1 : 1 9 3 11 5 13 7 15
* S2 : 2 10 4 12 6 14 8 16
* D1 : 1 9 5 13 2 10 6 14
* D2 : 3 11 7 15 4 12 8 16
*/
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
\
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
/**
* Unpacks and interleave low half and high half of each
* 128-bit lane in S1 and S2 and store into D1 and D2
* respectively.
* S1 : 1 2 3 4 5 6 7 8
* S2 : 9 10 11 12 13 14 15 16
* D1 : 1 9 3 11 5 13 7 15
* D2 : 2 10 4 12 6 14 8 16
*/
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
\
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_8_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
\
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
add(r14, rcx)
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_7_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_6_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_5_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_4_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_3_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_2_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_1_BZ \
\
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_8 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm8 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
add(r14, rcx)
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_7 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_6 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_5 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_4 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_3 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_2 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_1 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
/* These kernels Assume that A matrix needs to be in col-major order
* B matrix can be col/row-major
* C matrix can be col/row-major though support for row-major order will
* be added by a separate commit.
* C matrix can be col/row-major
* Prefetch for C is done assuming that C is col-stored.
* Prefetch of B is done assuming that the matrix is col-stored.
* Prefetch for B and C matrices when row-stored is yet to be added.
@@ -93,6 +438,10 @@ void bli_dgemmsup_rv_zen4_asm_24x1
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -105,6 +454,8 @@ void bli_dgemmsup_rv_zen4_asm_24x1
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -489,8 +840,102 @@ void bli_dgemmsup_rv_zen4_asm_24x1
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
vunpcklpd( zmm8, zmm6, zmm0)
vunpckhpd( zmm8, zmm6, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x1 tile updated
vunpcklpd( zmm9, zmm7, zmm0)
vunpckhpd( zmm9, zmm7, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8
//Second 8x1 tile updated
vunpcklpd( zmm29, zmm28, zmm0)
vunpckhpd( zmm29, zmm28, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Third 8x1 tile updated
jmp(.DDONE) // jump to end.
@@ -507,8 +952,100 @@ void bli_dgemmsup_rv_zen4_asm_24x1
label(.DROWSTORBZ)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
vunpcklpd( zmm8, zmm6, zmm0)
vunpckhpd( zmm8, zmm6, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x1 tile updated
vunpcklpd( zmm9, zmm7, zmm0)
vunpckhpd( zmm9, zmm7, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8_BZ
//Second 8x1 tile updated
vunpcklpd( zmm29, zmm28, zmm0)
vunpckhpd( zmm29, zmm28, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -533,7 +1070,8 @@ void bli_dgemmsup_rv_zen4_asm_24x1
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -594,6 +1132,11 @@ void bli_dgemmsup_rv_zen4_asm_16x1
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -606,6 +1149,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -934,8 +1479,90 @@ void bli_dgemmsup_rv_zen4_asm_16x1
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
vunpcklpd( zmm8, zmm6, zmm0)
vunpckhpd( zmm8, zmm6, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x1 tile updated
vunpcklpd( zmm9, zmm7, zmm0)
vunpckhpd( zmm9, zmm7, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//8x1 tile updated
jmp(.DDONE) // jump to end.
@@ -951,8 +1578,88 @@ void bli_dgemmsup_rv_zen4_asm_16x1
label(.DROWSTORBZ)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
vunpcklpd( zmm8, zmm6, zmm0)
vunpckhpd( zmm8, zmm6, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x1 tile updated
vunpcklpd( zmm9, zmm7, zmm0)
vunpckhpd( zmm9, zmm7, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -977,7 +1684,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -1038,6 +1746,11 @@ void bli_dgemmsup_rv_zen4_asm_8x1
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -1050,6 +1763,8 @@ void bli_dgemmsup_rv_zen4_asm_8x1
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -1322,8 +2037,78 @@ void bli_dgemmsup_rv_zen4_asm_8x1
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
vunpcklpd( zmm8, zmm6, zmm0)
vunpckhpd( zmm8, zmm6, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//8x1 tile updated
jmp(.DDONE) // jump to end.
@@ -1338,8 +2123,75 @@ void bli_dgemmsup_rv_zen4_asm_8x1
label(.DROWSTORBZ)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
vunpcklpd( zmm8, zmm6, zmm0)
vunpckhpd( zmm8, zmm6, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -1364,7 +2216,8 @@ void bli_dgemmsup_rv_zen4_asm_8x1
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",

View File

@@ -38,10 +38,355 @@
#include "bli_x86_asm_macros.h"
#define TAIL_NITER 3
/**
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
* and store the results in D1.
* S1 : 1 9 3 11 5 13 7 15
* S2 : 2 10 4 12 6 14 8 16
* D1 : 1 9 5 13 2 10 6 14
* D2 : 3 11 7 15 4 12 8 16
*/
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
\
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
/**
* Unpacks and interleave low half and high half of each
* 128-bit lane in S1 and S2 and store into D1 and D2
* respectively.
* S1 : 1 2 3 4 5 6 7 8
* S2 : 9 10 11 12 13 14 15 16
* D1 : 1 9 3 11 5 13 7 15
* D2 : 2 10 4 12 6 14 8 16
*/
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
\
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_8_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
\
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
add(r14, rcx)
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_7_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_6_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_5_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_4_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_3_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_2_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_1_BZ \
\
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_8 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm8 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
add(r14, rcx)
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_7 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_6 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_5 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_4 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_3 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_2 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_1 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
/* These kernels Assume that A matrix needs to be in col-major order
* B matrix can be col/row-major
* C matrix can be col/row-major though support for row-major order will
* be added by a separate commit.
* C matrix can be col/row-major
* Prefetch for C is done assuming that C is col-stored.
* Prefetch of B is done assuming that the matrix is col-stored.
* Prefetch for B and C matrices when row-stored is yet to be added.
@@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x2
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x2
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -607,8 +959,102 @@ void bli_dgemmsup_rv_zen4_asm_24x2
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
vunpcklpd( zmm8, zmm6, zmm0)
vunpckhpd( zmm8, zmm6, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x2 tile updated
vunpcklpd( zmm9, zmm7, zmm0)
vunpckhpd( zmm9, zmm7, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8
//Second 8x2 tile updated
vunpcklpd( zmm29, zmm28, zmm0)
vunpckhpd( zmm29, zmm28, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Third 8x2 tile updated
jmp(.DDONE) // jump to end.
@@ -628,8 +1074,100 @@ void bli_dgemmsup_rv_zen4_asm_24x2
label(.DROWSTORBZ)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
vunpcklpd( zmm8, zmm6, zmm0)
vunpckhpd( zmm8, zmm6, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x2 tile updated
vunpcklpd( zmm9, zmm7, zmm0)
vunpckhpd( zmm9, zmm7, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8_BZ
//Second 8x2 tile updated
vunpcklpd( zmm29, zmm28, zmm0)
vunpckhpd( zmm29, zmm28, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -654,7 +1192,8 @@ void bli_dgemmsup_rv_zen4_asm_24x2
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -715,6 +1254,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -727,6 +1271,8 @@ void bli_dgemmsup_rv_zen4_asm_16x2
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -1143,8 +1689,90 @@ void bli_dgemmsup_rv_zen4_asm_16x2
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
vunpcklpd( zmm8, zmm6, zmm0)
vunpckhpd( zmm8, zmm6, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x2 tile updated
vunpcklpd( zmm9, zmm7, zmm0)
vunpckhpd( zmm9, zmm7, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Second 8x2 tile updated
jmp(.DDONE) // jump to end.
@@ -1162,8 +1790,88 @@ void bli_dgemmsup_rv_zen4_asm_16x2
label(.DROWSTORBZ)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
vunpcklpd( zmm8, zmm6, zmm0)
vunpckhpd( zmm8, zmm6, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x2 tile updated
vunpcklpd( zmm9, zmm7, zmm0)
vunpckhpd( zmm9, zmm7, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -1188,7 +1896,8 @@ void bli_dgemmsup_rv_zen4_asm_16x2
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -1249,6 +1958,11 @@ void bli_dgemmsup_rv_zen4_asm_8x2
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -1261,6 +1975,8 @@ void bli_dgemmsup_rv_zen4_asm_8x2
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -1591,8 +2307,78 @@ void bli_dgemmsup_rv_zen4_asm_8x2
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
vunpcklpd( zmm8, zmm6, zmm0)
vunpckhpd( zmm8, zmm6, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//8x2 tile updated
jmp(.DDONE) // jump to end.
@@ -1608,8 +2394,75 @@ void bli_dgemmsup_rv_zen4_asm_8x2
label(.DROWSTORBZ)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
vunpcklpd( zmm8, zmm6, zmm0)
vunpckhpd( zmm8, zmm6, zmm1)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -1634,7 +2487,8 @@ void bli_dgemmsup_rv_zen4_asm_8x2
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",

View File

@@ -38,10 +38,355 @@
#include "bli_x86_asm_macros.h"
#define TAIL_NITER 3
/**
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
* and store the results in D1.
* S1 : 1 9 3 11 5 13 7 15
* S2 : 2 10 4 12 6 14 8 16
* D1 : 1 9 5 13 2 10 6 14
* D2 : 3 11 7 15 4 12 8 16
*/
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
\
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
/**
* Unpacks and interleave low half and high half of each
* 128-bit lane in S1 and S2 and store into D1 and D2
* respectively.
* S1 : 1 2 3 4 5 6 7 8
* S2 : 9 10 11 12 13 14 15 16
* D1 : 1 9 3 11 5 13 7 15
* D2 : 2 10 4 12 6 14 8 16
*/
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
\
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_8_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
\
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
add(r14, rcx)
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_7_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_6_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_5_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_4_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_3_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_2_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_1_BZ \
\
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_8 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm8 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
add(r14, rcx)
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_7 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_6 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_5 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_4 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_3 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_2 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_1 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
/* These kernels Assume that A matrix needs to be in col-major order
* B matrix can be col/row-major
* C matrix can be col/row-major though support for row-major order will
* be added by a separate commit.
* C matrix can be col/row-major
* Prefetch for C is done assuming that C is col-stored.
* Prefetch of B is done assuming that the matrix is col-stored.
* Prefetch for B and C matrices when row-stored is yet to be added.
@@ -93,6 +438,12 @@ void bli_dgemmsup_rv_zen4_asm_24x3
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -105,6 +456,8 @@ void bli_dgemmsup_rv_zen4_asm_24x3
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -725,8 +1078,99 @@ void bli_dgemmsup_rv_zen4_asm_24x3
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x3 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8
//Second 8x3 tile updated
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Third 8x3 tile updated
jmp(.DDONE) // jump to end.
@@ -749,8 +1193,97 @@ void bli_dgemmsup_rv_zen4_asm_24x3
label(.DROWSTORBZ)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x3 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8_BZ
//Second 8x3 tile updated
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -775,7 +1308,8 @@ void bli_dgemmsup_rv_zen4_asm_24x3
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -836,6 +1370,11 @@ void bli_dgemmsup_rv_zen4_asm_16x3
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -848,6 +1387,8 @@ void bli_dgemmsup_rv_zen4_asm_16x3
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -1352,8 +1893,88 @@ void bli_dgemmsup_rv_zen4_asm_16x3
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x3 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Second 8x3 tile updated
jmp(.DDONE) // jump to end.
@@ -1373,8 +1994,86 @@ void bli_dgemmsup_rv_zen4_asm_16x3
label(.DROWSTORBZ)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x3 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -1399,7 +2098,8 @@ void bli_dgemmsup_rv_zen4_asm_16x3
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -1460,7 +2160,12 @@ void bli_dgemmsup_rv_zen4_asm_8x3
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
// -------------------------------------------------------------------------
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
mov(var(a), rax) // load address of a
@@ -1472,6 +2177,8 @@ void bli_dgemmsup_rv_zen4_asm_8x3
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -1860,8 +2567,77 @@ void bli_dgemmsup_rv_zen4_asm_8x3
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//8x3 tile updated
jmp(.DDONE) // jump to end.
@@ -1878,8 +2654,75 @@ void bli_dgemmsup_rv_zen4_asm_8x3
label(.DROWSTORBZ)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
//8x3 tile updated
label(.DDONE)
@@ -1904,7 +2747,8 @@ void bli_dgemmsup_rv_zen4_asm_8x3
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",

View File

@@ -38,10 +38,355 @@
#include "bli_x86_asm_macros.h"
#define TAIL_NITER 3
/**
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
* and store the results in D1.
* S1 : 1 9 3 11 5 13 7 15
* S2 : 2 10 4 12 6 14 8 16
* D1 : 1 9 5 13 2 10 6 14
* D2 : 3 11 7 15 4 12 8 16
*/
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
\
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
/**
* Unpacks and interleave low half and high half of each
* 128-bit lane in S1 and S2 and store into D1 and D2
* respectively.
* S1 : 1 2 3 4 5 6 7 8
* S2 : 9 10 11 12 13 14 15 16
* D1 : 1 9 3 11 5 13 7 15
* D2 : 2 10 4 12 6 14 8 16
*/
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
\
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_8_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
\
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
add(r14, rcx)
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_7_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_6_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_5_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_4_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_3_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_2_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_1_BZ \
\
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_8 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm8 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
add(r14, rcx)
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_7 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_6 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_5 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_4 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_3 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_2 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_1 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
/* These kernels Assume that A matrix needs to be in col-major order
* B matrix can be col/row-major
* C matrix can be col/row-major though support for row-major order will
* be added by a separate commit.
* C matrix can be col/row-major
* Prefetch for C is done assuming that C is col-stored.
* Prefetch of B is done assuming that the matrix is col-stored.
* Prefetch for B and C matrices when row-stored is yet to be added.
@@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x4
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x4
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -845,8 +1197,96 @@ void bli_dgemmsup_rv_zen4_asm_24x4
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
lea(mem(rsi, rsi, 2), r12)
lea(mem(r12, rsi, 2), r13)
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x4 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8
//Second 8x4 tile updated
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Third 8x4 tile updated
jmp(.DDONE) // jump to end.
@@ -872,8 +1312,94 @@ void bli_dgemmsup_rv_zen4_asm_24x4
label(.DROWSTORBZ)
lea(mem(rsi, rsi, 2), r12)
lea(mem(r12, rsi, 2), r13)
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x4 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8_BZ
//Second 8x4 tile updated
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -898,7 +1424,8 @@ void bli_dgemmsup_rv_zen4_asm_24x4
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -959,6 +1486,11 @@ void bli_dgemmsup_rv_zen4_asm_16x4
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -971,6 +1503,8 @@ void bli_dgemmsup_rv_zen4_asm_16x4
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -1565,8 +2099,88 @@ void bli_dgemmsup_rv_zen4_asm_16x4
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x4 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Second 8x4 tile updated
jmp(.DDONE) // jump to end.
@@ -1588,8 +2202,86 @@ void bli_dgemmsup_rv_zen4_asm_16x4
label(.DROWSTORBZ)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x4 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -1614,7 +2306,8 @@ void bli_dgemmsup_rv_zen4_asm_16x4
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -1675,6 +2368,11 @@ void bli_dgemmsup_rv_zen4_asm_8x4
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -1687,6 +2385,8 @@ void bli_dgemmsup_rv_zen4_asm_8x4
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -2135,8 +2835,77 @@ void bli_dgemmsup_rv_zen4_asm_8x4
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//8x4 tile updated
jmp(.DDONE) // jump to end.
@@ -2154,8 +2923,74 @@ void bli_dgemmsup_rv_zen4_asm_8x4
label(.DROWSTORBZ)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -2180,7 +3015,8 @@ void bli_dgemmsup_rv_zen4_asm_8x4
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",

View File

@@ -38,10 +38,355 @@
#include "bli_x86_asm_macros.h"
#define TAIL_NITER 3
/**
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
* and store the results in D1.
* S1 : 1 9 3 11 5 13 7 15
* S2 : 2 10 4 12 6 14 8 16
* D1 : 1 9 5 13 2 10 6 14
* D2 : 3 11 7 15 4 12 8 16
*/
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
\
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
/**
* Unpacks and interleave low half and high half of each
* 128-bit lane in S1 and S2 and store into D1 and D2
* respectively.
* S1 : 1 2 3 4 5 6 7 8
* S2 : 9 10 11 12 13 14 15 16
* D1 : 1 9 3 11 5 13 7 15
* D2 : 2 10 4 12 6 14 8 16
*/
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
\
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_8_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
\
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
add(r14, rcx)
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_7_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_6_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_5_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_4_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_3_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_2_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_1_BZ \
\
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_8 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm8 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
add(r14, rcx)
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_7 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_6 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_5 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_4 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_3 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_2 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_1 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
/* These kernels Assume that A matrix needs to be in col-major order
* B matrix can be col/row-major
* C matrix can be col/row-major though support for row-major order will
* be added by a separate commit.
* C matrix can be col/row-major
* Prefetch for C is done assuming that C is col-stored.
* Prefetch of B is done assuming that the matrix is col-stored.
* Prefetch for B and C matrices when row-stored is yet to be added.
@@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x5
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x5
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -998,8 +1350,103 @@ void bli_dgemmsup_rv_zen4_asm_24x5
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
lea(mem(rsi, rsi, 2), r12)
lea(mem(r12, rsi, 2), r13)
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
vunpcklpd(zmm16, zmm14, zmm0)
vunpckhpd(zmm16, zmm14, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x5 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
vunpcklpd(zmm17, zmm15, zmm0)
vunpckhpd(zmm17, zmm15, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8
//Second 8x5 tile updated
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
vunpcklpd(zmm25, zmm24, zmm0)
vunpckhpd(zmm25, zmm24, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Third 8x8 tile updated
jmp(.DDONE) // jump to end.
@@ -1028,8 +1475,101 @@ void bli_dgemmsup_rv_zen4_asm_24x5
label(.DROWSTORBZ)
lea(mem(rsi, rsi, 2), r12)
lea(mem(r12, rsi, 2), r13)
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
vunpcklpd(zmm16, zmm14, zmm0)
vunpckhpd(zmm16, zmm14, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x5 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
vunpcklpd(zmm17, zmm15, zmm0)
vunpckhpd(zmm17, zmm15, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8_BZ
//Second 8x5 tile updated
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
vunpcklpd(zmm25, zmm24, zmm0)
vunpckhpd(zmm25, zmm24, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -1054,7 +1594,8 @@ void bli_dgemmsup_rv_zen4_asm_24x5
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -1115,6 +1656,11 @@ void bli_dgemmsup_rv_zen4_asm_16x5
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -1127,6 +1673,8 @@ void bli_dgemmsup_rv_zen4_asm_16x5
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -1844,8 +2392,92 @@ void bli_dgemmsup_rv_zen4_asm_16x5
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// r12 = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// r13 = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
vunpcklpd(zmm16, zmm14, zmm0)
vunpckhpd(zmm16, zmm14, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x5 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
vunpcklpd(zmm17, zmm15, zmm0)
vunpckhpd(zmm17, zmm15, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Second 8x5 tile updated
jmp(.DDONE) // jump to end.
@@ -1869,8 +2501,90 @@ void bli_dgemmsup_rv_zen4_asm_16x5
label(.DROWSTORBZ)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
vunpcklpd(zmm16, zmm14, zmm0)
vunpckhpd(zmm16, zmm14, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x5 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
vunpcklpd(zmm17, zmm15, zmm0)
vunpckhpd(zmm17, zmm15, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -1895,7 +2609,8 @@ void bli_dgemmsup_rv_zen4_asm_16x5
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -1956,6 +2671,11 @@ void bli_dgemmsup_rv_zen4_asm_8x5
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -1968,6 +2688,8 @@ void bli_dgemmsup_rv_zen4_asm_8x5
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -2509,8 +3231,79 @@ void bli_dgemmsup_rv_zen4_asm_8x5
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
vunpcklpd(zmm16, zmm14, zmm0)
vunpckhpd(zmm16, zmm14, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//8x5 tile updated
jmp(.DDONE) // jump to end.
@@ -2529,8 +3322,76 @@ void bli_dgemmsup_rv_zen4_asm_8x5
label(.DROWSTORBZ)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
vunpcklpd(zmm16, zmm14, zmm0)
vunpckhpd(zmm16, zmm14, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -2555,7 +3416,8 @@ void bli_dgemmsup_rv_zen4_asm_8x5
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",

View File

@@ -38,10 +38,355 @@
#include "bli_x86_asm_macros.h"
#define TAIL_NITER 3
/**
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
* and store the results in D1.
* S1 : 1 9 3 11 5 13 7 15
* S2 : 2 10 4 12 6 14 8 16
* D1 : 1 9 5 13 2 10 6 14
* D2 : 3 11 7 15 4 12 8 16
*/
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
\
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
/**
* Unpacks and interleave low half and high half of each
* 128-bit lane in S1 and S2 and store into D1 and D2
* respectively.
* S1 : 1 2 3 4 5 6 7 8
* S2 : 9 10 11 12 13 14 15 16
* D1 : 1 9 3 11 5 13 7 15
* D2 : 2 10 4 12 6 14 8 16
*/
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
\
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_8_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
\
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
add(r14, rcx)
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_7_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_6_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_5_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_4_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_3_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_2_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_1_BZ \
\
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_8 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm8 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
add(r14, rcx)
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_7 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_6 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_5 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_4 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_3 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_2 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_1 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
/* These kernels Assume that A matrix needs to be in col-major order
* B matrix can be col/row-major
* C matrix can be col/row-major though support for row-major order will
* be added by a separate commit.
* C matrix can be col/row-major
* Prefetch for C is done assuming that C is col-stored.
* Prefetch of B is done assuming that the matrix is col-stored.
* Prefetch for B and C matrices when row-stored is yet to be added.
@@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x6
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x6
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -1116,8 +1468,102 @@ void bli_dgemmsup_rv_zen4_asm_24x6
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
lea(mem(rsi, rsi, 2), r12)
lea(mem(r12, rsi, 2), r13)
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
vunpcklpd(zmm16, zmm14, zmm0)
vunpckhpd(zmm16, zmm14, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x6 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
vunpcklpd(zmm17, zmm15, zmm0)
vunpckhpd(zmm17, zmm15, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8
//Second 8x6 tile updated
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
vunpcklpd(zmm25, zmm24, zmm0)
vunpckhpd(zmm25, zmm24, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Third 7x6 tile updated
jmp(.DDONE) // jump to end.
@@ -1149,8 +1595,100 @@ void bli_dgemmsup_rv_zen4_asm_24x6
label(.DROWSTORBZ)
lea(mem(rsi, rsi, 2), r12)
lea(mem(r12, rsi, 2), r13)
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
vunpcklpd(zmm16, zmm14, zmm0)
vunpckhpd(zmm16, zmm14, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x6 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
vunpcklpd(zmm17, zmm15, zmm0)
vunpckhpd(zmm17, zmm15, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8_BZ
//Second 8x6 tile updated
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
vunpcklpd(zmm25, zmm24, zmm0)
vunpckhpd(zmm25, zmm24, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -1175,7 +1713,8 @@ void bli_dgemmsup_rv_zen4_asm_24x6
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -1236,6 +1775,11 @@ void bli_dgemmsup_rv_zen4_asm_16x6
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -1248,6 +1792,8 @@ void bli_dgemmsup_rv_zen4_asm_16x6
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -2053,8 +2599,92 @@ void bli_dgemmsup_rv_zen4_asm_16x6
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
vunpcklpd(zmm16, zmm14, zmm0)
vunpckhpd(zmm16, zmm14, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x6 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
vunpcklpd(zmm17, zmm15, zmm0)
vunpckhpd(zmm17, zmm15, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Second 7x6 tile updated
jmp(.DDONE) // jump to end.
@@ -2080,8 +2710,90 @@ void bli_dgemmsup_rv_zen4_asm_16x6
label(.DROWSTORBZ)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
vunpcklpd(zmm16, zmm14, zmm0)
vunpckhpd(zmm16, zmm14, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x6 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
vunpcklpd(zmm17, zmm15, zmm0)
vunpckhpd(zmm17, zmm15, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -2106,7 +2818,8 @@ void bli_dgemmsup_rv_zen4_asm_16x6
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -2167,6 +2880,11 @@ void bli_dgemmsup_rv_zen4_asm_8x6
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -2179,6 +2897,8 @@ void bli_dgemmsup_rv_zen4_asm_8x6
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -2778,8 +3498,79 @@ void bli_dgemmsup_rv_zen4_asm_8x6
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
vunpcklpd(zmm16, zmm14, zmm0)
vunpckhpd(zmm16, zmm14, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//7x6 tile updated
jmp(.DDONE) // jump to end.
@@ -2799,8 +3590,76 @@ void bli_dgemmsup_rv_zen4_asm_8x6
label(.DROWSTORBZ)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
vunpcklpd(zmm16, zmm14, zmm0)
vunpckhpd(zmm16, zmm14, zmm1)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -2825,7 +3684,8 @@ void bli_dgemmsup_rv_zen4_asm_8x6
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",

View File

@@ -38,10 +38,355 @@
#include "bli_x86_asm_macros.h"
#define TAIL_NITER 3
/**
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
* and store the results in D1.
* S1 : 1 9 3 11 5 13 7 15
* S2 : 2 10 4 12 6 14 8 16
* D1 : 1 9 5 13 2 10 6 14
* D2 : 3 11 7 15 4 12 8 16
*/
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
\
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
/**
* Unpacks and interleave low half and high half of each
* 128-bit lane in S1 and S2 and store into D1 and D2
* respectively.
* S1 : 1 2 3 4 5 6 7 8
* S2 : 9 10 11 12 13 14 15 16
* D1 : 1 9 3 11 5 13 7 15
* D2 : 2 10 4 12 6 14 8 16
*/
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
\
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_8_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
\
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
add(r14, rcx)
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_7_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
\
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_6_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
\
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_5_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
\
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_4_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
\
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_3_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
\
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_2_BZ \
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
\
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
/**
* mask register is set, stores the fma result back to C
*/
#define UPDATE_MASKED_C_1_BZ \
\
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_8 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm8 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
add(r14, rcx)
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_7 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm3 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_6 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm18,zmm5 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_5 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm14,zmm1 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_4 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm16,zmm6 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_3 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm12,zmm2 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_2 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm10,zmm4 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
/**
* Loads elements from C row only if correspondnig bits in
* mask register is set, Scales it with Beta and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_MASKED_C_1 \
\
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
vfmadd231pd( zmm31,zmm30,zmm0 ) \
\
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
/* These kernels Assume that A matrix needs to be in col-major order
* B matrix can be col/row-major
* C matrix can be col/row-major though support for row-major order will
* be added by a separate commit.
* C matrix can be col/row-major
* Prefetch for C is done assuming that C is col-stored.
* Prefetch of B is done assuming that the matrix is col-stored.
* Prefetch for B and C matrices when row-stored is yet to be added.
@@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x7
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x7
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -1234,8 +1586,100 @@ void bli_dgemmsup_rv_zen4_asm_24x7
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
lea(mem(rsi, rsi, 2), r12)
lea(mem(r12, rsi, 2), r13)
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x7 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8
//Second 8x7 tile updated
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Third 7x8 tile updated
jmp(.DDONE) // jump to end.
@@ -1270,8 +1714,98 @@ void bli_dgemmsup_rv_zen4_asm_24x7
label(.DROWSTORBZ)
lea(mem(rsi, rsi, 2), r12)
lea(mem(r12, rsi, 2), r13)
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x7 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_MASKED_C_8_BZ
//Second 8x7 tile updated
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -1296,7 +1830,8 @@ void bli_dgemmsup_rv_zen4_asm_24x7
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -1357,6 +1892,11 @@ void bli_dgemmsup_rv_zen4_asm_16x7
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -1369,6 +1909,8 @@ void bli_dgemmsup_rv_zen4_asm_16x7
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -2263,8 +2805,90 @@ void bli_dgemmsup_rv_zen4_asm_16x7
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_MASKED_C_8
//First 8x7 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Third 7x8 tile updated
jmp(.DDONE) // jump to end.
@@ -2292,8 +2916,88 @@ void bli_dgemmsup_rv_zen4_asm_16x7
label(.DROWSTORBZ)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_MASKED_C_8_BZ
//First 8x7 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -2318,7 +3022,8 @@ void bli_dgemmsup_rv_zen4_asm_16x7
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
@@ -2379,6 +3084,11 @@ void bli_dgemmsup_rv_zen4_asm_8x7
// So, mask becomes 0xff(11111111)
if (mask == 0) mask = 0xff;
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
// So, mask becomes 0xff(11111111)
if (mask_n0 == 0) mask_n0 = 0xff;
// -------------------------------------------------------------------------
begin_asm()
@@ -2391,6 +3101,8 @@ void bli_dgemmsup_rv_zen4_asm_8x7
mov(var(cs_c), rdi) // load cs_c
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(mask_n0), rdx) // load mask
kmovw(edx, k(3)) // move mask to k3 register
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
@@ -3048,8 +3760,78 @@ void bli_dgemmsup_rv_zen4_asm_8x7
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_MASKED_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_MASKED_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_MASKED_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_MASKED_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_MASKED_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_MASKED_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_MASKED_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_MASKED_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Third 7x7 tile updated
jmp(.DDONE) // jump to end.
@@ -3070,8 +3852,75 @@ void bli_dgemmsup_rv_zen4_asm_8x7
label(.DROWSTORBZ)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_MASKED_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_MASKED_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_MASKED_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_MASKED_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_MASKED_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_MASKED_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_MASKED_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_MASKED_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -3096,7 +3945,8 @@ void bli_dgemmsup_rv_zen4_asm_8x7
[cs_c] "m" (cs_c),
[n0] "m" (n0),
[m0] "m" (m0),
[mask] "m" (mask)
[mask] "m" (mask),
[mask_n0] "m" (mask_n0)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",

View File

@@ -38,10 +38,318 @@
#include "bli_x86_asm_macros.h"
#define TAIL_NITER 3
/**
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
* and store the results in D1.
* S1 : 1 9 3 11 5 13 7 15
* S2 : 2 10 4 12 6 14 8 16
* D1 : 1 9 5 13 2 10 6 14
* D2 : 3 11 7 15 4 12 8 16
*/
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
\
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
/**
* Unpacks and interleave low half and high half of each
* 128-bit lane in S1 and S2 and store into D1 and D2
* respectively.
* S1 : 1 2 3 4 5 6 7 8
* S2 : 9 10 11 12 13 14 15 16
* D1 : 1 9 3 11 5 13 7 15
* D2 : 2 10 4 12 6 14 8 16
*/
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
\
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
/**
* Stores fma result back to C
*/
#define UPDATE_C_8_BZ \
\
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
\
vmovupd( zmm4, (rcx, rsi, 1) ) \
\
vmovupd( zmm2, (rcx, rsi, 2) ) \
\
vmovupd( zmm6, (rcx, r12, 1) ) \
\
vmovupd( zmm1, (rcx, rsi, 4) ) \
\
vmovupd( zmm5, (rcx, r13, 1) ) \
\
vmovupd( zmm3, (rcx, r12, 2) ) \
\
vmovupd( zmm8, (rcx, rdx, 1) ) \
add(r14, rcx)
/**
* Stores fma result back to C
*/
#define UPDATE_C_7_BZ \
\
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
\
vmovupd( zmm4, (rcx, rsi, 1) ) \
\
vmovupd( zmm2, (rcx, rsi, 2) ) \
\
vmovupd( zmm6, (rcx, r12, 1) ) \
\
vmovupd( zmm1, (rcx, rsi, 4) ) \
\
vmovupd( zmm5, (rcx, r13, 1) ) \
\
vmovupd( zmm3, (rcx, r12, 2) )
/**
* Stores fma result back to C
*/
#define UPDATE_C_6_BZ \
\
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
\
vmovupd( zmm4, (rcx, rsi, 1) ) \
\
vmovupd( zmm2, (rcx, rsi, 2) ) \
\
vmovupd( zmm6, (rcx, r12, 1) ) \
\
vmovupd( zmm1, (rcx, rsi, 4) ) \
\
vmovupd( zmm5, (rcx, r13, 1) )
/**
* Stores fma result back to C
*/
#define UPDATE_C_5_BZ \
\
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
\
vmovupd( zmm4, (rcx, rsi, 1) ) \
\
vmovupd( zmm2, (rcx, rsi, 2) ) \
\
vmovupd( zmm6, (rcx, r12, 1) ) \
\
vmovupd( zmm1, (rcx, rsi, 4) )
/**
* Stores fma result back to C
*/
#define UPDATE_C_4_BZ \
\
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
\
vmovupd( zmm4, (rcx, rsi, 1) ) \
\
vmovupd( zmm2, (rcx, rsi, 2) ) \
\
vmovupd( zmm6, (rcx, r12, 1) )
/**
* Stores fma result back to C
*/
#define UPDATE_C_3_BZ \
\
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
\
vmovupd( zmm4, (rcx, rsi, 1) ) \
\
vmovupd( zmm2, (rcx, rsi, 2) )
/**
* Stores fma result back to C
*/
#define UPDATE_C_2_BZ \
\
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
\
vmovupd( zmm4, (rcx, rsi, 1) )
/**
* Stores fma result back to C
*/
#define UPDATE_C_1_BZ \
\
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
/**
* Loads elements from C row, Scales it with Beta
* and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_C_8 \
\
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
\
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
vmovupd( zmm4, (rcx, rsi, 1) )\
\
vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \
vmovupd( zmm2, (rcx, rsi, 2) )\
\
vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \
vmovupd( zmm6, (rcx, r12, 1) )\
\
vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \
vmovupd( zmm1, (rcx, rsi, 4) )\
\
vfmadd231pd( mem(rcx, r13, 1),zmm31,zmm5 ) \
vmovupd( zmm5, (rcx, r13, 1) )\
\
vfmadd231pd( mem(rcx, r12, 2),zmm31,zmm3 ) \
vmovupd( zmm3, (rcx, r12, 2) )\
\
vfmadd231pd( mem(rcx, rdx, 1),zmm31,zmm8 ) \
vmovupd( zmm8, (rcx, rdx, 1) )\
add(r14, rcx)
/**
* Loads elements from C row, Scales it with Beta
* and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_C_7 \
\
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
\
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
vmovupd( zmm4, (rcx, rsi, 1) )\
\
vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \
vmovupd( zmm2, (rcx, rsi, 2) )\
\
vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \
vmovupd( zmm6, (rcx, r12, 1) )\
\
vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \
vmovupd( zmm1, (rcx, rsi, 4) )\
\
vfmadd231pd( mem(rcx, r13, 1),zmm31,zmm5 ) \
vmovupd( zmm5, (rcx, r13, 1) )\
\
vfmadd231pd( mem(rcx, r12, 2),zmm31,zmm3 ) \
vmovupd( zmm3, (rcx, r12, 2) )
/**
* Loads elements from C row, Scales it with Beta
* and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_C_6 \
\
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
\
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
vmovupd( zmm4, (rcx, rsi, 1) )\
\
vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \
vmovupd( zmm2, (rcx, rsi, 2) )\
\
vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \
vmovupd( zmm6, (rcx, r12, 1) )\
\
vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \
vmovupd( zmm1, (rcx, rsi, 4) )\
\
vfmadd231pd( mem(rcx, r13, 1),zmm31,zmm5 ) \
vmovupd( zmm5, (rcx, r13, 1) )
/**
* Loads elements from C row, Scales it with Beta
* and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_C_5 \
\
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
\
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
vmovupd( zmm4, (rcx, rsi, 1) )\
\
vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \
vmovupd( zmm2, (rcx, rsi, 2) )\
\
vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \
vmovupd( zmm6, (rcx, r12, 1) )\
\
vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \
vmovupd( zmm1, (rcx, rsi, 4) )
/**
* Loads elements from C row, Scales it with Beta
* and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_C_4 \
\
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
\
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
vmovupd( zmm4, (rcx, rsi, 1) )\
\
vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \
vmovupd( zmm2, (rcx, rsi, 2) )\
\
vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \
vmovupd( zmm6, (rcx, r12, 1) )
/**
* Loads elements from C row, Scales it with Beta
* and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_C_3 \
\
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
\
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
vmovupd( zmm4, (rcx, rsi, 1) )\
\
vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \
vmovupd( zmm2, (rcx, rsi, 2) )
/**
* Loads elements from C row, Scales it with Beta
* and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_C_2 \
\
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
\
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
vmovupd( zmm4, (rcx, rsi, 1) )
/**
* Loads elements from C row, Scales it with Beta
* and adds FMA result to it.
* Stores back the C row.
*/
#define UPDATE_C_1 \
\
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
vmovupd( zmm0, (rcx) ) /*Stores back to C*/
/* These kernels Assume that A matrix needs to be in col-major order
* B matrix can be col/row-major
* C matrix can be col/row-major though support for row-major order will
* be added by a separate commit.
* C matrix can be col/row-major
* Prefetch for C is done assuming that C is col-stored.
* Prefetch of B is done assuming that the matrix is col-stored.
* Prefetch for B and C matrices when row-stored is yet to be added.
@@ -1352,8 +1660,102 @@ void bli_dgemmsup_rv_zen4_asm_24x8
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_C_8
//First 8x8 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_C_8
//Second 8x8 tile updated
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//Third 7x8 tile updated
jmp(.DDONE) // jump to end.
@@ -1391,8 +1793,100 @@ void bli_dgemmsup_rv_zen4_asm_24x8
label(.DROWSTORBZ)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_C_8_BZ
//First 8x8 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
UPDATE_C_8_BZ
//Second 8x8 tile updated
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(16), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -2473,8 +2967,90 @@ void bli_dgemmsup_rv_zen4_asm_16x8
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
UPDATE_C_8
//First 8x8 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//7x8 tile updated
jmp(.DDONE) // jump to end.
@@ -2504,8 +3080,88 @@ void bli_dgemmsup_rv_zen4_asm_16x8
label(.DROWSTORBZ)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
UPDATE_C_8_BZ
//First 8x8 tile updated
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
mov(var(m0), rdi)
sub(imm(8), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)
@@ -3318,8 +3974,78 @@ void bli_dgemmsup_rv_zen4_asm_8x8
jmp(.DDONE) // jump to end.
label(.DROWSTORED)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
vbroadcastsd(mem(rax), zmm31)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8)
cmp(imm(7), rdi)
JZ(.UPDATE7)
cmp(imm(6), rdi)
JZ(.UPDATE6)
cmp(imm(5), rdi)
JZ(.UPDATE5)
cmp(imm(4), rdi)
JZ(.UPDATE4)
cmp(imm(3), rdi)
JZ(.UPDATE3)
cmp(imm(2), rdi)
JZ(.UPDATE2)
cmp(imm(1), rdi)
JZ(.UPDATE1)
cmp(imm(0), rdi)
JZ(.UPDATE0)
LABEL(.UPDATE8)
UPDATE_C_8
jmp(.DDONE)
LABEL(.UPDATE7)
UPDATE_C_7
jmp(.DDONE)
LABEL(.UPDATE6)
UPDATE_C_6
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5)
UPDATE_C_5
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4)
UPDATE_C_4
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3)
UPDATE_C_3
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2)
UPDATE_C_2
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1)
UPDATE_C_1
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0)
//7x8 tile updated
jmp(.DDONE) // jump to end.
@@ -3341,8 +4067,75 @@ void bli_dgemmsup_rv_zen4_asm_8x8
label(.DROWSTORBZ)
// rdx = 3*rs_c
lea(mem(rsi, rsi, 2), r12)
// rdx = 5*rs_c
lea(mem(r12, rsi, 2), r13)
// rdx = 7*rs_c
lea(mem(r12, rsi, 4), rdx)
lea(mem( , rsi, 8), r14)
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
// yet to be implemented
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
mov(var(m0), rdi)
cmp(imm(8), rdi)
JZ(.UPDATE8BZ)
cmp(imm(7), rdi)
JZ(.UPDATE7BZ)
cmp(imm(6), rdi)
JZ(.UPDATE6BZ)
cmp(imm(5), rdi)
JZ(.UPDATE5BZ)
cmp(imm(4), rdi)
JZ(.UPDATE4BZ)
cmp(imm(3), rdi)
JZ(.UPDATE3BZ)
cmp(imm(2), rdi)
JZ(.UPDATE2BZ)
cmp(imm(1), rdi)
JZ(.UPDATE1BZ)
cmp(imm(0), rdi)
JZ(.UPDATE0BZ)
LABEL(.UPDATE8BZ)
UPDATE_C_8_BZ
jmp(.DDONE)
LABEL(.UPDATE7BZ)
UPDATE_C_7_BZ
jmp(.DDONE)
LABEL(.UPDATE6BZ)
UPDATE_C_6_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE5BZ)
UPDATE_C_5_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE4BZ)
UPDATE_C_4_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE3BZ)
UPDATE_C_3_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE2BZ)
UPDATE_C_2_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE1BZ)
UPDATE_C_1_BZ
jmp(.DDONE) // jump to end.
LABEL(.UPDATE0BZ)
label(.DDONE)