mirror of
https://github.com/amd/blis.git
synced 2026-05-11 17:50:00 +00:00
Added in row storage support for C matrix.
- Added in-register transpose support for c matrix to support row stored C matrix for dgemm sup. - Support is added for all edge case kernels. - FMA are made independent of each other, for faster computation while storing data back to C matrix. AMD-Internal: [CPUPL-2966] Change-Id: I1d13af99a17ee66adbf5f537a4664ade489a7cad
This commit is contained in:
@@ -118,13 +118,6 @@ err_t bli_gemmsup
|
||||
|
||||
if((bli_arch_query_id() == BLIS_ARCH_ZEN4) && (bli_obj_dt(a) == BLIS_DOUBLE))
|
||||
{
|
||||
// This check will be removed once transpose and store of C matrix inside
|
||||
// the kernel is supported.
|
||||
if((stor_id == BLIS_RCC || stor_id == BLIS_CRR || stor_id == BLIS_RRC))
|
||||
{
|
||||
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Unsuppported storage type for dgemm.");
|
||||
return BLIS_FAILURE;
|
||||
}
|
||||
// override the existing blocksizes with 24x8 specific ones.
|
||||
// This can be removed when we use same blocksizes and function pointers
|
||||
// for all level-3 SUP routines.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -38,10 +38,355 @@
|
||||
#include "bli_x86_asm_macros.h"
|
||||
#define TAIL_NITER 3
|
||||
|
||||
/**
|
||||
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
|
||||
* and store the results in D1.
|
||||
* S1 : 1 9 3 11 5 13 7 15
|
||||
* S2 : 2 10 4 12 6 14 8 16
|
||||
* D1 : 1 9 5 13 2 10 6 14
|
||||
* D2 : 3 11 7 15 4 12 8 16
|
||||
*/
|
||||
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
|
||||
|
||||
/**
|
||||
* Unpacks and interleave low half and high half of each
|
||||
* 128-bit lane in S1 and S2 and store into D1 and D2
|
||||
* respectively.
|
||||
* S1 : 1 2 3 4 5 6 7 8
|
||||
* S2 : 9 10 11 12 13 14 15 16
|
||||
* D1 : 1 9 3 11 5 13 7 15
|
||||
* D2 : 2 10 4 12 6 14 8 16
|
||||
*/
|
||||
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
|
||||
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
|
||||
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
|
||||
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1_BZ \
|
||||
\
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm8 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
|
||||
|
||||
/* These kernels Assume that A matrix needs to be in col-major order
|
||||
* B matrix can be col/row-major
|
||||
* C matrix can be col/row-major though support for row-major order will
|
||||
* be added by a separate commit.
|
||||
* C matrix can be col/row-major
|
||||
* Prefetch for C is done assuming that C is col-stored.
|
||||
* Prefetch of B is done assuming that the matrix is col-stored.
|
||||
* Prefetch for B and C matrices when row-stored is yet to be added.
|
||||
@@ -93,6 +438,10 @@ void bli_dgemmsup_rv_zen4_asm_24x1
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -105,6 +454,8 @@ void bli_dgemmsup_rv_zen4_asm_24x1
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -489,8 +840,102 @@ void bli_dgemmsup_rv_zen4_asm_24x1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
vunpcklpd( zmm8, zmm6, zmm0)
|
||||
vunpckhpd( zmm8, zmm6, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x1 tile updated
|
||||
|
||||
vunpcklpd( zmm9, zmm7, zmm0)
|
||||
vunpckhpd( zmm9, zmm7, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8
|
||||
//Second 8x1 tile updated
|
||||
|
||||
vunpcklpd( zmm29, zmm28, zmm0)
|
||||
vunpckhpd( zmm29, zmm28, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Third 8x1 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -507,8 +952,100 @@ void bli_dgemmsup_rv_zen4_asm_24x1
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
vunpcklpd( zmm8, zmm6, zmm0)
|
||||
vunpckhpd( zmm8, zmm6, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x1 tile updated
|
||||
|
||||
vunpcklpd( zmm9, zmm7, zmm0)
|
||||
vunpckhpd( zmm9, zmm7, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//Second 8x1 tile updated
|
||||
|
||||
vunpcklpd( zmm29, zmm28, zmm0)
|
||||
vunpckhpd( zmm29, zmm28, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -533,7 +1070,8 @@ void bli_dgemmsup_rv_zen4_asm_24x1
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -594,6 +1132,11 @@ void bli_dgemmsup_rv_zen4_asm_16x1
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -606,6 +1149,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -934,8 +1479,90 @@ void bli_dgemmsup_rv_zen4_asm_16x1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
vunpcklpd( zmm8, zmm6, zmm0)
|
||||
vunpckhpd( zmm8, zmm6, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x1 tile updated
|
||||
|
||||
vunpcklpd( zmm9, zmm7, zmm0)
|
||||
vunpckhpd( zmm9, zmm7, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//8x1 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -951,8 +1578,88 @@ void bli_dgemmsup_rv_zen4_asm_16x1
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
vunpcklpd( zmm8, zmm6, zmm0)
|
||||
vunpckhpd( zmm8, zmm6, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x1 tile updated
|
||||
|
||||
vunpcklpd( zmm9, zmm7, zmm0)
|
||||
vunpckhpd( zmm9, zmm7, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -977,7 +1684,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -1038,6 +1746,11 @@ void bli_dgemmsup_rv_zen4_asm_8x1
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -1050,6 +1763,8 @@ void bli_dgemmsup_rv_zen4_asm_8x1
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -1322,8 +2037,78 @@ void bli_dgemmsup_rv_zen4_asm_8x1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
vunpcklpd( zmm8, zmm6, zmm0)
|
||||
vunpckhpd( zmm8, zmm6, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//8x1 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -1338,8 +2123,75 @@ void bli_dgemmsup_rv_zen4_asm_8x1
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
vunpcklpd( zmm8, zmm6, zmm0)
|
||||
vunpckhpd( zmm8, zmm6, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -1364,7 +2216,8 @@ void bli_dgemmsup_rv_zen4_asm_8x1
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
|
||||
@@ -38,10 +38,355 @@
|
||||
#include "bli_x86_asm_macros.h"
|
||||
#define TAIL_NITER 3
|
||||
|
||||
/**
|
||||
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
|
||||
* and store the results in D1.
|
||||
* S1 : 1 9 3 11 5 13 7 15
|
||||
* S2 : 2 10 4 12 6 14 8 16
|
||||
* D1 : 1 9 5 13 2 10 6 14
|
||||
* D2 : 3 11 7 15 4 12 8 16
|
||||
*/
|
||||
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
|
||||
|
||||
/**
|
||||
* Unpacks and interleave low half and high half of each
|
||||
* 128-bit lane in S1 and S2 and store into D1 and D2
|
||||
* respectively.
|
||||
* S1 : 1 2 3 4 5 6 7 8
|
||||
* S2 : 9 10 11 12 13 14 15 16
|
||||
* D1 : 1 9 3 11 5 13 7 15
|
||||
* D2 : 2 10 4 12 6 14 8 16
|
||||
*/
|
||||
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
|
||||
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
|
||||
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
|
||||
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1_BZ \
|
||||
\
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm8 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
|
||||
|
||||
/* These kernels Assume that A matrix needs to be in col-major order
|
||||
* B matrix can be col/row-major
|
||||
* C matrix can be col/row-major though support for row-major order will
|
||||
* be added by a separate commit.
|
||||
* C matrix can be col/row-major
|
||||
* Prefetch for C is done assuming that C is col-stored.
|
||||
* Prefetch of B is done assuming that the matrix is col-stored.
|
||||
* Prefetch for B and C matrices when row-stored is yet to be added.
|
||||
@@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x2
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x2
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -607,8 +959,102 @@ void bli_dgemmsup_rv_zen4_asm_24x2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
vunpcklpd( zmm8, zmm6, zmm0)
|
||||
vunpckhpd( zmm8, zmm6, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x2 tile updated
|
||||
|
||||
vunpcklpd( zmm9, zmm7, zmm0)
|
||||
vunpckhpd( zmm9, zmm7, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8
|
||||
//Second 8x2 tile updated
|
||||
|
||||
vunpcklpd( zmm29, zmm28, zmm0)
|
||||
vunpckhpd( zmm29, zmm28, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Third 8x2 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -628,8 +1074,100 @@ void bli_dgemmsup_rv_zen4_asm_24x2
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
vunpcklpd( zmm8, zmm6, zmm0)
|
||||
vunpckhpd( zmm8, zmm6, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x2 tile updated
|
||||
|
||||
vunpcklpd( zmm9, zmm7, zmm0)
|
||||
vunpckhpd( zmm9, zmm7, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//Second 8x2 tile updated
|
||||
|
||||
vunpcklpd( zmm29, zmm28, zmm0)
|
||||
vunpckhpd( zmm29, zmm28, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -654,7 +1192,8 @@ void bli_dgemmsup_rv_zen4_asm_24x2
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -715,6 +1254,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -727,6 +1271,8 @@ void bli_dgemmsup_rv_zen4_asm_16x2
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -1143,8 +1689,90 @@ void bli_dgemmsup_rv_zen4_asm_16x2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
vunpcklpd( zmm8, zmm6, zmm0)
|
||||
vunpckhpd( zmm8, zmm6, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x2 tile updated
|
||||
|
||||
vunpcklpd( zmm9, zmm7, zmm0)
|
||||
vunpckhpd( zmm9, zmm7, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Second 8x2 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -1162,8 +1790,88 @@ void bli_dgemmsup_rv_zen4_asm_16x2
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
vunpcklpd( zmm8, zmm6, zmm0)
|
||||
vunpckhpd( zmm8, zmm6, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x2 tile updated
|
||||
|
||||
vunpcklpd( zmm9, zmm7, zmm0)
|
||||
vunpckhpd( zmm9, zmm7, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -1188,7 +1896,8 @@ void bli_dgemmsup_rv_zen4_asm_16x2
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -1249,6 +1958,11 @@ void bli_dgemmsup_rv_zen4_asm_8x2
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -1261,6 +1975,8 @@ void bli_dgemmsup_rv_zen4_asm_8x2
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -1591,8 +2307,78 @@ void bli_dgemmsup_rv_zen4_asm_8x2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
vunpcklpd( zmm8, zmm6, zmm0)
|
||||
vunpckhpd( zmm8, zmm6, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//8x2 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -1608,8 +2394,75 @@ void bli_dgemmsup_rv_zen4_asm_8x2
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
vunpcklpd( zmm8, zmm6, zmm0)
|
||||
vunpckhpd( zmm8, zmm6, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -1634,7 +2487,8 @@ void bli_dgemmsup_rv_zen4_asm_8x2
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
|
||||
@@ -38,10 +38,355 @@
|
||||
#include "bli_x86_asm_macros.h"
|
||||
#define TAIL_NITER 3
|
||||
|
||||
/**
|
||||
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
|
||||
* and store the results in D1.
|
||||
* S1 : 1 9 3 11 5 13 7 15
|
||||
* S2 : 2 10 4 12 6 14 8 16
|
||||
* D1 : 1 9 5 13 2 10 6 14
|
||||
* D2 : 3 11 7 15 4 12 8 16
|
||||
*/
|
||||
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
|
||||
|
||||
/**
|
||||
* Unpacks and interleave low half and high half of each
|
||||
* 128-bit lane in S1 and S2 and store into D1 and D2
|
||||
* respectively.
|
||||
* S1 : 1 2 3 4 5 6 7 8
|
||||
* S2 : 9 10 11 12 13 14 15 16
|
||||
* D1 : 1 9 3 11 5 13 7 15
|
||||
* D2 : 2 10 4 12 6 14 8 16
|
||||
*/
|
||||
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
|
||||
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
|
||||
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
|
||||
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1_BZ \
|
||||
\
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm8 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
|
||||
|
||||
/* These kernels Assume that A matrix needs to be in col-major order
|
||||
* B matrix can be col/row-major
|
||||
* C matrix can be col/row-major though support for row-major order will
|
||||
* be added by a separate commit.
|
||||
* C matrix can be col/row-major
|
||||
* Prefetch for C is done assuming that C is col-stored.
|
||||
* Prefetch of B is done assuming that the matrix is col-stored.
|
||||
* Prefetch for B and C matrices when row-stored is yet to be added.
|
||||
@@ -93,6 +438,12 @@ void bli_dgemmsup_rv_zen4_asm_24x3
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -105,6 +456,8 @@ void bli_dgemmsup_rv_zen4_asm_24x3
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -725,8 +1078,99 @@ void bli_dgemmsup_rv_zen4_asm_24x3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x3 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8
|
||||
//Second 8x3 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Third 8x3 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -749,8 +1193,97 @@ void bli_dgemmsup_rv_zen4_asm_24x3
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x3 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//Second 8x3 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -775,7 +1308,8 @@ void bli_dgemmsup_rv_zen4_asm_24x3
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -836,6 +1370,11 @@ void bli_dgemmsup_rv_zen4_asm_16x3
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -848,6 +1387,8 @@ void bli_dgemmsup_rv_zen4_asm_16x3
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -1352,8 +1893,88 @@ void bli_dgemmsup_rv_zen4_asm_16x3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x3 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Second 8x3 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -1373,8 +1994,86 @@ void bli_dgemmsup_rv_zen4_asm_16x3
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x3 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -1399,7 +2098,8 @@ void bli_dgemmsup_rv_zen4_asm_16x3
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -1460,7 +2160,12 @@ void bli_dgemmsup_rv_zen4_asm_8x3
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
mov(var(a), rax) // load address of a
|
||||
@@ -1472,6 +2177,8 @@ void bli_dgemmsup_rv_zen4_asm_8x3
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -1860,8 +2567,77 @@ void bli_dgemmsup_rv_zen4_asm_8x3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//8x3 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -1878,8 +2654,75 @@ void bli_dgemmsup_rv_zen4_asm_8x3
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
//8x3 tile updated
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -1904,7 +2747,8 @@ void bli_dgemmsup_rv_zen4_asm_8x3
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
|
||||
@@ -38,10 +38,355 @@
|
||||
#include "bli_x86_asm_macros.h"
|
||||
#define TAIL_NITER 3
|
||||
|
||||
/**
|
||||
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
|
||||
* and store the results in D1.
|
||||
* S1 : 1 9 3 11 5 13 7 15
|
||||
* S2 : 2 10 4 12 6 14 8 16
|
||||
* D1 : 1 9 5 13 2 10 6 14
|
||||
* D2 : 3 11 7 15 4 12 8 16
|
||||
*/
|
||||
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
|
||||
|
||||
/**
|
||||
* Unpacks and interleave low half and high half of each
|
||||
* 128-bit lane in S1 and S2 and store into D1 and D2
|
||||
* respectively.
|
||||
* S1 : 1 2 3 4 5 6 7 8
|
||||
* S2 : 9 10 11 12 13 14 15 16
|
||||
* D1 : 1 9 3 11 5 13 7 15
|
||||
* D2 : 2 10 4 12 6 14 8 16
|
||||
*/
|
||||
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
|
||||
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
|
||||
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
|
||||
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1_BZ \
|
||||
\
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm8 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
|
||||
|
||||
/* These kernels Assume that A matrix needs to be in col-major order
|
||||
* B matrix can be col/row-major
|
||||
* C matrix can be col/row-major though support for row-major order will
|
||||
* be added by a separate commit.
|
||||
* C matrix can be col/row-major
|
||||
* Prefetch for C is done assuming that C is col-stored.
|
||||
* Prefetch of B is done assuming that the matrix is col-stored.
|
||||
* Prefetch for B and C matrices when row-stored is yet to be added.
|
||||
@@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x4
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x4
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -845,8 +1197,96 @@ void bli_dgemmsup_rv_zen4_asm_24x4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x4 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8
|
||||
//Second 8x4 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Third 8x4 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -872,8 +1312,94 @@ void bli_dgemmsup_rv_zen4_asm_24x4
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x4 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//Second 8x4 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -898,7 +1424,8 @@ void bli_dgemmsup_rv_zen4_asm_24x4
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -959,6 +1486,11 @@ void bli_dgemmsup_rv_zen4_asm_16x4
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -971,6 +1503,8 @@ void bli_dgemmsup_rv_zen4_asm_16x4
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -1565,8 +2099,88 @@ void bli_dgemmsup_rv_zen4_asm_16x4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x4 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Second 8x4 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -1588,8 +2202,86 @@ void bli_dgemmsup_rv_zen4_asm_16x4
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x4 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -1614,7 +2306,8 @@ void bli_dgemmsup_rv_zen4_asm_16x4
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -1675,6 +2368,11 @@ void bli_dgemmsup_rv_zen4_asm_8x4
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -1687,6 +2385,8 @@ void bli_dgemmsup_rv_zen4_asm_8x4
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -2135,8 +2835,77 @@ void bli_dgemmsup_rv_zen4_asm_8x4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//8x4 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -2154,8 +2923,74 @@ void bli_dgemmsup_rv_zen4_asm_8x4
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -2180,7 +3015,8 @@ void bli_dgemmsup_rv_zen4_asm_8x4
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
|
||||
@@ -38,10 +38,355 @@
|
||||
#include "bli_x86_asm_macros.h"
|
||||
#define TAIL_NITER 3
|
||||
|
||||
/**
|
||||
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
|
||||
* and store the results in D1.
|
||||
* S1 : 1 9 3 11 5 13 7 15
|
||||
* S2 : 2 10 4 12 6 14 8 16
|
||||
* D1 : 1 9 5 13 2 10 6 14
|
||||
* D2 : 3 11 7 15 4 12 8 16
|
||||
*/
|
||||
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
|
||||
|
||||
/**
|
||||
* Unpacks and interleave low half and high half of each
|
||||
* 128-bit lane in S1 and S2 and store into D1 and D2
|
||||
* respectively.
|
||||
* S1 : 1 2 3 4 5 6 7 8
|
||||
* S2 : 9 10 11 12 13 14 15 16
|
||||
* D1 : 1 9 3 11 5 13 7 15
|
||||
* D2 : 2 10 4 12 6 14 8 16
|
||||
*/
|
||||
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
|
||||
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
|
||||
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
|
||||
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1_BZ \
|
||||
\
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm8 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
|
||||
|
||||
/* These kernels Assume that A matrix needs to be in col-major order
|
||||
* B matrix can be col/row-major
|
||||
* C matrix can be col/row-major though support for row-major order will
|
||||
* be added by a separate commit.
|
||||
* C matrix can be col/row-major
|
||||
* Prefetch for C is done assuming that C is col-stored.
|
||||
* Prefetch of B is done assuming that the matrix is col-stored.
|
||||
* Prefetch for B and C matrices when row-stored is yet to be added.
|
||||
@@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x5
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x5
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -998,8 +1350,103 @@ void bli_dgemmsup_rv_zen4_asm_24x5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
vunpcklpd(zmm16, zmm14, zmm0)
|
||||
vunpckhpd(zmm16, zmm14, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x5 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
vunpcklpd(zmm17, zmm15, zmm0)
|
||||
vunpckhpd(zmm17, zmm15, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8
|
||||
//Second 8x5 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
vunpcklpd(zmm25, zmm24, zmm0)
|
||||
vunpckhpd(zmm25, zmm24, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Third 8x8 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -1028,8 +1475,101 @@ void bli_dgemmsup_rv_zen4_asm_24x5
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
vunpcklpd(zmm16, zmm14, zmm0)
|
||||
vunpckhpd(zmm16, zmm14, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x5 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
vunpcklpd(zmm17, zmm15, zmm0)
|
||||
vunpckhpd(zmm17, zmm15, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//Second 8x5 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
vunpcklpd(zmm25, zmm24, zmm0)
|
||||
vunpckhpd(zmm25, zmm24, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -1054,7 +1594,8 @@ void bli_dgemmsup_rv_zen4_asm_24x5
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -1115,6 +1656,11 @@ void bli_dgemmsup_rv_zen4_asm_16x5
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -1127,6 +1673,8 @@ void bli_dgemmsup_rv_zen4_asm_16x5
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -1844,8 +2392,92 @@ void bli_dgemmsup_rv_zen4_asm_16x5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// r12 = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// r13 = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
vunpcklpd(zmm16, zmm14, zmm0)
|
||||
vunpckhpd(zmm16, zmm14, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x5 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
vunpcklpd(zmm17, zmm15, zmm0)
|
||||
vunpckhpd(zmm17, zmm15, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Second 8x5 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -1869,8 +2501,90 @@ void bli_dgemmsup_rv_zen4_asm_16x5
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
vunpcklpd(zmm16, zmm14, zmm0)
|
||||
vunpckhpd(zmm16, zmm14, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x5 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
vunpcklpd(zmm17, zmm15, zmm0)
|
||||
vunpckhpd(zmm17, zmm15, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -1895,7 +2609,8 @@ void bli_dgemmsup_rv_zen4_asm_16x5
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -1956,6 +2671,11 @@ void bli_dgemmsup_rv_zen4_asm_8x5
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -1968,6 +2688,8 @@ void bli_dgemmsup_rv_zen4_asm_8x5
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -2509,8 +3231,79 @@ void bli_dgemmsup_rv_zen4_asm_8x5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
vunpcklpd(zmm16, zmm14, zmm0)
|
||||
vunpckhpd(zmm16, zmm14, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//8x5 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -2529,8 +3322,76 @@ void bli_dgemmsup_rv_zen4_asm_8x5
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
vunpcklpd(zmm16, zmm14, zmm0)
|
||||
vunpckhpd(zmm16, zmm14, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -2555,7 +3416,8 @@ void bli_dgemmsup_rv_zen4_asm_8x5
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
|
||||
@@ -38,10 +38,355 @@
|
||||
#include "bli_x86_asm_macros.h"
|
||||
#define TAIL_NITER 3
|
||||
|
||||
/**
|
||||
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
|
||||
* and store the results in D1.
|
||||
* S1 : 1 9 3 11 5 13 7 15
|
||||
* S2 : 2 10 4 12 6 14 8 16
|
||||
* D1 : 1 9 5 13 2 10 6 14
|
||||
* D2 : 3 11 7 15 4 12 8 16
|
||||
*/
|
||||
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
|
||||
|
||||
/**
|
||||
* Unpacks and interleave low half and high half of each
|
||||
* 128-bit lane in S1 and S2 and store into D1 and D2
|
||||
* respectively.
|
||||
* S1 : 1 2 3 4 5 6 7 8
|
||||
* S2 : 9 10 11 12 13 14 15 16
|
||||
* D1 : 1 9 3 11 5 13 7 15
|
||||
* D2 : 2 10 4 12 6 14 8 16
|
||||
*/
|
||||
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
|
||||
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
|
||||
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
|
||||
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1_BZ \
|
||||
\
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm8 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
|
||||
|
||||
/* These kernels Assume that A matrix needs to be in col-major order
|
||||
* B matrix can be col/row-major
|
||||
* C matrix can be col/row-major though support for row-major order will
|
||||
* be added by a separate commit.
|
||||
* C matrix can be col/row-major
|
||||
* Prefetch for C is done assuming that C is col-stored.
|
||||
* Prefetch of B is done assuming that the matrix is col-stored.
|
||||
* Prefetch for B and C matrices when row-stored is yet to be added.
|
||||
@@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x6
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x6
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -1116,8 +1468,102 @@ void bli_dgemmsup_rv_zen4_asm_24x6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
vunpcklpd(zmm16, zmm14, zmm0)
|
||||
vunpckhpd(zmm16, zmm14, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x6 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
vunpcklpd(zmm17, zmm15, zmm0)
|
||||
vunpckhpd(zmm17, zmm15, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8
|
||||
//Second 8x6 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
vunpcklpd(zmm25, zmm24, zmm0)
|
||||
vunpckhpd(zmm25, zmm24, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Third 7x6 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -1149,8 +1595,100 @@ void bli_dgemmsup_rv_zen4_asm_24x6
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
vunpcklpd(zmm16, zmm14, zmm0)
|
||||
vunpckhpd(zmm16, zmm14, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x6 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
vunpcklpd(zmm17, zmm15, zmm0)
|
||||
vunpckhpd(zmm17, zmm15, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//Second 8x6 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
vunpcklpd(zmm25, zmm24, zmm0)
|
||||
vunpckhpd(zmm25, zmm24, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -1175,7 +1713,8 @@ void bli_dgemmsup_rv_zen4_asm_24x6
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -1236,6 +1775,11 @@ void bli_dgemmsup_rv_zen4_asm_16x6
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -1248,6 +1792,8 @@ void bli_dgemmsup_rv_zen4_asm_16x6
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -2053,8 +2599,92 @@ void bli_dgemmsup_rv_zen4_asm_16x6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
vunpcklpd(zmm16, zmm14, zmm0)
|
||||
vunpckhpd(zmm16, zmm14, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x6 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
vunpcklpd(zmm17, zmm15, zmm0)
|
||||
vunpckhpd(zmm17, zmm15, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Second 7x6 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -2080,8 +2710,90 @@ void bli_dgemmsup_rv_zen4_asm_16x6
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
vunpcklpd(zmm16, zmm14, zmm0)
|
||||
vunpckhpd(zmm16, zmm14, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x6 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
vunpcklpd(zmm17, zmm15, zmm0)
|
||||
vunpckhpd(zmm17, zmm15, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -2106,7 +2818,8 @@ void bli_dgemmsup_rv_zen4_asm_16x6
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -2167,6 +2880,11 @@ void bli_dgemmsup_rv_zen4_asm_8x6
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -2179,6 +2897,8 @@ void bli_dgemmsup_rv_zen4_asm_8x6
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -2778,8 +3498,79 @@ void bli_dgemmsup_rv_zen4_asm_8x6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
vunpcklpd(zmm16, zmm14, zmm0)
|
||||
vunpckhpd(zmm16, zmm14, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//7x6 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -2799,8 +3590,76 @@ void bli_dgemmsup_rv_zen4_asm_8x6
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
vunpcklpd(zmm16, zmm14, zmm0)
|
||||
vunpckhpd(zmm16, zmm14, zmm1)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -2825,7 +3684,8 @@ void bli_dgemmsup_rv_zen4_asm_8x6
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
|
||||
@@ -38,10 +38,355 @@
|
||||
#include "bli_x86_asm_macros.h"
|
||||
#define TAIL_NITER 3
|
||||
|
||||
/**
|
||||
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
|
||||
* and store the results in D1.
|
||||
* S1 : 1 9 3 11 5 13 7 15
|
||||
* S2 : 2 10 4 12 6 14 8 16
|
||||
* D1 : 1 9 5 13 2 10 6 14
|
||||
* D2 : 3 11 7 15 4 12 8 16
|
||||
*/
|
||||
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
|
||||
|
||||
/**
|
||||
* Unpacks and interleave low half and high half of each
|
||||
* 128-bit lane in S1 and S2 and store into D1 and D2
|
||||
* respectively.
|
||||
* S1 : 1 2 3 4 5 6 7 8
|
||||
* S2 : 9 10 11 12 13 14 15 16
|
||||
* D1 : 1 9 3 11 5 13 7 15
|
||||
* D2 : 2 10 4 12 6 14 8 16
|
||||
*/
|
||||
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
|
||||
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
|
||||
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
|
||||
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \
|
||||
\
|
||||
vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) )
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2_BZ \
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3))) \
|
||||
\
|
||||
vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* mask register is set, stores the fma result back to C
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1_BZ \
|
||||
\
|
||||
vmovupd( zmm0, mem(rcx) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_8 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm8 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_7 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm3 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_6 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm18,zmm5 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\
|
||||
vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_5 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm14,zmm1 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_4 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm16,zmm6 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\
|
||||
vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_3 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm12,zmm2 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_2 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm10,zmm4 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))
|
||||
|
||||
/**
|
||||
* Loads elements from C row only if correspondnig bits in
|
||||
* mask register is set, Scales it with Beta and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_MASKED_C_1 \
|
||||
\
|
||||
vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \
|
||||
vfmadd231pd( zmm31,zmm30,zmm0 ) \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/
|
||||
|
||||
/* These kernels Assume that A matrix needs to be in col-major order
|
||||
* B matrix can be col/row-major
|
||||
* C matrix can be col/row-major though support for row-major order will
|
||||
* be added by a separate commit.
|
||||
* C matrix can be col/row-major
|
||||
* Prefetch for C is done assuming that C is col-stored.
|
||||
* Prefetch of B is done assuming that the matrix is col-stored.
|
||||
* Prefetch for B and C matrices when row-stored is yet to be added.
|
||||
@@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x7
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x7
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -1234,8 +1586,100 @@ void bli_dgemmsup_rv_zen4_asm_24x7
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x7 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8
|
||||
//Second 8x7 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Third 7x8 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -1270,8 +1714,98 @@ void bli_dgemmsup_rv_zen4_asm_24x7
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x7 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//Second 8x7 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -1296,7 +1830,8 @@ void bli_dgemmsup_rv_zen4_asm_24x7
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -1357,6 +1892,11 @@ void bli_dgemmsup_rv_zen4_asm_16x7
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -1369,6 +1909,8 @@ void bli_dgemmsup_rv_zen4_asm_16x7
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -2263,8 +2805,90 @@ void bli_dgemmsup_rv_zen4_asm_16x7
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_MASKED_C_8
|
||||
//First 8x7 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Third 7x8 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -2292,8 +2916,88 @@ void bli_dgemmsup_rv_zen4_asm_16x7
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
//First 8x7 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -2318,7 +3022,8 @@ void bli_dgemmsup_rv_zen4_asm_16x7
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
@@ -2379,6 +3084,11 @@ void bli_dgemmsup_rv_zen4_asm_8x7
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask == 0) mask = 0xff;
|
||||
|
||||
uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left
|
||||
// For special cases where n_left = 8, all 8 elements have to be loaded/stored
|
||||
// So, mask becomes 0xff(11111111)
|
||||
if (mask_n0 == 0) mask_n0 = 0xff;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
begin_asm()
|
||||
|
||||
@@ -2391,6 +3101,8 @@ void bli_dgemmsup_rv_zen4_asm_8x7
|
||||
mov(var(cs_c), rdi) // load cs_c
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
mov(var(mask_n0), rdx) // load mask
|
||||
kmovw(edx, k(3)) // move mask to k3 register
|
||||
lea(mem(, r8, 8), r8) // rs_b *= sizeof(double)
|
||||
lea(mem(, r9, 8), r9) // cs_b *= sizeof(double)
|
||||
lea(mem(, r10, 8), r10) // cs_a *= sizeof(double)
|
||||
@@ -3048,8 +3760,78 @@ void bli_dgemmsup_rv_zen4_asm_8x7
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_MASKED_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_MASKED_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_MASKED_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_MASKED_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_MASKED_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_MASKED_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_MASKED_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_MASKED_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Third 7x7 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -3070,8 +3852,75 @@ void bli_dgemmsup_rv_zen4_asm_8x7
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_MASKED_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_MASKED_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_MASKED_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_MASKED_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_MASKED_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_MASKED_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_MASKED_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_MASKED_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -3096,7 +3945,8 @@ void bli_dgemmsup_rv_zen4_asm_8x7
|
||||
[cs_c] "m" (cs_c),
|
||||
[n0] "m" (n0),
|
||||
[m0] "m" (m0),
|
||||
[mask] "m" (mask)
|
||||
[mask] "m" (mask),
|
||||
[mask_n0] "m" (mask_n0)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
|
||||
@@ -38,10 +38,318 @@
|
||||
#include "bli_x86_asm_macros.h"
|
||||
#define TAIL_NITER 3
|
||||
|
||||
/**
|
||||
* Shuffle 2 double-precision elements selected by imm8 from S1 and S2,
|
||||
* and store the results in D1.
|
||||
* S1 : 1 9 3 11 5 13 7 15
|
||||
* S2 : 2 10 4 12 6 14 8 16
|
||||
* D1 : 1 9 5 13 2 10 6 14
|
||||
* D2 : 3 11 7 15 4 12 8 16
|
||||
*/
|
||||
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
|
||||
|
||||
/**
|
||||
* Unpacks and interleave low half and high half of each
|
||||
* 128-bit lane in S1 and S2 and store into D1 and D2
|
||||
* respectively.
|
||||
* S1 : 1 2 3 4 5 6 7 8
|
||||
* S2 : 9 10 11 12 13 14 15 16
|
||||
* D1 : 1 9 3 11 5 13 7 15
|
||||
* D2 : 2 10 4 12 6 14 8 16
|
||||
*/
|
||||
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \
|
||||
vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \
|
||||
vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \
|
||||
vunpckhpd( zmm(S3), zmm(S4), zmm(D4))
|
||||
|
||||
/**
|
||||
* Stores fma result back to C
|
||||
*/
|
||||
#define UPDATE_C_8_BZ \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
|
||||
\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) ) \
|
||||
\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) ) \
|
||||
\
|
||||
vmovupd( zmm6, (rcx, r12, 1) ) \
|
||||
\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) ) \
|
||||
\
|
||||
vmovupd( zmm5, (rcx, r13, 1) ) \
|
||||
\
|
||||
vmovupd( zmm3, (rcx, r12, 2) ) \
|
||||
\
|
||||
vmovupd( zmm8, (rcx, rdx, 1) ) \
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* Stores fma result back to C
|
||||
*/
|
||||
#define UPDATE_C_7_BZ \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
|
||||
\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) ) \
|
||||
\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) ) \
|
||||
\
|
||||
vmovupd( zmm6, (rcx, r12, 1) ) \
|
||||
\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) ) \
|
||||
\
|
||||
vmovupd( zmm5, (rcx, r13, 1) ) \
|
||||
\
|
||||
vmovupd( zmm3, (rcx, r12, 2) )
|
||||
|
||||
/**
|
||||
* Stores fma result back to C
|
||||
*/
|
||||
#define UPDATE_C_6_BZ \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
|
||||
\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) ) \
|
||||
\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) ) \
|
||||
\
|
||||
vmovupd( zmm6, (rcx, r12, 1) ) \
|
||||
\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) ) \
|
||||
\
|
||||
vmovupd( zmm5, (rcx, r13, 1) )
|
||||
|
||||
/**
|
||||
* Stores fma result back to C
|
||||
*/
|
||||
#define UPDATE_C_5_BZ \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
|
||||
\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) ) \
|
||||
\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) ) \
|
||||
\
|
||||
vmovupd( zmm6, (rcx, r12, 1) ) \
|
||||
\
|
||||
vmovupd( zmm1, (rcx, rsi, 4) )
|
||||
|
||||
/**
|
||||
* Stores fma result back to C
|
||||
*/
|
||||
#define UPDATE_C_4_BZ \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
|
||||
\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) ) \
|
||||
\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) ) \
|
||||
\
|
||||
vmovupd( zmm6, (rcx, r12, 1) )
|
||||
|
||||
/**
|
||||
* Stores fma result back to C
|
||||
*/
|
||||
#define UPDATE_C_3_BZ \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
|
||||
\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) ) \
|
||||
\
|
||||
vmovupd( zmm2, (rcx, rsi, 2) )
|
||||
|
||||
/**
|
||||
* Stores fma result back to C
|
||||
*/
|
||||
#define UPDATE_C_2_BZ \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
|
||||
\
|
||||
vmovupd( zmm4, (rcx, rsi, 1) )
|
||||
|
||||
/**
|
||||
* Stores fma result back to C
|
||||
*/
|
||||
#define UPDATE_C_1_BZ \
|
||||
\
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \
|
||||
|
||||
/**
|
||||
* Loads elements from C row, Scales it with Beta
|
||||
* and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_C_8 \
|
||||
\
|
||||
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
|
||||
vmovupd( zmm4, (rcx, rsi, 1) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \
|
||||
vmovupd( zmm2, (rcx, rsi, 2) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \
|
||||
vmovupd( zmm6, (rcx, r12, 1) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \
|
||||
vmovupd( zmm1, (rcx, rsi, 4) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, r13, 1),zmm31,zmm5 ) \
|
||||
vmovupd( zmm5, (rcx, r13, 1) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, r12, 2),zmm31,zmm3 ) \
|
||||
vmovupd( zmm3, (rcx, r12, 2) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rdx, 1),zmm31,zmm8 ) \
|
||||
vmovupd( zmm8, (rcx, rdx, 1) )\
|
||||
add(r14, rcx)
|
||||
|
||||
/**
|
||||
* Loads elements from C row, Scales it with Beta
|
||||
* and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_C_7 \
|
||||
\
|
||||
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
|
||||
vmovupd( zmm4, (rcx, rsi, 1) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \
|
||||
vmovupd( zmm2, (rcx, rsi, 2) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \
|
||||
vmovupd( zmm6, (rcx, r12, 1) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \
|
||||
vmovupd( zmm1, (rcx, rsi, 4) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, r13, 1),zmm31,zmm5 ) \
|
||||
vmovupd( zmm5, (rcx, r13, 1) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, r12, 2),zmm31,zmm3 ) \
|
||||
vmovupd( zmm3, (rcx, r12, 2) )
|
||||
|
||||
/**
|
||||
* Loads elements from C row, Scales it with Beta
|
||||
* and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_C_6 \
|
||||
\
|
||||
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
|
||||
vmovupd( zmm4, (rcx, rsi, 1) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \
|
||||
vmovupd( zmm2, (rcx, rsi, 2) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \
|
||||
vmovupd( zmm6, (rcx, r12, 1) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \
|
||||
vmovupd( zmm1, (rcx, rsi, 4) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, r13, 1),zmm31,zmm5 ) \
|
||||
vmovupd( zmm5, (rcx, r13, 1) )
|
||||
|
||||
/**
|
||||
* Loads elements from C row, Scales it with Beta
|
||||
* and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_C_5 \
|
||||
\
|
||||
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
|
||||
vmovupd( zmm4, (rcx, rsi, 1) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \
|
||||
vmovupd( zmm2, (rcx, rsi, 2) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \
|
||||
vmovupd( zmm6, (rcx, r12, 1) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \
|
||||
vmovupd( zmm1, (rcx, rsi, 4) )
|
||||
|
||||
/**
|
||||
* Loads elements from C row, Scales it with Beta
|
||||
* and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_C_4 \
|
||||
\
|
||||
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
|
||||
vmovupd( zmm4, (rcx, rsi, 1) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \
|
||||
vmovupd( zmm2, (rcx, rsi, 2) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \
|
||||
vmovupd( zmm6, (rcx, r12, 1) )
|
||||
|
||||
/**
|
||||
* Loads elements from C row, Scales it with Beta
|
||||
* and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_C_3 \
|
||||
\
|
||||
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
|
||||
vmovupd( zmm4, (rcx, rsi, 1) )\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \
|
||||
vmovupd( zmm2, (rcx, rsi, 2) )
|
||||
|
||||
/**
|
||||
* Loads elements from C row, Scales it with Beta
|
||||
* and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_C_2 \
|
||||
\
|
||||
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/\
|
||||
\
|
||||
vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \
|
||||
vmovupd( zmm4, (rcx, rsi, 1) )
|
||||
|
||||
/**
|
||||
* Loads elements from C row, Scales it with Beta
|
||||
* and adds FMA result to it.
|
||||
* Stores back the C row.
|
||||
*/
|
||||
#define UPDATE_C_1 \
|
||||
\
|
||||
vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \
|
||||
vmovupd( zmm0, (rcx) ) /*Stores back to C*/
|
||||
|
||||
/* These kernels Assume that A matrix needs to be in col-major order
|
||||
* B matrix can be col/row-major
|
||||
* C matrix can be col/row-major though support for row-major order will
|
||||
* be added by a separate commit.
|
||||
* C matrix can be col/row-major
|
||||
* Prefetch for C is done assuming that C is col-stored.
|
||||
* Prefetch of B is done assuming that the matrix is col-stored.
|
||||
* Prefetch for B and C matrices when row-stored is yet to be added.
|
||||
@@ -1352,8 +1660,102 @@ void bli_dgemmsup_rv_zen4_asm_24x8
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_C_8
|
||||
//First 8x8 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_C_8
|
||||
//Second 8x8 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//Third 7x8 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -1391,8 +1793,100 @@ void bli_dgemmsup_rv_zen4_asm_24x8
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_C_8_BZ
|
||||
//First 8x8 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
UPDATE_C_8_BZ
|
||||
//Second 8x8 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(16), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -2473,8 +2967,90 @@ void bli_dgemmsup_rv_zen4_asm_16x8
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
UPDATE_C_8
|
||||
//First 8x8 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//7x8 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -2504,8 +3080,88 @@ void bli_dgemmsup_rv_zen4_asm_16x8
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
UPDATE_C_8_BZ
|
||||
//First 8x8 tile updated
|
||||
|
||||
UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9)
|
||||
|
||||
UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
sub(imm(8), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
@@ -3318,8 +3974,78 @@ void bli_dgemmsup_rv_zen4_asm_8x8
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
label(.DROWSTORED)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
vbroadcastsd(mem(rax), zmm31)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0)
|
||||
|
||||
LABEL(.UPDATE8)
|
||||
UPDATE_C_8
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7)
|
||||
UPDATE_C_7
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6)
|
||||
UPDATE_C_6
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5)
|
||||
UPDATE_C_5
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4)
|
||||
UPDATE_C_4
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3)
|
||||
UPDATE_C_3
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2)
|
||||
UPDATE_C_2
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1)
|
||||
UPDATE_C_1
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0)
|
||||
//7x8 tile updated
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
|
||||
@@ -3341,8 +4067,75 @@ void bli_dgemmsup_rv_zen4_asm_8x8
|
||||
|
||||
|
||||
label(.DROWSTORBZ)
|
||||
// rdx = 3*rs_c
|
||||
lea(mem(rsi, rsi, 2), r12)
|
||||
// rdx = 5*rs_c
|
||||
lea(mem(r12, rsi, 2), r13)
|
||||
// rdx = 7*rs_c
|
||||
lea(mem(r12, rsi, 4), rdx)
|
||||
lea(mem( , rsi, 8), r14)
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
|
||||
// yet to be implemented
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
mov(var(m0), rdi)
|
||||
cmp(imm(8), rdi)
|
||||
JZ(.UPDATE8BZ)
|
||||
cmp(imm(7), rdi)
|
||||
JZ(.UPDATE7BZ)
|
||||
cmp(imm(6), rdi)
|
||||
JZ(.UPDATE6BZ)
|
||||
cmp(imm(5), rdi)
|
||||
JZ(.UPDATE5BZ)
|
||||
cmp(imm(4), rdi)
|
||||
JZ(.UPDATE4BZ)
|
||||
cmp(imm(3), rdi)
|
||||
JZ(.UPDATE3BZ)
|
||||
cmp(imm(2), rdi)
|
||||
JZ(.UPDATE2BZ)
|
||||
cmp(imm(1), rdi)
|
||||
JZ(.UPDATE1BZ)
|
||||
cmp(imm(0), rdi)
|
||||
JZ(.UPDATE0BZ)
|
||||
|
||||
LABEL(.UPDATE8BZ)
|
||||
UPDATE_C_8_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE7BZ)
|
||||
UPDATE_C_7_BZ
|
||||
jmp(.DDONE)
|
||||
|
||||
LABEL(.UPDATE6BZ)
|
||||
UPDATE_C_6_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE5BZ)
|
||||
UPDATE_C_5_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE4BZ)
|
||||
UPDATE_C_4_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE3BZ)
|
||||
UPDATE_C_3_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE2BZ)
|
||||
UPDATE_C_2_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE1BZ)
|
||||
UPDATE_C_1_BZ
|
||||
jmp(.DDONE) // jump to end.
|
||||
|
||||
LABEL(.UPDATE0BZ)
|
||||
label(.DDONE)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user