From c1766e312ae818617e8b9942a3ec063fb71ff4df Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Sun, 12 Mar 2023 00:45:28 -0600 Subject: [PATCH] Added in row storage support for C matrix. - Added in-register transpose support for c matrix to support row stored C matrix for dgemm sup. - Support is added for all edge case kernels. - FMA are made independent of each other, for faster computation while storing data back to C matrix. AMD-Internal: [CPUPL-2966] Change-Id: I1d13af99a17ee66adbf5f537a4664ade489a7cad --- frame/3/bli_l3_sup.c | 7 - .../3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c | 881 ++++++++++++++++- .../sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c | 875 ++++++++++++++++- .../sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c | 876 ++++++++++++++++- .../sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c | 868 ++++++++++++++++- .../sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c | 858 ++++++++++++++++- .../sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c | 884 +++++++++++++++++- .../sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c | 882 ++++++++++++++++- .../sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c | 872 ++++++++++++++++- .../sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c | 809 +++++++++++++++- 10 files changed, 7694 insertions(+), 118 deletions(-) diff --git a/frame/3/bli_l3_sup.c b/frame/3/bli_l3_sup.c index 81c0ed9c1..228600c0a 100644 --- a/frame/3/bli_l3_sup.c +++ b/frame/3/bli_l3_sup.c @@ -118,13 +118,6 @@ err_t bli_gemmsup if((bli_arch_query_id() == BLIS_ARCH_ZEN4) && (bli_obj_dt(a) == BLIS_DOUBLE)) { - // This check will be removed once transpose and store of C matrix inside - // the kernel is supported. - if((stor_id == BLIS_RCC || stor_id == BLIS_CRR || stor_id == BLIS_RRC)) - { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Unsuppported storage type for dgemm."); - return BLIS_FAILURE; - } // override the existing blocksizes with 24x8 specific ones. // This can be removed when we use same blocksizes and function pointers // for all level-3 SUP routines. diff --git a/kernels/zen4/3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c b/kernels/zen4/3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c index 1109d7172..67df513fe 100644 --- a/kernels/zen4/3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c +++ b/kernels/zen4/3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c @@ -37,10 +37,158 @@ #include "bli_x86_asm_macros.h" #define TAIL_NITER 3 +/** + * Shuffle 2 double-precision elements selected by imm8 from S1 and S2, + * and store the results in D1 + * S1 : 1 9 3 11 5 13 7 15 + * S2 : 2 10 4 12 6 14 8 16 + * D1 : 1 9 5 13 2 10 6 14 + * D2 : 3 11 7 15 4 12 8 16 +*/ +#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \ + VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \ + +/** + * Unpacks and interleave low half and high half of each + * 128-bit lane in S1 and S2 and store into D1 and D2 + * respectively. + * S1 : 1 2 3 4 5 6 7 8 + * S2 : 9 10 11 12 13 14 15 16 + * D1 : 1 9 3 11 5 13 7 15 + * D2 : 2 10 4 12 6 14 8 16 +*/ +#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \ + vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \ + vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \ + vunpckhpd( zmm(S3), zmm(S4), zmm(D4)) + +/** + * Loads elements from C row, Scales it with Beta + * and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_C \ +\ + vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/\ +\ + vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \ + vmovupd( zmm4, (rcx, rsi, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \ + vmovupd( zmm2, (rcx, rsi, 2) )\ +\ + vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \ + vmovupd( zmm6, (rcx, r12, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \ + vmovupd( zmm1, (rcx, rsi, 4) )\ +\ + vfmadd231pd( mem(rcx, r13, 1),zmm31,zmm5 ) \ + vmovupd( zmm5, (rcx, r13, 1) )\ +\ + vfmadd231pd( mem(rcx, r12, 2),zmm31,zmm3 ) \ + vmovupd( zmm3, (rcx, r12, 2) )\ +\ + vfmadd231pd( mem(rcx, rdx, 1),zmm31,zmm8 ) \ + vmovupd( zmm8, (rcx, rdx, 1) )\ + add(r14, rcx) + + +/** + * stores FMA result to C. +*/ +#define UPDATE_C_BZ \ +\ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \ +\ + vmovupd( zmm4, (rcx, rsi, 1) ) \ +\ + vmovupd( zmm2, (rcx, rsi, 2) ) \ +\ + vmovupd( zmm6, (rcx, r12, 1) ) \ +\ + vmovupd( zmm1, (rcx, rsi, 4) ) \ +\ + vmovupd( zmm5, (rcx, r13, 1) ) \ +\ + vmovupd( zmm3, (rcx, r12, 2) ) \ +\ + vmovupd( zmm8, (rcx, rdx, 1) ) \ + add(r14, rcx) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it + * Stores back the C row. +*/ +#define UPDATE_MASKED_C \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(2)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(2)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(2)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(2)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(2)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(2)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(2)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(2)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm8 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(2))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(2)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(2)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(2)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(2)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(2)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(2)))\ + vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(2)))\ + add(r14, rcx) + +/** + * mask register is set, stores FMA result to C. +*/ +#define UPDATE_MASKED_C_BZ \ +\ + vmovupd( zmm0, mem(rcx) MASK_(k(2))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(2))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(2)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(2)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(2))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(2))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(2))) \ +\ + vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(2))) \ + add(r14, rcx) + /* These kernels Assume that A matrix needs to be in col-major order * B matrix can be col/row-major - * C matrix can be col/row-major though support for row-major order will - * be added by a separate commit. + * C matrix can be col/row-major * Prefetch for C is done assuming that C is col-stored. * Prefetch of B is done assuming that the matrix is col-stored. * Prefetch for B and C matrices when row-stored is yet to be added. @@ -1489,8 +1637,49 @@ void bli_dgemmsup_rv_zen4_asm_24x8m jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_C + //First 8x8 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_C + //Second 8x8 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_C + //Third 8x8 tile updated jmp(.DDONE) // jump to end. @@ -1528,8 +1717,48 @@ void bli_dgemmsup_rv_zen4_asm_24x8m label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_C_BZ + //First 8x8 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_C_BZ + //Second 8x8 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_C_BZ + //Third 8x8 tile updated label(.DDONE) @@ -1644,6 +1873,8 @@ void bli_dgemmsup_rv_zen4_asm_24x7m uint64_t k_iter = (uint64_t)k0 / 8; uint64_t k_left = (uint64_t)k0 % 8; + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + if ( m_iter == 0 ) goto consider_edge_cases; /* For one iteration of this loop, a block of MRxNR is computed @@ -1659,6 +1890,8 @@ void bli_dgemmsup_rv_zen4_asm_24x7m // ------------------------------------------------------------------------- begin_asm() + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register mov(var(a), rax) // load address of a mov(var(cs_a), r10) // load cs_a mov(var(b), rbx) // load address of b @@ -2901,8 +3134,49 @@ void bli_dgemmsup_rv_zen4_asm_24x7m jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x7 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x7 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x7 tile updated jmp(.DDONE) // jump to end. @@ -2937,8 +3211,48 @@ void bli_dgemmsup_rv_zen4_asm_24x7m label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_BZ + //First 8x7 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x7 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x7 tile updated label(.DDONE) @@ -2962,7 +3276,8 @@ void bli_dgemmsup_rv_zen4_asm_24x7m [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), [n0] "m" (n0), - [m0] "m" (m0) + [m0] "m" (m0), + [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -3053,6 +3368,8 @@ void bli_dgemmsup_rv_zen4_asm_24x6m uint64_t k_iter = (uint64_t)k0 / 8; uint64_t k_left = (uint64_t)k0 % 8; + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + if ( m_iter == 0 ) goto consider_edge_cases; /* For one iteration of this loop, a block of MRxNR is computed @@ -3068,6 +3385,8 @@ void bli_dgemmsup_rv_zen4_asm_24x6m // ------------------------------------------------------------------------- begin_asm() + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register mov(var(a), rax) // load address of a mov(var(cs_a), r10) // load cs_a mov(var(b), rbx) // load address of b @@ -4195,8 +4514,52 @@ void bli_dgemmsup_rv_zen4_asm_24x6m jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x6 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x6 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm25, zmm24, zmm0) + vunpckhpd(zmm25, zmm24, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x6 tile updated jmp(.DDONE) // jump to end. @@ -4228,8 +4591,51 @@ void bli_dgemmsup_rv_zen4_asm_24x6m label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_BZ + //First 8x6 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x6 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm25, zmm24, zmm0) + vunpckhpd(zmm25, zmm24, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x6 tile updated label(.DDONE) @@ -4253,7 +4659,8 @@ void bli_dgemmsup_rv_zen4_asm_24x6m [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), [n0] "m" (n0), - [m0] "m" (m0) + [m0] "m" (m0), + [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -4344,6 +4751,8 @@ void bli_dgemmsup_rv_zen4_asm_24x5m uint64_t k_iter = (uint64_t)k0 / 8; uint64_t k_left = (uint64_t)k0 % 8; + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + if ( m_iter == 0 ) goto consider_edge_cases; /* For one iteration of this loop, a block of MRxNR is computed @@ -4359,6 +4768,8 @@ void bli_dgemmsup_rv_zen4_asm_24x5m // ------------------------------------------------------------------------- begin_asm() + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register mov(var(a), rax) // load address of a mov(var(cs_a), r10) // load cs_a mov(var(b), rbx) // load address of b @@ -5371,8 +5782,52 @@ void bli_dgemmsup_rv_zen4_asm_24x5m jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x5 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x5 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm25, zmm24, zmm0) + vunpckhpd(zmm25, zmm24, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x5 tile updated jmp(.DDONE) // jump to end. @@ -5401,8 +5856,52 @@ void bli_dgemmsup_rv_zen4_asm_24x5m label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_BZ + //First 8x5 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x5 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm25, zmm24, zmm0) + vunpckhpd(zmm25, zmm24, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x5 tile updated label(.DDONE) @@ -5426,7 +5925,8 @@ void bli_dgemmsup_rv_zen4_asm_24x5m [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), [n0] "m" (n0), - [m0] "m" (m0) + [m0] "m" (m0), + [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -5517,6 +6017,8 @@ void bli_dgemmsup_rv_zen4_asm_24x4m uint64_t k_iter = (uint64_t)k0 / 8; uint64_t k_left = (uint64_t)k0 % 8; + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + if ( m_iter == 0 ) goto consider_edge_cases; /* For one iteration of this loop, a block of MRxNR is computed @@ -5532,6 +6034,8 @@ void bli_dgemmsup_rv_zen4_asm_24x4m // ------------------------------------------------------------------------- begin_asm() + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register mov(var(a), rax) // load address of a mov(var(cs_a), r10) // load cs_a mov(var(b), rbx) // load address of b @@ -6394,8 +6898,46 @@ void bli_dgemmsup_rv_zen4_asm_24x4m jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x4 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x4 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x4 tile updated jmp(.DDONE) // jump to end. @@ -6421,8 +6963,45 @@ void bli_dgemmsup_rv_zen4_asm_24x4m label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_BZ + //First 8x5 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x5 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x5 tile updated label(.DDONE) @@ -6446,7 +7025,8 @@ void bli_dgemmsup_rv_zen4_asm_24x4m [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), [n0] "m" (n0), - [m0] "m" (m0) + [m0] "m" (m0), + [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -6537,6 +7117,8 @@ void bli_dgemmsup_rv_zen4_asm_24x3m uint64_t k_iter = (uint64_t)k0 / 8; uint64_t k_left = (uint64_t)k0 % 8; + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + if ( m_iter == 0 ) goto consider_edge_cases; /* For one iteration of this loop, a block of MRxNR is computed @@ -6552,6 +7134,8 @@ void bli_dgemmsup_rv_zen4_asm_24x3m // ------------------------------------------------------------------------- begin_asm() + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register mov(var(a), rax) // load address of a mov(var(cs_a), r10) // load cs_a mov(var(b), rbx) // load address of b @@ -7297,8 +7881,46 @@ void bli_dgemmsup_rv_zen4_asm_24x3m jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x3 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x3 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x3 tile updated jmp(.DDONE) // jump to end. @@ -7321,8 +7943,46 @@ void bli_dgemmsup_rv_zen4_asm_24x3m label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_BZ + //First 8x3 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x3 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x3 tile updated label(.DDONE) @@ -7346,7 +8006,8 @@ void bli_dgemmsup_rv_zen4_asm_24x3m [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), [n0] "m" (n0), - [m0] "m" (m0) + [m0] "m" (m0), + [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -7437,6 +8098,8 @@ void bli_dgemmsup_rv_zen4_asm_24x2m uint64_t k_iter = (uint64_t)k0 / 8; uint64_t k_left = (uint64_t)k0 % 8; + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + if ( m_iter == 0 ) goto consider_edge_cases; /* For one iteration of this loop, a block of MRxNR is computed @@ -7452,6 +8115,8 @@ void bli_dgemmsup_rv_zen4_asm_24x2m // ------------------------------------------------------------------------- begin_asm() + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register mov(var(a), rax) // load address of a mov(var(cs_a), r10) // load cs_a mov(var(b), rbx) // load address of b @@ -8082,8 +8747,47 @@ void bli_dgemmsup_rv_zen4_asm_24x2m jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x2 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x2 tile updated + + vunpcklpd( zmm29, zmm28, zmm0) + vunpckhpd( zmm29, zmm28, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x2 tile updated jmp(.DDONE) // jump to end. @@ -8103,8 +8807,48 @@ void bli_dgemmsup_rv_zen4_asm_24x2m label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_BZ + //First 8x2 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x2 tile updated + + vunpcklpd( zmm29, zmm28, zmm0) + vunpckhpd( zmm29, zmm28, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x2 tile updated label(.DDONE) @@ -8128,7 +8872,8 @@ void bli_dgemmsup_rv_zen4_asm_24x2m [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), [n0] "m" (n0), - [m0] "m" (m0) + [m0] "m" (m0), + [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -8219,6 +8964,8 @@ void bli_dgemmsup_rv_zen4_asm_24x1m uint64_t k_iter = (uint64_t)k0 / 8; uint64_t k_left = (uint64_t)k0 % 8; + uint8_t mask = (0xff >> (0x8 - (n0 & 7))); // calculate mask based on n_left + if ( m_iter == 0 ) goto consider_edge_cases; /* For one iteration of this loop, a block of MRxNR is computed @@ -8234,6 +8981,8 @@ void bli_dgemmsup_rv_zen4_asm_24x1m // ------------------------------------------------------------------------- begin_asm() + mov(var(mask), rdx) // load mask + kmovw(edx, k(2)) // move mask to k2 register mov(var(a), rax) // load address of a mov(var(cs_a), r10) // load cs_a mov(var(b), rbx) // load address of b @@ -8749,8 +9498,49 @@ void bli_dgemmsup_rv_zen4_asm_24x1m jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C + //First 8x1 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Second 8x1 tile updated + + vunpcklpd( zmm29, zmm28, zmm0) + vunpckhpd( zmm29, zmm28, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C + //Third 8x1 tile updated jmp(.DDONE) // jump to end. @@ -8767,8 +9557,48 @@ void bli_dgemmsup_rv_zen4_asm_24x1m label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_BZ + //First 8x1 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Second 8x1 tile updated + + vunpcklpd( zmm29, zmm28, zmm0) + vunpckhpd( zmm29, zmm28, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_BZ + //Third 8x1 tile updated label(.DDONE) @@ -8792,7 +9622,8 @@ void bli_dgemmsup_rv_zen4_asm_24x1m [rs_c] "m" (rs_c), [cs_c] "m" (cs_c), [n0] "m" (n0), - [m0] "m" (m0) + [m0] "m" (m0), + [mask] "m" (mask) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c index 32b443777..47ba85926 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c @@ -38,10 +38,355 @@ #include "bli_x86_asm_macros.h" #define TAIL_NITER 3 +/** + * Shuffle 2 double-precision elements selected by imm8 from S1 and S2, + * and store the results in D1. + * S1 : 1 9 3 11 5 13 7 15 + * S2 : 2 10 4 12 6 14 8 16 + * D1 : 1 9 5 13 2 10 6 14 + * D2 : 3 11 7 15 4 12 8 16 +*/ +#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \ + VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \ + +/** + * Unpacks and interleave low half and high half of each + * 128-bit lane in S1 and S2 and store into D1 and D2 + * respectively. + * S1 : 1 2 3 4 5 6 7 8 + * S2 : 9 10 11 12 13 14 15 16 + * D1 : 1 9 3 11 5 13 7 15 + * D2 : 2 10 4 12 6 14 8 16 +*/ +#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \ + vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \ + vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \ + vunpckhpd( zmm(S3), zmm(S4), zmm(D4)) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_8_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \ +\ + vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \ + add(r14, rcx) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_7_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_6_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_5_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_4_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_3_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_2_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_1_BZ \ +\ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_8 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm8 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\ + vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\ + add(r14, rcx) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_7 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_6 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_5 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_4 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_3 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_2 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_1 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/ + /* These kernels Assume that A matrix needs to be in col-major order * B matrix can be col/row-major - * C matrix can be col/row-major though support for row-major order will - * be added by a separate commit. + * C matrix can be col/row-major * Prefetch for C is done assuming that C is col-stored. * Prefetch of B is done assuming that the matrix is col-stored. * Prefetch for B and C matrices when row-stored is yet to be added. @@ -93,6 +438,10 @@ void bli_dgemmsup_rv_zen4_asm_24x1 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; // ------------------------------------------------------------------------- begin_asm() @@ -105,6 +454,8 @@ void bli_dgemmsup_rv_zen4_asm_24x1 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -489,8 +840,102 @@ void bli_dgemmsup_rv_zen4_asm_24x1 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x1 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8 + //Second 8x1 tile updated + + vunpcklpd( zmm29, zmm28, zmm0) + vunpckhpd( zmm29, zmm28, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Third 8x1 tile updated jmp(.DDONE) // jump to end. @@ -507,8 +952,100 @@ void bli_dgemmsup_rv_zen4_asm_24x1 label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x1 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8_BZ + //Second 8x1 tile updated + + vunpcklpd( zmm29, zmm28, zmm0) + vunpckhpd( zmm29, zmm28, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -533,7 +1070,8 @@ void bli_dgemmsup_rv_zen4_asm_24x1 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -594,6 +1132,11 @@ void bli_dgemmsup_rv_zen4_asm_16x1 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -606,6 +1149,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -934,8 +1479,90 @@ void bli_dgemmsup_rv_zen4_asm_16x1 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x1 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //8x1 tile updated jmp(.DDONE) // jump to end. @@ -951,8 +1578,88 @@ void bli_dgemmsup_rv_zen4_asm_16x1 label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x1 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -977,7 +1684,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1038,6 +1746,11 @@ void bli_dgemmsup_rv_zen4_asm_8x1 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -1050,6 +1763,8 @@ void bli_dgemmsup_rv_zen4_asm_8x1 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -1322,8 +2037,78 @@ void bli_dgemmsup_rv_zen4_asm_8x1 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //8x1 tile updated jmp(.DDONE) // jump to end. @@ -1338,8 +2123,75 @@ void bli_dgemmsup_rv_zen4_asm_8x1 label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -1364,7 +2216,8 @@ void bli_dgemmsup_rv_zen4_asm_8x1 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c index 898035c4f..bb12cff69 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c @@ -38,10 +38,355 @@ #include "bli_x86_asm_macros.h" #define TAIL_NITER 3 +/** + * Shuffle 2 double-precision elements selected by imm8 from S1 and S2, + * and store the results in D1. + * S1 : 1 9 3 11 5 13 7 15 + * S2 : 2 10 4 12 6 14 8 16 + * D1 : 1 9 5 13 2 10 6 14 + * D2 : 3 11 7 15 4 12 8 16 +*/ +#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \ + VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \ + +/** + * Unpacks and interleave low half and high half of each + * 128-bit lane in S1 and S2 and store into D1 and D2 + * respectively. + * S1 : 1 2 3 4 5 6 7 8 + * S2 : 9 10 11 12 13 14 15 16 + * D1 : 1 9 3 11 5 13 7 15 + * D2 : 2 10 4 12 6 14 8 16 +*/ +#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \ + vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \ + vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \ + vunpckhpd( zmm(S3), zmm(S4), zmm(D4)) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_8_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \ +\ + vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \ + add(r14, rcx) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_7_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_6_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_5_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_4_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_3_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_2_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_1_BZ \ +\ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_8 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm8 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\ + vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\ + add(r14, rcx) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_7 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_6 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_5 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_4 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_3 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_2 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_1 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/ + /* These kernels Assume that A matrix needs to be in col-major order * B matrix can be col/row-major - * C matrix can be col/row-major though support for row-major order will - * be added by a separate commit. + * C matrix can be col/row-major * Prefetch for C is done assuming that C is col-stored. * Prefetch of B is done assuming that the matrix is col-stored. * Prefetch for B and C matrices when row-stored is yet to be added. @@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x2 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x2 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -607,8 +959,102 @@ void bli_dgemmsup_rv_zen4_asm_24x2 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x2 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8 + //Second 8x2 tile updated + + vunpcklpd( zmm29, zmm28, zmm0) + vunpckhpd( zmm29, zmm28, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Third 8x2 tile updated jmp(.DDONE) // jump to end. @@ -628,8 +1074,100 @@ void bli_dgemmsup_rv_zen4_asm_24x2 label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x2 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8_BZ + //Second 8x2 tile updated + + vunpcklpd( zmm29, zmm28, zmm0) + vunpckhpd( zmm29, zmm28, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -654,7 +1192,8 @@ void bli_dgemmsup_rv_zen4_asm_24x2 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -715,6 +1254,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -727,6 +1271,8 @@ void bli_dgemmsup_rv_zen4_asm_16x2 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -1143,8 +1689,90 @@ void bli_dgemmsup_rv_zen4_asm_16x2 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x2 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Second 8x2 tile updated jmp(.DDONE) // jump to end. @@ -1162,8 +1790,88 @@ void bli_dgemmsup_rv_zen4_asm_16x2 label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x2 tile updated + + vunpcklpd( zmm9, zmm7, zmm0) + vunpckhpd( zmm9, zmm7, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -1188,7 +1896,8 @@ void bli_dgemmsup_rv_zen4_asm_16x2 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1249,6 +1958,11 @@ void bli_dgemmsup_rv_zen4_asm_8x2 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -1261,6 +1975,8 @@ void bli_dgemmsup_rv_zen4_asm_8x2 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -1591,8 +2307,78 @@ void bli_dgemmsup_rv_zen4_asm_8x2 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //8x2 tile updated jmp(.DDONE) // jump to end. @@ -1608,8 +2394,75 @@ void bli_dgemmsup_rv_zen4_asm_8x2 label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + vunpcklpd( zmm8, zmm6, zmm0) + vunpckhpd( zmm8, zmm6, zmm1) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -1634,7 +2487,8 @@ void bli_dgemmsup_rv_zen4_asm_8x2 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c index 4f5466f84..d84e41597 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c @@ -38,10 +38,355 @@ #include "bli_x86_asm_macros.h" #define TAIL_NITER 3 +/** + * Shuffle 2 double-precision elements selected by imm8 from S1 and S2, + * and store the results in D1. + * S1 : 1 9 3 11 5 13 7 15 + * S2 : 2 10 4 12 6 14 8 16 + * D1 : 1 9 5 13 2 10 6 14 + * D2 : 3 11 7 15 4 12 8 16 +*/ +#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \ + VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \ + +/** + * Unpacks and interleave low half and high half of each + * 128-bit lane in S1 and S2 and store into D1 and D2 + * respectively. + * S1 : 1 2 3 4 5 6 7 8 + * S2 : 9 10 11 12 13 14 15 16 + * D1 : 1 9 3 11 5 13 7 15 + * D2 : 2 10 4 12 6 14 8 16 +*/ +#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \ + vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \ + vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \ + vunpckhpd( zmm(S3), zmm(S4), zmm(D4)) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_8_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \ +\ + vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \ + add(r14, rcx) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_7_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_6_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_5_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_4_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_3_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_2_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_1_BZ \ +\ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_8 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm8 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\ + vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\ + add(r14, rcx) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_7 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_6 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_5 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_4 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_3 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_2 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_1 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/ + /* These kernels Assume that A matrix needs to be in col-major order * B matrix can be col/row-major - * C matrix can be col/row-major though support for row-major order will - * be added by a separate commit. + * C matrix can be col/row-major * Prefetch for C is done assuming that C is col-stored. * Prefetch of B is done assuming that the matrix is col-stored. * Prefetch for B and C matrices when row-stored is yet to be added. @@ -93,6 +438,12 @@ void bli_dgemmsup_rv_zen4_asm_24x3 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + + // ------------------------------------------------------------------------- begin_asm() @@ -105,6 +456,8 @@ void bli_dgemmsup_rv_zen4_asm_24x3 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -725,8 +1078,99 @@ void bli_dgemmsup_rv_zen4_asm_24x3 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x3 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8 + //Second 8x3 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Third 8x3 tile updated jmp(.DDONE) // jump to end. @@ -749,8 +1193,97 @@ void bli_dgemmsup_rv_zen4_asm_24x3 label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x3 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8_BZ + //Second 8x3 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -775,7 +1308,8 @@ void bli_dgemmsup_rv_zen4_asm_24x3 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -836,6 +1370,11 @@ void bli_dgemmsup_rv_zen4_asm_16x3 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -848,6 +1387,8 @@ void bli_dgemmsup_rv_zen4_asm_16x3 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -1352,8 +1893,88 @@ void bli_dgemmsup_rv_zen4_asm_16x3 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x3 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Second 8x3 tile updated jmp(.DDONE) // jump to end. @@ -1373,8 +1994,86 @@ void bli_dgemmsup_rv_zen4_asm_16x3 label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x3 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -1399,7 +2098,8 @@ void bli_dgemmsup_rv_zen4_asm_16x3 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1460,7 +2160,12 @@ void bli_dgemmsup_rv_zen4_asm_8x3 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; - // ------------------------------------------------------------------------- + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + + // ------------------------------------------------------------------------- begin_asm() mov(var(a), rax) // load address of a @@ -1472,6 +2177,8 @@ void bli_dgemmsup_rv_zen4_asm_8x3 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -1860,8 +2567,77 @@ void bli_dgemmsup_rv_zen4_asm_8x3 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //8x3 tile updated jmp(.DDONE) // jump to end. @@ -1878,8 +2654,75 @@ void bli_dgemmsup_rv_zen4_asm_8x3 label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) + //8x3 tile updated label(.DDONE) @@ -1904,7 +2747,8 @@ void bli_dgemmsup_rv_zen4_asm_8x3 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c index fb067e685..e77e38054 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c @@ -38,10 +38,355 @@ #include "bli_x86_asm_macros.h" #define TAIL_NITER 3 +/** + * Shuffle 2 double-precision elements selected by imm8 from S1 and S2, + * and store the results in D1. + * S1 : 1 9 3 11 5 13 7 15 + * S2 : 2 10 4 12 6 14 8 16 + * D1 : 1 9 5 13 2 10 6 14 + * D2 : 3 11 7 15 4 12 8 16 +*/ +#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \ + VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \ + +/** + * Unpacks and interleave low half and high half of each + * 128-bit lane in S1 and S2 and store into D1 and D2 + * respectively. + * S1 : 1 2 3 4 5 6 7 8 + * S2 : 9 10 11 12 13 14 15 16 + * D1 : 1 9 3 11 5 13 7 15 + * D2 : 2 10 4 12 6 14 8 16 +*/ +#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \ + vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \ + vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \ + vunpckhpd( zmm(S3), zmm(S4), zmm(D4)) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_8_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \ +\ + vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \ + add(r14, rcx) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_7_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_6_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_5_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_4_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_3_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_2_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_1_BZ \ +\ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_8 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm8 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\ + vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\ + add(r14, rcx) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_7 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_6 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_5 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_4 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_3 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_2 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_1 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/ + /* These kernels Assume that A matrix needs to be in col-major order * B matrix can be col/row-major - * C matrix can be col/row-major though support for row-major order will - * be added by a separate commit. + * C matrix can be col/row-major * Prefetch for C is done assuming that C is col-stored. * Prefetch of B is done assuming that the matrix is col-stored. * Prefetch for B and C matrices when row-stored is yet to be added. @@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x4 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x4 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -845,8 +1197,96 @@ void bli_dgemmsup_rv_zen4_asm_24x4 jmp(.DDONE) // jump to end. label(.DROWSTORED) + lea(mem(rsi, rsi, 2), r12) + lea(mem(r12, rsi, 2), r13) + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x4 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8 + //Second 8x4 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Third 8x4 tile updated jmp(.DDONE) // jump to end. @@ -872,8 +1312,94 @@ void bli_dgemmsup_rv_zen4_asm_24x4 label(.DROWSTORBZ) + lea(mem(rsi, rsi, 2), r12) + lea(mem(r12, rsi, 2), r13) + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x4 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8_BZ + //Second 8x4 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -898,7 +1424,8 @@ void bli_dgemmsup_rv_zen4_asm_24x4 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -959,6 +1486,11 @@ void bli_dgemmsup_rv_zen4_asm_16x4 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -971,6 +1503,8 @@ void bli_dgemmsup_rv_zen4_asm_16x4 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -1565,8 +2099,88 @@ void bli_dgemmsup_rv_zen4_asm_16x4 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x4 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Second 8x4 tile updated jmp(.DDONE) // jump to end. @@ -1588,8 +2202,86 @@ void bli_dgemmsup_rv_zen4_asm_16x4 label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x4 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -1614,7 +2306,8 @@ void bli_dgemmsup_rv_zen4_asm_16x4 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1675,6 +2368,11 @@ void bli_dgemmsup_rv_zen4_asm_8x4 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -1687,6 +2385,8 @@ void bli_dgemmsup_rv_zen4_asm_8x4 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -2135,8 +2835,77 @@ void bli_dgemmsup_rv_zen4_asm_8x4 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //8x4 tile updated jmp(.DDONE) // jump to end. @@ -2154,8 +2923,74 @@ void bli_dgemmsup_rv_zen4_asm_8x4 label(.DROWSTORBZ) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -2180,7 +3015,8 @@ void bli_dgemmsup_rv_zen4_asm_8x4 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c index 991fe53be..eba43abe6 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c @@ -38,10 +38,355 @@ #include "bli_x86_asm_macros.h" #define TAIL_NITER 3 +/** + * Shuffle 2 double-precision elements selected by imm8 from S1 and S2, + * and store the results in D1. + * S1 : 1 9 3 11 5 13 7 15 + * S2 : 2 10 4 12 6 14 8 16 + * D1 : 1 9 5 13 2 10 6 14 + * D2 : 3 11 7 15 4 12 8 16 +*/ +#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \ + VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \ + +/** + * Unpacks and interleave low half and high half of each + * 128-bit lane in S1 and S2 and store into D1 and D2 + * respectively. + * S1 : 1 2 3 4 5 6 7 8 + * S2 : 9 10 11 12 13 14 15 16 + * D1 : 1 9 3 11 5 13 7 15 + * D2 : 2 10 4 12 6 14 8 16 +*/ +#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \ + vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \ + vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \ + vunpckhpd( zmm(S3), zmm(S4), zmm(D4)) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_8_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \ +\ + vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \ + add(r14, rcx) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_7_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_6_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_5_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_4_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_3_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_2_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_1_BZ \ +\ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_8 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm8 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\ + vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\ + add(r14, rcx) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_7 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_6 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_5 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_4 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_3 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_2 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_1 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/ + /* These kernels Assume that A matrix needs to be in col-major order * B matrix can be col/row-major - * C matrix can be col/row-major though support for row-major order will - * be added by a separate commit. + * C matrix can be col/row-major * Prefetch for C is done assuming that C is col-stored. * Prefetch of B is done assuming that the matrix is col-stored. * Prefetch for B and C matrices when row-stored is yet to be added. @@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x5 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x5 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -998,8 +1350,103 @@ void bli_dgemmsup_rv_zen4_asm_24x5 jmp(.DDONE) // jump to end. label(.DROWSTORED) + lea(mem(rsi, rsi, 2), r12) + lea(mem(r12, rsi, 2), r13) + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x5 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8 + //Second 8x5 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm25, zmm24, zmm0) + vunpckhpd(zmm25, zmm24, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Third 8x8 tile updated jmp(.DDONE) // jump to end. @@ -1028,8 +1475,101 @@ void bli_dgemmsup_rv_zen4_asm_24x5 label(.DROWSTORBZ) + lea(mem(rsi, rsi, 2), r12) + lea(mem(r12, rsi, 2), r13) + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x5 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8_BZ + //Second 8x5 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm25, zmm24, zmm0) + vunpckhpd(zmm25, zmm24, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -1054,7 +1594,8 @@ void bli_dgemmsup_rv_zen4_asm_24x5 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1115,6 +1656,11 @@ void bli_dgemmsup_rv_zen4_asm_16x5 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -1127,6 +1673,8 @@ void bli_dgemmsup_rv_zen4_asm_16x5 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -1844,8 +2392,92 @@ void bli_dgemmsup_rv_zen4_asm_16x5 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // r12 = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // r13 = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x5 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Second 8x5 tile updated jmp(.DDONE) // jump to end. @@ -1869,8 +2501,90 @@ void bli_dgemmsup_rv_zen4_asm_16x5 label(.DROWSTORBZ) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x5 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -1895,7 +2609,8 @@ void bli_dgemmsup_rv_zen4_asm_16x5 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1956,6 +2671,11 @@ void bli_dgemmsup_rv_zen4_asm_8x5 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -1968,6 +2688,8 @@ void bli_dgemmsup_rv_zen4_asm_8x5 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -2509,8 +3231,79 @@ void bli_dgemmsup_rv_zen4_asm_8x5 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //8x5 tile updated jmp(.DDONE) // jump to end. @@ -2529,8 +3322,76 @@ void bli_dgemmsup_rv_zen4_asm_8x5 label(.DROWSTORBZ) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -2555,7 +3416,8 @@ void bli_dgemmsup_rv_zen4_asm_8x5 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c index dc874680c..449b7b232 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c @@ -38,10 +38,355 @@ #include "bli_x86_asm_macros.h" #define TAIL_NITER 3 +/** + * Shuffle 2 double-precision elements selected by imm8 from S1 and S2, + * and store the results in D1. + * S1 : 1 9 3 11 5 13 7 15 + * S2 : 2 10 4 12 6 14 8 16 + * D1 : 1 9 5 13 2 10 6 14 + * D2 : 3 11 7 15 4 12 8 16 +*/ +#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \ + VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \ + +/** + * Unpacks and interleave low half and high half of each + * 128-bit lane in S1 and S2 and store into D1 and D2 + * respectively. + * S1 : 1 2 3 4 5 6 7 8 + * S2 : 9 10 11 12 13 14 15 16 + * D1 : 1 9 3 11 5 13 7 15 + * D2 : 2 10 4 12 6 14 8 16 +*/ +#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \ + vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \ + vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \ + vunpckhpd( zmm(S3), zmm(S4), zmm(D4)) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_8_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \ +\ + vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \ + add(r14, rcx) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_7_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_6_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_5_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_4_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_3_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_2_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_1_BZ \ +\ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_8 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm8 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\ + vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\ + add(r14, rcx) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_7 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_6 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_5 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_4 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_3 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_2 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_1 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/ + /* These kernels Assume that A matrix needs to be in col-major order * B matrix can be col/row-major - * C matrix can be col/row-major though support for row-major order will - * be added by a separate commit. + * C matrix can be col/row-major * Prefetch for C is done assuming that C is col-stored. * Prefetch of B is done assuming that the matrix is col-stored. * Prefetch for B and C matrices when row-stored is yet to be added. @@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x6 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x6 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -1116,8 +1468,102 @@ void bli_dgemmsup_rv_zen4_asm_24x6 jmp(.DDONE) // jump to end. label(.DROWSTORED) + lea(mem(rsi, rsi, 2), r12) + lea(mem(r12, rsi, 2), r13) + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x6 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8 + //Second 8x6 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm25, zmm24, zmm0) + vunpckhpd(zmm25, zmm24, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Third 7x6 tile updated jmp(.DDONE) // jump to end. @@ -1149,8 +1595,100 @@ void bli_dgemmsup_rv_zen4_asm_24x6 label(.DROWSTORBZ) + lea(mem(rsi, rsi, 2), r12) + lea(mem(r12, rsi, 2), r13) + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x6 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8_BZ + //Second 8x6 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm25, zmm24, zmm0) + vunpckhpd(zmm25, zmm24, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -1175,7 +1713,8 @@ void bli_dgemmsup_rv_zen4_asm_24x6 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1236,6 +1775,11 @@ void bli_dgemmsup_rv_zen4_asm_16x6 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -1248,6 +1792,8 @@ void bli_dgemmsup_rv_zen4_asm_16x6 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -2053,8 +2599,92 @@ void bli_dgemmsup_rv_zen4_asm_16x6 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x6 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Second 7x6 tile updated jmp(.DDONE) // jump to end. @@ -2080,8 +2710,90 @@ void bli_dgemmsup_rv_zen4_asm_16x6 label(.DROWSTORBZ) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x6 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + vunpcklpd(zmm17, zmm15, zmm0) + vunpckhpd(zmm17, zmm15, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -2106,7 +2818,8 @@ void bli_dgemmsup_rv_zen4_asm_16x6 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -2167,6 +2880,11 @@ void bli_dgemmsup_rv_zen4_asm_8x6 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -2179,6 +2897,8 @@ void bli_dgemmsup_rv_zen4_asm_8x6 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -2778,8 +3498,79 @@ void bli_dgemmsup_rv_zen4_asm_8x6 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //7x6 tile updated jmp(.DDONE) // jump to end. @@ -2799,8 +3590,76 @@ void bli_dgemmsup_rv_zen4_asm_8x6 label(.DROWSTORBZ) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + vunpcklpd(zmm16, zmm14, zmm0) + vunpckhpd(zmm16, zmm14, zmm1) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -2825,7 +3684,8 @@ void bli_dgemmsup_rv_zen4_asm_8x6 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c index bc8bf3d26..95b3ca452 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c @@ -38,10 +38,355 @@ #include "bli_x86_asm_macros.h" #define TAIL_NITER 3 +/** + * Shuffle 2 double-precision elements selected by imm8 from S1 and S2, + * and store the results in D1. + * S1 : 1 9 3 11 5 13 7 15 + * S2 : 2 10 4 12 6 14 8 16 + * D1 : 1 9 5 13 2 10 6 14 + * D2 : 3 11 7 15 4 12 8 16 +*/ +#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \ + VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \ + +/** + * Unpacks and interleave low half and high half of each + * 128-bit lane in S1 and S2 and store into D1 and D2 + * respectively. + * S1 : 1 2 3 4 5 6 7 8 + * S2 : 9 10 11 12 13 14 15 16 + * D1 : 1 9 3 11 5 13 7 15 + * D2 : 2 10 4 12 6 14 8 16 +*/ +#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \ + vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \ + vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \ + vunpckhpd( zmm(S3), zmm(S4), zmm(D4)) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_8_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) \ +\ + vmovupd( zmm8, mem(rcx, rdx, 1) MASK_(k(3))) \ + add(r14, rcx) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_7_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) \ +\ + vmovupd( zmm3, mem(rcx, r12, 2) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_6_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) \ +\ + vmovupd( zmm5, mem(rcx, r13, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_5_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) \ +\ + vmovupd( zmm1, mem(rcx, rsi, 4) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_4_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) \ +\ + vmovupd( zmm6, mem(rcx, r12, 1) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_3_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) \ +\ + vmovupd( zmm2, mem(rcx, rsi, 2) MASK_(k(3)) ) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_2_BZ \ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) \ +\ + vmovupd( zmm4, mem(rcx, rsi, 1) MASK_(k(3))) + +/** + * mask register is set, stores the fma result back to C +*/ +#define UPDATE_MASKED_C_1_BZ \ +\ + vmovupd( zmm0, mem(rcx) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_8 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( mem(rcx, rdx, 1, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm8 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3)))\ + vmovupd( zmm8, (rcx, rdx, 1) MASK_(k(3)))\ + add(r14, rcx) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_7 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( mem(rcx, r12, 2, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm3 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3)))\ + vmovupd( zmm3, (rcx, r12, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_6 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( mem(rcx, r13, 1, 0), zmm18 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm18,zmm5 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3)))\ + vmovupd( zmm5, (rcx, r13, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_5 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( mem(rcx, rsi, 4, 0), zmm14 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm14,zmm1 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3)))\ + vmovupd( zmm1, (rcx, rsi, 4) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_4 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( mem(rcx, r12, 1, 0), zmm16 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm16,zmm6 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3)))\ + vmovupd( zmm6, (rcx, r12, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_3 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( mem(rcx, rsi, 2, 0), zmm12 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm12,zmm2 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3)))\ + vmovupd( zmm2, (rcx, rsi, 2) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_2 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( mem(rcx, rsi, 1, 0), zmm10 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm10,zmm4 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/\ + vmovupd( zmm4, (rcx, rsi, 1) MASK_(k(3))) + +/** + * Loads elements from C row only if correspondnig bits in + * mask register is set, Scales it with Beta and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_MASKED_C_1 \ +\ + vmovupd( mem(rcx), zmm30 MASK_(k(3)) MASK_(z) ) \ + vfmadd231pd( zmm31,zmm30,zmm0 ) \ +\ + vmovupd( zmm0, (rcx) MASK_(k(3))) /*Stores back to C*/ + /* These kernels Assume that A matrix needs to be in col-major order * B matrix can be col/row-major - * C matrix can be col/row-major though support for row-major order will - * be added by a separate commit. + * C matrix can be col/row-major * Prefetch for C is done assuming that C is col-stored. * Prefetch of B is done assuming that the matrix is col-stored. * Prefetch for B and C matrices when row-stored is yet to be added. @@ -93,6 +438,11 @@ void bli_dgemmsup_rv_zen4_asm_24x7 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -105,6 +455,8 @@ void bli_dgemmsup_rv_zen4_asm_24x7 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -1234,8 +1586,100 @@ void bli_dgemmsup_rv_zen4_asm_24x7 jmp(.DDONE) // jump to end. label(.DROWSTORED) + lea(mem(rsi, rsi, 2), r12) + lea(mem(r12, rsi, 2), r13) + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x7 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8 + //Second 8x7 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Third 7x8 tile updated jmp(.DDONE) // jump to end. @@ -1270,8 +1714,98 @@ void bli_dgemmsup_rv_zen4_asm_24x7 label(.DROWSTORBZ) + lea(mem(rsi, rsi, 2), r12) + lea(mem(r12, rsi, 2), r13) + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x7 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_MASKED_C_8_BZ + //Second 8x7 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -1296,7 +1830,8 @@ void bli_dgemmsup_rv_zen4_asm_24x7 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1357,6 +1892,11 @@ void bli_dgemmsup_rv_zen4_asm_16x7 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -1369,6 +1909,8 @@ void bli_dgemmsup_rv_zen4_asm_16x7 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -2263,8 +2805,90 @@ void bli_dgemmsup_rv_zen4_asm_16x7 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_MASKED_C_8 + //First 8x7 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Third 7x8 tile updated jmp(.DDONE) // jump to end. @@ -2292,8 +2916,88 @@ void bli_dgemmsup_rv_zen4_asm_16x7 label(.DROWSTORBZ) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_MASKED_C_8_BZ + //First 8x7 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -2318,7 +3022,8 @@ void bli_dgemmsup_rv_zen4_asm_16x7 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -2379,6 +3084,11 @@ void bli_dgemmsup_rv_zen4_asm_8x7 // So, mask becomes 0xff(11111111) if (mask == 0) mask = 0xff; + uint8_t mask_n0 = 0xff >> (0x8 - (n0 & 7)); // calculate mask based on n_left + // For special cases where n_left = 8, all 8 elements have to be loaded/stored + // So, mask becomes 0xff(11111111) + if (mask_n0 == 0) mask_n0 = 0xff; + // ------------------------------------------------------------------------- begin_asm() @@ -2391,6 +3101,8 @@ void bli_dgemmsup_rv_zen4_asm_8x7 mov(var(cs_c), rdi) // load cs_c mov(var(mask), rdx) // load mask kmovw(edx, k(2)) // move mask to k2 register + mov(var(mask_n0), rdx) // load mask + kmovw(edx, k(3)) // move mask to k3 register lea(mem(, r8, 8), r8) // rs_b *= sizeof(double) lea(mem(, r9, 8), r9) // cs_b *= sizeof(double) lea(mem(, r10, 8), r10) // cs_a *= sizeof(double) @@ -3048,8 +3760,78 @@ void bli_dgemmsup_rv_zen4_asm_8x7 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_MASKED_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_MASKED_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_MASKED_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_MASKED_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_MASKED_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_MASKED_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_MASKED_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_MASKED_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Third 7x7 tile updated jmp(.DDONE) // jump to end. @@ -3070,8 +3852,75 @@ void bli_dgemmsup_rv_zen4_asm_8x7 label(.DROWSTORBZ) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_MASKED_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_MASKED_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_MASKED_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_MASKED_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_MASKED_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_MASKED_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_MASKED_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_MASKED_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -3096,7 +3945,8 @@ void bli_dgemmsup_rv_zen4_asm_8x7 [cs_c] "m" (cs_c), [n0] "m" (n0), [m0] "m" (m0), - [mask] "m" (mask) + [mask] "m" (mask), + [mask_n0] "m" (mask_n0) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c index 8bf041cbe..60f23206b 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c @@ -38,10 +38,318 @@ #include "bli_x86_asm_macros.h" #define TAIL_NITER 3 +/** + * Shuffle 2 double-precision elements selected by imm8 from S1 and S2, + * and store the results in D1. + * S1 : 1 9 3 11 5 13 7 15 + * S2 : 2 10 4 12 6 14 8 16 + * D1 : 1 9 5 13 2 10 6 14 + * D2 : 3 11 7 15 4 12 8 16 +*/ +#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \ + VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \ + VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \ + +/** + * Unpacks and interleave low half and high half of each + * 128-bit lane in S1 and S2 and store into D1 and D2 + * respectively. + * S1 : 1 2 3 4 5 6 7 8 + * S2 : 9 10 11 12 13 14 15 16 + * D1 : 1 9 3 11 5 13 7 15 + * D2 : 2 10 4 12 6 14 8 16 +*/ +#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \ +\ + vunpcklpd( zmm(S1), zmm(S2), zmm(D1)) \ + vunpckhpd( zmm(S1), zmm(S2), zmm(D2)) \ + vunpcklpd( zmm(S3), zmm(S4), zmm(D3)) \ + vunpckhpd( zmm(S3), zmm(S4), zmm(D4)) + +/** + * Stores fma result back to C +*/ +#define UPDATE_C_8_BZ \ +\ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \ +\ + vmovupd( zmm4, (rcx, rsi, 1) ) \ +\ + vmovupd( zmm2, (rcx, rsi, 2) ) \ +\ + vmovupd( zmm6, (rcx, r12, 1) ) \ +\ + vmovupd( zmm1, (rcx, rsi, 4) ) \ +\ + vmovupd( zmm5, (rcx, r13, 1) ) \ +\ + vmovupd( zmm3, (rcx, r12, 2) ) \ +\ + vmovupd( zmm8, (rcx, rdx, 1) ) \ + add(r14, rcx) + +/** + * Stores fma result back to C +*/ +#define UPDATE_C_7_BZ \ +\ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \ +\ + vmovupd( zmm4, (rcx, rsi, 1) ) \ +\ + vmovupd( zmm2, (rcx, rsi, 2) ) \ +\ + vmovupd( zmm6, (rcx, r12, 1) ) \ +\ + vmovupd( zmm1, (rcx, rsi, 4) ) \ +\ + vmovupd( zmm5, (rcx, r13, 1) ) \ +\ + vmovupd( zmm3, (rcx, r12, 2) ) + +/** + * Stores fma result back to C +*/ +#define UPDATE_C_6_BZ \ +\ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \ +\ + vmovupd( zmm4, (rcx, rsi, 1) ) \ +\ + vmovupd( zmm2, (rcx, rsi, 2) ) \ +\ + vmovupd( zmm6, (rcx, r12, 1) ) \ +\ + vmovupd( zmm1, (rcx, rsi, 4) ) \ +\ + vmovupd( zmm5, (rcx, r13, 1) ) + +/** + * Stores fma result back to C +*/ +#define UPDATE_C_5_BZ \ +\ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \ +\ + vmovupd( zmm4, (rcx, rsi, 1) ) \ +\ + vmovupd( zmm2, (rcx, rsi, 2) ) \ +\ + vmovupd( zmm6, (rcx, r12, 1) ) \ +\ + vmovupd( zmm1, (rcx, rsi, 4) ) + +/** + * Stores fma result back to C +*/ +#define UPDATE_C_4_BZ \ +\ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \ +\ + vmovupd( zmm4, (rcx, rsi, 1) ) \ +\ + vmovupd( zmm2, (rcx, rsi, 2) ) \ +\ + vmovupd( zmm6, (rcx, r12, 1) ) + +/** + * Stores fma result back to C +*/ +#define UPDATE_C_3_BZ \ +\ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \ +\ + vmovupd( zmm4, (rcx, rsi, 1) ) \ +\ + vmovupd( zmm2, (rcx, rsi, 2) ) + +/** + * Stores fma result back to C +*/ +#define UPDATE_C_2_BZ \ +\ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \ +\ + vmovupd( zmm4, (rcx, rsi, 1) ) + +/** + * Stores fma result back to C +*/ +#define UPDATE_C_1_BZ \ +\ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/ \ + +/** + * Loads elements from C row, Scales it with Beta + * and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_C_8 \ +\ + vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/\ +\ + vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \ + vmovupd( zmm4, (rcx, rsi, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \ + vmovupd( zmm2, (rcx, rsi, 2) )\ +\ + vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \ + vmovupd( zmm6, (rcx, r12, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \ + vmovupd( zmm1, (rcx, rsi, 4) )\ +\ + vfmadd231pd( mem(rcx, r13, 1),zmm31,zmm5 ) \ + vmovupd( zmm5, (rcx, r13, 1) )\ +\ + vfmadd231pd( mem(rcx, r12, 2),zmm31,zmm3 ) \ + vmovupd( zmm3, (rcx, r12, 2) )\ +\ + vfmadd231pd( mem(rcx, rdx, 1),zmm31,zmm8 ) \ + vmovupd( zmm8, (rcx, rdx, 1) )\ + add(r14, rcx) + +/** + * Loads elements from C row, Scales it with Beta + * and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_C_7 \ +\ + vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/\ +\ + vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \ + vmovupd( zmm4, (rcx, rsi, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \ + vmovupd( zmm2, (rcx, rsi, 2) )\ +\ + vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \ + vmovupd( zmm6, (rcx, r12, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \ + vmovupd( zmm1, (rcx, rsi, 4) )\ +\ + vfmadd231pd( mem(rcx, r13, 1),zmm31,zmm5 ) \ + vmovupd( zmm5, (rcx, r13, 1) )\ +\ + vfmadd231pd( mem(rcx, r12, 2),zmm31,zmm3 ) \ + vmovupd( zmm3, (rcx, r12, 2) ) + +/** + * Loads elements from C row, Scales it with Beta + * and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_C_6 \ +\ + vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/\ +\ + vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \ + vmovupd( zmm4, (rcx, rsi, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \ + vmovupd( zmm2, (rcx, rsi, 2) )\ +\ + vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \ + vmovupd( zmm6, (rcx, r12, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \ + vmovupd( zmm1, (rcx, rsi, 4) )\ +\ + vfmadd231pd( mem(rcx, r13, 1),zmm31,zmm5 ) \ + vmovupd( zmm5, (rcx, r13, 1) ) + +/** + * Loads elements from C row, Scales it with Beta + * and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_C_5 \ +\ + vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/\ +\ + vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \ + vmovupd( zmm4, (rcx, rsi, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \ + vmovupd( zmm2, (rcx, rsi, 2) )\ +\ + vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \ + vmovupd( zmm6, (rcx, r12, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 4),zmm31,zmm1 ) \ + vmovupd( zmm1, (rcx, rsi, 4) ) + +/** + * Loads elements from C row, Scales it with Beta + * and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_C_4 \ +\ + vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/\ +\ + vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \ + vmovupd( zmm4, (rcx, rsi, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \ + vmovupd( zmm2, (rcx, rsi, 2) )\ +\ + vfmadd231pd( mem(rcx, r12, 1),zmm31,zmm6 ) \ + vmovupd( zmm6, (rcx, r12, 1) ) + +/** + * Loads elements from C row, Scales it with Beta + * and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_C_3 \ +\ + vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/\ +\ + vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \ + vmovupd( zmm4, (rcx, rsi, 1) )\ +\ + vfmadd231pd( mem(rcx, rsi, 2),zmm31,zmm2 ) \ + vmovupd( zmm2, (rcx, rsi, 2) ) + +/** + * Loads elements from C row, Scales it with Beta + * and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_C_2 \ +\ + vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/\ +\ + vfmadd231pd( mem(rcx, rsi, 1),zmm31,zmm4 ) \ + vmovupd( zmm4, (rcx, rsi, 1) ) + +/** + * Loads elements from C row, Scales it with Beta + * and adds FMA result to it. + * Stores back the C row. +*/ +#define UPDATE_C_1 \ +\ + vfmadd231pd( mem(rcx),zmm31,zmm0 ) /*Scale by Beta and add it to fma result*/ \ + vmovupd( zmm0, (rcx) ) /*Stores back to C*/ + /* These kernels Assume that A matrix needs to be in col-major order * B matrix can be col/row-major - * C matrix can be col/row-major though support for row-major order will - * be added by a separate commit. + * C matrix can be col/row-major * Prefetch for C is done assuming that C is col-stored. * Prefetch of B is done assuming that the matrix is col-stored. * Prefetch for B and C matrices when row-stored is yet to be added. @@ -1352,8 +1660,102 @@ void bli_dgemmsup_rv_zen4_asm_24x8 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_C_8 + //First 8x8 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_C_8 + //Second 8x8 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //Third 7x8 tile updated jmp(.DDONE) // jump to end. @@ -1391,8 +1793,100 @@ void bli_dgemmsup_rv_zen4_asm_24x8 label(.DROWSTORBZ) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_C_8_BZ + //First 8x8 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + UPDATE_C_8_BZ + //Second 8x8 tile updated + + UNPACK_LO_HIGH(29, 28, 0, 1, 27, 26, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(25, 24, 0, 1, 23, 22, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(16), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -2473,8 +2967,90 @@ void bli_dgemmsup_rv_zen4_asm_16x8 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + UPDATE_C_8 + //First 8x8 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //7x8 tile updated jmp(.DDONE) // jump to end. @@ -2504,8 +3080,88 @@ void bli_dgemmsup_rv_zen4_asm_16x8 label(.DROWSTORBZ) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + UPDATE_C_8_BZ + //First 8x8 tile updated + + UNPACK_LO_HIGH(9, 7, 0, 1, 13, 11, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 7, 9) + + UNPACK_LO_HIGH(17, 15, 0, 1, 21, 19, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 7, 4, 5, 12, 9, 6, 8) + + mov(var(m0), rdi) + sub(imm(8), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE) @@ -3318,8 +3974,78 @@ void bli_dgemmsup_rv_zen4_asm_8x8 jmp(.DDONE) // jump to end. label(.DROWSTORED) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + vbroadcastsd(mem(rax), zmm31) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8) + cmp(imm(7), rdi) + JZ(.UPDATE7) + cmp(imm(6), rdi) + JZ(.UPDATE6) + cmp(imm(5), rdi) + JZ(.UPDATE5) + cmp(imm(4), rdi) + JZ(.UPDATE4) + cmp(imm(3), rdi) + JZ(.UPDATE3) + cmp(imm(2), rdi) + JZ(.UPDATE2) + cmp(imm(1), rdi) + JZ(.UPDATE1) + cmp(imm(0), rdi) + JZ(.UPDATE0) + + LABEL(.UPDATE8) + UPDATE_C_8 + jmp(.DDONE) + + LABEL(.UPDATE7) + UPDATE_C_7 + jmp(.DDONE) + + LABEL(.UPDATE6) + UPDATE_C_6 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5) + UPDATE_C_5 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4) + UPDATE_C_4 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3) + UPDATE_C_3 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2) + UPDATE_C_2 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1) + UPDATE_C_1 + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0) + //7x8 tile updated jmp(.DDONE) // jump to end. @@ -3341,8 +4067,75 @@ void bli_dgemmsup_rv_zen4_asm_8x8 label(.DROWSTORBZ) + // rdx = 3*rs_c + lea(mem(rsi, rsi, 2), r12) + // rdx = 5*rs_c + lea(mem(r12, rsi, 2), r13) + // rdx = 7*rs_c + lea(mem(r12, rsi, 4), rdx) + lea(mem( , rsi, 8), r14) + UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3) + SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31) - // yet to be implemented + UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3) + SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12) + + SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3) + SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8) + + mov(var(m0), rdi) + cmp(imm(8), rdi) + JZ(.UPDATE8BZ) + cmp(imm(7), rdi) + JZ(.UPDATE7BZ) + cmp(imm(6), rdi) + JZ(.UPDATE6BZ) + cmp(imm(5), rdi) + JZ(.UPDATE5BZ) + cmp(imm(4), rdi) + JZ(.UPDATE4BZ) + cmp(imm(3), rdi) + JZ(.UPDATE3BZ) + cmp(imm(2), rdi) + JZ(.UPDATE2BZ) + cmp(imm(1), rdi) + JZ(.UPDATE1BZ) + cmp(imm(0), rdi) + JZ(.UPDATE0BZ) + + LABEL(.UPDATE8BZ) + UPDATE_C_8_BZ + jmp(.DDONE) + + LABEL(.UPDATE7BZ) + UPDATE_C_7_BZ + jmp(.DDONE) + + LABEL(.UPDATE6BZ) + UPDATE_C_6_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE5BZ) + UPDATE_C_5_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE4BZ) + UPDATE_C_4_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE3BZ) + UPDATE_C_3_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE2BZ) + UPDATE_C_2_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE1BZ) + UPDATE_C_1_BZ + jmp(.DDONE) // jump to end. + + LABEL(.UPDATE0BZ) label(.DDONE)