diff --git a/frame/include/bli_x86_asm_macros.h b/frame/include/bli_x86_asm_macros.h index 84bc76c21..112fe6473 100644 --- a/frame/include/bli_x86_asm_macros.h +++ b/frame/include/bli_x86_asm_macros.h @@ -776,6 +776,7 @@ #define VMOVHPD(...) INSTR_(vmovhpd, __VA_ARGS__) #define VMOVDQA(_0, _1) INSTR_(vmovdqa, _0, _1) #define VMOVDQA32(_0, _1) INSTR_(vmovdqa32, _0, _1) +#define VMOVDQU(_0, _1) INSTR_(vmovdqu, _0, _1) #define VMOVDQA64(_0, _1) INSTR_(vmovdqa64, _0, _1) #define VBROADCASTSS(_0, _1) INSTR_(vbroadcastss, _0, _1) #define VBROADCASTSD(_0, _1) INSTR_(vbroadcastsd, _0, _1) @@ -809,6 +810,7 @@ #define vmovhpd(...) VMOVHPD(__VA_ARGS__) #define vmovdqa(_0, _1) VMOVDQA(_0, _1) #define vmovdqa32(_0, _1) VMOVDQA32(_0, _1) +#define vmovdqu(_0, _1) VMOVDQU(_0, _1) #define vmovdqa64(_0, _1) VMOVDQA64(_0, _1) #define vbroadcastss(_0, _1) VBROADCASTSS(_0, _1) #define vbroadcastsd(_0, _1) VBROADCASTSD(_0, _1) @@ -911,6 +913,7 @@ #define VCOMISS(_0, _1) INSTR_(vcomiss, _0, _1) #define VCOMISD(_0, _1) INSTR_(vcomisd, _0, _1) +#define VMASKMOVPD(_0, _1, _2) INSTR_(vmaskmovpd, _0, _1, _2) #define VFMADD132SS(_0, _1, _2) INSTR_(vfmadd132ss, _0, _1, _2) #define VFMADD213SS(_0, _1, _2) INSTR_(vfmadd213ss, _0, _1, _2) #define VFMADD231SS(_0, _1, _2) INSTR_(vfmadd231ss, _0, _1, _2) @@ -1236,7 +1239,7 @@ #define vblendpd(_0, _1, _2, _3) VBLENDPD(_0, _1, _2, _3) #define vblendmps(_0, _1, _2) VBLENDMSD(_0, _1, _2) #define vblendmpd(_0, _1, _2) VBLENDMPD(_0, _1, _2) - +#define vmaskmovpd(_0, _1, _2) VMASKMOVPD(_0, _1, _2) // Prefetches #define PREFETCH(_0, _1) INSTR_(prefetcht##_0, _1) diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index 05c240d2d..cdd698982 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -38,6 +38,360 @@ #define BLIS_ASM_SYNTAX_ATT #include "bli_x86_asm_macros.h" +static const int64_t mask_3[4] = {-1, -1, -1, 0}; +static const int64_t mask_1[4] = {-1, 0, 0, 0}; + +static void bli_dgemmsup_rv_haswell_asm_6x7m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +static void bli_dgemmsup_rv_haswell_asm_6x5m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +static void bli_dgemmsup_rv_haswell_asm_6x3m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +#define C_TRANSPOSE_6x7_TILE(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10, R11, R12) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + /*Broadcasting Beta into ymm15 vector register*/\ + vbroadcastsd(mem(rbx), ymm15)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*R1, R2, R3, R4 holds final result*/ \ + vfmadd231pd(mem(rcx ), ymm15, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm15, ymm(R4))\ + /*Storing it back to C matrix.*/ \ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + /*Moving to operate on last 2 rows of 6 rows.*/ \ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm3)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*0, 1, 2, 3 holds final result*/ \ + vfmadd231pd(mem(rdx ), xmm15, xmm0)\ + vfmadd231pd(mem(rdx, rsi, 1), xmm15, xmm1)\ + vfmadd231pd(mem(rdx, rsi, 2), xmm15, xmm2)\ + vfmadd231pd(mem(rdx, rax, 1), xmm15, xmm3)\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2))\ + vmovupd(xmm3, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R8), ymm(R7), ymm0)\ + vunpckhpd(ymm(R8), ymm(R7), ymm1)\ + vunpcklpd(ymm(R10), ymm(R9), ymm2)\ + vunpckhpd(ymm(R10), ymm(R9), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm5)\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm7)\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm9)\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm5)\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm7)\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm9)\ +\ + vmovupd(ymm5, mem(rcx ))\ + vmovupd(ymm7, mem(rcx, rsi, 1))\ + vmovupd(ymm9, mem(rcx, rsi, 2))\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R12), ymm(R11), ymm0)\ + vunpckhpd(ymm(R12), ymm(R11), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm4)\ +\ + vfmadd231pd(mem(rdx ), xmm15, xmm0)\ + vfmadd231pd(mem(rdx, rsi, 1), xmm15, xmm1)\ + vfmadd231pd(mem(rdx, rsi, 2), xmm15, xmm2)\ +\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_6x7_TILE_BZ(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10, R11, R12) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + /*Storing transposed 4x4 tile back to C matrix*/\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm3)\ +\ + /*Storing transposed 2x4 tile back to C matrix*/\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2))\ + vmovupd(xmm3, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R8), ymm(R7), ymm0)\ + vunpckhpd(ymm(R8), ymm(R7), ymm1)\ + vunpcklpd(ymm(R10), ymm(R9), ymm2)\ + vunpckhpd(ymm(R10), ymm(R9), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm5)\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm7)\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm9)\ +\ + /*Storing transposed 4x3 tile back to C matrix*/\ + vmovupd(ymm5, mem(rcx ))\ + vmovupd(ymm7, mem(rcx, rsi, 1))\ + vmovupd(ymm9, mem(rcx, rsi, 2))\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R12), ymm(R11), ymm0)\ + vunpckhpd(ymm(R12), ymm(R11), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm4)\ +\ + /*Storing transposed 2x3 tile back to C matrix*/\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_6x5_TILE(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10, R11, R12) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + /*Broadcasting Beta into ymm15 vector register*/\ + vbroadcastsd(mem(rbx), ymm15)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*R1, R2, R3, R4 holds final result*/ \ + vfmadd231pd(mem(rcx ), ymm15, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm15, ymm(R4))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm3)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*0, 1, 2, 3 holds final result*/ \ + vfmadd231pd(mem(rdx ), xmm15, xmm0)\ + vfmadd231pd(mem(rdx, rsi, 1), xmm15, xmm1)\ + vfmadd231pd(mem(rdx, rsi, 2), xmm15, xmm2)\ + vfmadd231pd(mem(rdx, rax, 1), xmm15, xmm3)\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2))\ + vmovupd(xmm3, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x1 tile*/ \ + vunpcklpd(ymm(R8), ymm(R7), ymm0)\ + vunpcklpd(ymm(R10), ymm(R9), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm5)\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm5)\ + vmovupd(ymm5, mem(rcx ))\ +\ + /*Transposing 2x1 tile*/ \ + vunpcklpd(ymm(R12), ymm(R11), ymm0)\ + vfmadd231pd(mem(rdx ), xmm15, xmm0)\ +\ + vmovupd(xmm0, mem(rdx )) + +#define C_TRANSPOSE_6x5_TILE_BZ(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10, R11, R12) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + /*Storing transposed 4x4 tile back to C matrix*/\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm3)\ +\ + /*Storing transposed 4x2 tile back to C matrix*/\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2))\ + vmovupd(xmm3, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x1 tile*/ \ + vunpcklpd(ymm(R8), ymm(R7), ymm0)\ + vunpcklpd(ymm(R10), ymm(R9), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm5)\ +\ + /*Storing transposed 4x1 tile back to C matrix*/\ + vmovupd(ymm5, mem(rcx ))\ +\ + /*Transposing 2x1 tile*/ \ + vunpcklpd(ymm(R12), ymm(R11), ymm0)\ +\ + /*Storing transposed 2x1 tile back to C matrix*/\ + vmovupd(xmm0, mem(rdx )) + +#define C_TRANSPOSE_6x3_TILE(R1, R2, R3, R4, R5, R6) \ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*R1, R2, R3 holds final result*/ \ + vfmadd231pd(mem(rcx ), ymm3, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm(R3))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*0, 1, 2 holds final result*/ \ + vfmadd231pd(mem(rdx ), xmm3, xmm0)\ + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1)\ + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2)\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_6x3_TILE_BZ(R1, R2, R3, R4, R5, R6) \ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm4)\ +\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2)) + /* rrr: -------- ------ -------- @@ -108,93 +462,114 @@ void bli_dgemmsup_rv_haswell_asm_6x8m double* restrict bj = b; double* restrict ai = a; - if ( 6 <= n_left ) + switch(n_left) { - const dim_t nr_cur = 6; - - bli_dgemmsup_rv_haswell_asm_6x6m - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 4 <= n_left ) - { - const dim_t nr_cur = 4; - - bli_dgemmsup_rv_haswell_asm_6x4m - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 2 <= n_left ) - { - const dim_t nr_cur = 2; - - bli_dgemmsup_rv_haswell_asm_6x2m - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 1 == n_left ) - { -#if 0 - const dim_t nr_cur = 1; - - bli_dgemmsup_r_haswell_ref - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); -#else - dim_t ps_a0 = bli_auxinfo_ps_a( data ); - - if ( ps_a0 == 6 * rs_a0 ) + case 7: { - // Since A is not packed, we can use one gemv. - bli_dgemv_ex + bli_dgemmsup_rv_haswell_asm_6x7m ( - BLIS_NO_TRANSPOSE, conjb, m0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, - beta, cij, rs_c0, cntx, NULL + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx ); + break; } - else + case 6: { - const dim_t mr = 6; - - // Since A is packed into row panels, we must use a loop over - // gemv. - dim_t m_iter = ( m0 + mr - 1 ) / mr; - dim_t m_left = m0 % mr; - - double* restrict ai_ii = ai; - double* restrict cij_ii = cij; - - for ( dim_t ii = 0; ii < m_iter; ii += 1 ) - { - dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) - ? mr : m_left ); - - bli_dgemv_ex - ( - BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, - alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, - beta, cij_ii, rs_c0, cntx, NULL - ); - cij_ii += mr*rs_c0; ai_ii += ps_a0; - } + bli_dgemmsup_rv_haswell_asm_6x6m + ( + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 5: + { + bli_dgemmsup_rv_haswell_asm_6x5m + ( + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 4: + { + bli_dgemmsup_rv_haswell_asm_6x4m + ( + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 3: + { + bli_dgemmsup_rv_haswell_asm_6x3m + ( + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 2: + { + bli_dgemmsup_rv_haswell_asm_6x2m + ( + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 1: + { + dim_t ps_a0 = bli_auxinfo_ps_a( data ); + + if ( ps_a0 == 6 * rs_a0 ) + { + // Since A is not packed, we can use one gemv. + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + else + { + const dim_t mr = 6; + + // Since A is packed into row panels, we must use a loop over + // gemv. + dim_t m_iter = ( m0 + mr - 1 ) / mr; + dim_t m_left = m0 % mr; + + double* restrict ai_ii = ai; + double* restrict cij_ii = cij; + + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) + { + dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) + ? mr : m_left ); + + bli_dgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, + beta, cij_ii, rs_c0, cntx, NULL + ); + cij_ii += mr*rs_c0; ai_ii += ps_a0; + } + } + break; + } + default: + { + break; } -#endif } return; } @@ -916,53 +1291,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8m double* restrict ai = a + m_iter * ps_a; double* restrict bj = b; -#if 0 - // We add special handling for slightly inflated MR blocksizes - // at edge cases, up to a maximum of 9. - if ( 6 < m_left ) - { - dgemmsup_ker_ft ker_fp1 = NULL; - dgemmsup_ker_ft ker_fp2 = NULL; - dim_t mr1, mr2; - - if ( m_left == 7 ) - { - mr1 = 4; mr2 = 3; - ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8; - ker_fp2 = bli_dgemmsup_rv_haswell_asm_3x8; - } - else if ( m_left == 8 ) - { - mr1 = 4; mr2 = 4; - ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8; - ker_fp2 = bli_dgemmsup_rv_haswell_asm_4x8; - } - else // if ( m_left == 9 ) - { - mr1 = 4; mr2 = 5; - ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8; - ker_fp2 = bli_dgemmsup_rv_haswell_asm_5x8; - } - - ker_fp1 - ( - conja, conjb, mr1, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr1*rs_c0; ai += mr1*rs_a0; - - ker_fp2 - ( - conja, conjb, mr2, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - - return; - } -#endif - dgemmsup_ker_ft ker_fps[6] = { NULL, @@ -8129,6 +8457,1728 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U ) } +static void bli_dgemmsup_rv_haswell_asm_6x7m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 7 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 3 elements, +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + int64_t const *mask_vec = mask_3; + + if ( m_iter == 0 ) goto consider_edge_cases_7; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load mask + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X7I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm3) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + lea(mem(rdx, rsi, 2), rcx) // rcx = c + 5*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.DLOOPKITER) // MAIN LOOP + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements based on mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements based on mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements based on mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements based on mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + label(.DPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + vmulpd(ymm0, ymm7, ymm7) // scale by alpha + vmulpd(ymm0, ymm8, ymm8) + + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + label(.DROWSTORED) + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + //Loads 4 element + vmovupd(ymm3, mem(rcx, 0*32)) + //Loads 3 elements based on mask_3 mask vector + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm9) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------4 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm11) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm12) + + vmovupd(ymm11, mem(rcx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------5 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm13) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm14) + + vmovupd(ymm13, mem(rcx, 0*32)) + vmaskmovpd(ymm14, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------6 + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORED) + C_TRANSPOSE_6x7_TILE(3, 5, 7, 9, 11, 13, 4, 6, 8, 10, 12, 14) + jmp(.RESETPARAM) + + label(.DBETAZERO) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------4 + + vmovupd(ymm11, mem(rcx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------5 + + vmovupd(ymm13, mem(rcx, 0*32)) + vmaskmovpd(ymm14, ymm15, mem(rcx, 1*32)) + + //-----------------------6 + + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_6x7_TILE_BZ(3, 5, 7, 9, 11, 13, 4, 6, 8, 10, 12, 14) + jmp(.RESETPARAM) + + label(.RESETPARAM) + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load mask + jmp(.DDONE) + + label(.DDONE) + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + mov(var(ps_a8), rax) // load ps_a8 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a8 + + dec(r11) // ii -= 1; + jne(.DLOOP6X7I) // iterate again if ii != 0. + + + label(.DRETURN) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [mask_vec] "m" (mask_vec), + [rs_c] "m" (rs_c), + [n0] "m" (n0), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", + "memory" + ) + + consider_edge_cases_7: + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = n0; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict ai = a + m_iter * ps_a; + double* restrict bj = b; + + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x7, + bli_dgemmsup_rv_haswell_asm_2x7, + bli_dgemmsup_rv_haswell_asm_3x7, + bli_dgemmsup_rv_haswell_asm_4x7, + bli_dgemmsup_rv_haswell_asm_5x7 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +static void bli_dgemmsup_rv_haswell_asm_6x5m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_5); + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 5 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 1 element, +// kernel is using mask_1 which is set to -1, 0, 0, 0 static that the +// 1 element will be loaded and other 3 elements will be set to 0 in destination vector. +// + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + int64_t const *mask_vec = mask_1; + + if ( m_iter == 0 ) goto consider_edge_cases_5; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load mask + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X5I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm3) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.DLOOPKITER) // MAIN LOOP + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + label(.DPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + vmulpd(ymm0, ymm7, ymm7) // scale by alpha + vmulpd(ymm0, ymm8, ymm8) + + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + label(.DROWSTORED) + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + //Loads 4 element + vmovupd(ymm3, mem(rcx, 0*32)) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm9) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------4 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm11) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm12) + + vmovupd(ymm11, mem(rcx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------5 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm13) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm14) + + vmovupd(ymm13, mem(rcx, 0*32)) + vmaskmovpd(ymm14, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------6 + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORED) + + C_TRANSPOSE_6x5_TILE(3, 5, 7, 9, 11, 13, 4, 6, 8, 10, 12, 14) + jmp(.RESETPARAM) + + label(.DBETAZERO) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------4 + + vmovupd(ymm11, mem(rcx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------5 + + vmovupd(ymm13, mem(rcx, 0*32)) + vmaskmovpd(ymm14, ymm15, mem(rcx, 1*32)) + + //-----------------------6 + + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_6x5_TILE_BZ(3, 5, 7, 9, 11, 13, 4, 6, 8, 10, 12, 14) + jmp(.RESETPARAM) + + label(.RESETPARAM) + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load mask + jmp(.DDONE) + + label(.DDONE) + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + mov(var(ps_a8), rax) // load ps_a8 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a8 + + dec(r11) // ii -= 1; + jne(.DLOOP6X5I) // iterate again if ii != 0. + + + label(.DRETURN) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [mask_vec] "m" (mask_vec), + [rs_c] "m" (rs_c), + [n0] "m" (n0), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", + "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", + "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", + "memory" + ) + + consider_edge_cases_5: + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = n0; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict ai = a + m_iter * ps_a; + double* restrict bj = b; + + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x5, + bli_dgemmsup_rv_haswell_asm_2x5, + bli_dgemmsup_rv_haswell_asm_3x5, + bli_dgemmsup_rv_haswell_asm_4x5, + bli_dgemmsup_rv_haswell_asm_5x5 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_5); +} + +static void bli_dgemmsup_rv_haswell_asm_6x3m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// + +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + int64_t const *mask_vec = mask_3; + + if ( m_iter == 0 ) goto consider_edge_cases_nleft_3; + + begin_asm() + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load mask + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X3I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 2*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + label(.DPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + label(.DROWSTORED) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm4) + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm6) + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm8) + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm10) + vmaskmovpd(ymm10, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm12) + vmaskmovpd(ymm12, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm14) + vmaskmovpd(ymm14, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORED) + + C_TRANSPOSE_6x3_TILE(4, 6, 8, 10, 12, 14) + jmp(.RESETPARAM) + + label(.DBETAZERO) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + label(.DROWSTORBZ) + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm10, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm12, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm14, ymm15, mem(rcx, 0*32)) + + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORBZ) + + C_TRANSPOSE_6x3_TILE_BZ(4, 6, 8, 10, 12, 14) + jmp(.RESETPARAM) + + label(.RESETPARAM) + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load mask + jmp(.DDONE) // jump to end. + + label(.DDONE) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + mov(var(ps_a8), rax) // load ps_a8 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a8 + + dec(r11) // ii -= 1; + jne(.DLOOP6X3I) // iterate again if ii != 0. + + label(.DRETURN) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [n0] "m" (n0), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", "ymm14", + "memory" + ) + + consider_edge_cases_nleft_3: + if ( m_left ) + { + const dim_t nr_cur = n0; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict ai = a + m_iter * ps_a; + double* restrict bj = b; + + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x3, + bli_dgemmsup_rv_haswell_asm_2x3, + bli_dgemmsup_rv_haswell_asm_3x3, + bli_dgemmsup_rv_haswell_asm_4x3, + bli_dgemmsup_rv_haswell_asm_5x3 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + + + void bli_dgemmsup_rv_haswell_asm_6x6m ( conj_t conja, @@ -8912,6 +10962,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6m AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } + void bli_dgemmsup_rv_haswell_asm_6x4m ( conj_t conja, @@ -10206,5 +12257,3 @@ void bli_dgemmsup_rv_haswell_asm_6x2m } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } - - diff --git a/kernels/haswell/3/sup/d6x8/CMakeLists.txt b/kernels/haswell/3/sup/d6x8/CMakeLists.txt index c74dff937..24edd62ba 100644 --- a/kernels/haswell/3/sup/d6x8/CMakeLists.txt +++ b/kernels/haswell/3/sup/d6x8/CMakeLists.txt @@ -8,8 +8,11 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_dMx2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_dMx4.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_dMx8.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_dMx2.c +${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_dMx3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_dMx4.c +${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_dMx5.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_dMx6.c +${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_dMx7.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_dMx8.c ) diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx3.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx3.c new file mode 100644 index 000000000..795ca5772 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx3.c @@ -0,0 +1,2137 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + + +#define C_TRANSPOSE_5x3_TILE(R1, R2, R3, R4, R5)\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm3, ymm1, ymm(R4))\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + vfmadd231pd(mem(rcx ), ymm3, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm(R4))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ +\ + /*Transposing 4x1 tile*/ \ + vmovlpd(mem(rdx ), xmm0, xmm0)\ + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vfmadd213pd(ymm(R5), ymm3, ymm0)\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rdx ))\ + vmovhpd(xmm0, mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_5x3_TILE_BZ(R1, R2, R3, R4, R5)\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm3, ymm1, ymm(R4))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ +\ + /*Transposing 1x3 tile*/ \ + vextractf128(imm(1), ymm(R5), xmm1)\ + vmovlpd(xmm(R5), mem(rdx ))\ + vmovhpd(xmm(R5), mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2)) + + +#define C_TRANSPOSE_4x3_TILE(R1, R2, R3, R4)\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + vfmadd231pd(mem(rcx ), ymm3, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm(R3))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2)) + +#define C_TRANSPOSE_4x3_TILE_BZ(R1, R2, R3, R4)\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2)) + +#define C_TRANSPOSE_3x3_TILE(R1, R2, R3)\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(10), ymm(R3), ymm2)\ + vunpckhpd(ymm(10), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ +\ + vextractf128(imm(0x1), ymm(R1), xmm12)\ + vextractf128(imm(0x1), ymm(R2), xmm13)\ + vextractf128(imm(0x1), ymm(R3), xmm14)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + vfmadd231pd(mem(rcx ), xmm3, xmm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm(R3))\ + vmovupd(xmm(R1), mem(rcx ))\ + vmovupd(xmm(R2), mem(rcx, rsi, 1))\ + vmovupd(xmm(R3), mem(rcx, rsi, 2))\ +\ + /*Transposing 1x3 tile*/ \ + vfmadd231sd(mem(rdx ), xmm3, xmm12)\ + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13)\ + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14)\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_3x3_TILE_BZ(R1, R2, R3)\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(10), ymm(R3), ymm2)\ + vunpckhpd(ymm(10), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ +\ + vextractf128(imm(0x1), ymm(R1), xmm12)\ + vextractf128(imm(0x1), ymm(R2), xmm13)\ + vextractf128(imm(0x1), ymm(R3), xmm14)\ +\ + vmovupd(xmm(R1), mem(rcx ))\ + vmovupd(xmm(R2), mem(rcx, rsi, 1))\ + vmovupd(xmm(R3), mem(rcx, rsi, 2))\ +\ + /*Transposing 1x3 tile*/ \ + vmovlpd(xmm(12), mem(rdx ))\ + vmovlpd(xmm(13), mem(rdx, rsi, 1))\ + vmovlpd(xmm(14), mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_2x3_TILE(R1, R2)\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ + vfmadd231pd(mem(rcx ), xmm3, xmm0)\ + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1)\ + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2)\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2)) + + +#define C_TRANSPOSE_2x3_TILE_BZ(R1, R2)\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ +\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2)) + +#define C_TRANSPOSE_1x3_TILE(R1)\ + vmovlpd(mem(rcx ), xmm0, xmm0)\ + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ + vfmadd213pd(ymm(R1), ymm3, ymm0)\ +\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rcx ))\ + vmovhpd(xmm0, mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2)) + +#define C_TRANSPOSE_1x3_TILE_BZ(R1)\ + vextractf128(imm(1), ymm(R1), xmm1)\ + vmovlpd(xmm(R1), mem(rcx ))\ + vmovhpd(xmm(R1), mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2)) + +static const int64_t mask_3[4] = {-1, -1, -1, 0}; + +void bli_dgemmsup_rv_haswell_asm_5x3 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + lea(mem(r9, r9, 2), r15) // r15 = 3*cs_a + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(rdx, r15, 1, 5*8)) // a_prefetch += 3*cs_a; + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + + add(rdi, rcx) + //-----------------------1 + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + + add(rdi, rcx) + //-----------------------2 + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + + add(rdi, rcx) + //-----------------------3 + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + vmaskmovpd(ymm10, ymm15, mem(rcx, 0*32)) + + add(rdi, rcx) + //-----------------------4 + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm12) + vmaskmovpd(ymm12, ymm15, mem(rcx, 0*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_5x3_TILE(4, 6, 8, 10, 12) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm10, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm12, ymm15, mem(rcx, 0*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_5x3_TILE_BZ(4, 6, 8, 10, 12) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x3 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 2), rdx) // a_prefetch += 2*cs_a; + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 3*cs_a; + prefetch(0, mem(rdx, 4*8)) + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + + add(rdi, rcx) + //-----------------------1 + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + + add(rdi, rcx) + //-----------------------2 + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + + add(rdi, rcx) + //-----------------------3 + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + vmaskmovpd(ymm10, ymm15, mem(rcx, 0*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_4x3_TILE(4, 6, 8, 10) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm10, ymm15, mem(rcx, 0*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_4x3_TILE_BZ(4, 6, 8, 10) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x3 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 3*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 2), rdx) // a_prefetch += 2*cs_a; + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 3*cs_a; + prefetch(0, mem(rdx, 4*8)) + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + + add(rdi, rcx) + //-----------------------1 + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + + add(rdi, rcx) + //-----------------------2 + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_3x3_TILE(4, 6, 8) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_3x3_TILE_BZ(4, 6, 8) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x3 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 2*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 2*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 2), rdx) // a_prefetch += 2*cs_a; + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 3*cs_a; + prefetch(0, mem(rdx, 4*8)) + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + + add(rdi, rcx) + //-----------------------1 + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_2x3_TILE(4, 6) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORBZ) + + C_TRANSPOSE_2x3_TILE_BZ(4, 6) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x3 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 1*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 1*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 2), rdx) // a_prefetch += 2*cs_a; + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 3*cs_a; + prefetch(0, mem(rdx, 4*8)) + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_1x3_TILE(4) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_1x3_TILE_BZ(4) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", "ymm15", + "memory" + ) +} diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx5.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx5.c new file mode 100644 index 000000000..ac12db75c --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx5.c @@ -0,0 +1,2519 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +//3, 5, 7, 9, 11, 13, 4, 6, 8, 10, 12, 14 +#define C_TRANSPOSE_5x5_TILE(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + /*Broadcasting Beta into ymm15 vector register*/\ + vbroadcastsd(mem(rbx), ymm15)\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm15, ymm(R4))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 1x4 tile*/ \ + vmovlpd(mem(rdx ), xmm0, xmm0)\ + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1)\ + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + /*Transposing 4x1 tile*/ \ + vfmadd213pd(ymm(R5), ymm15, ymm0)\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rdx ))\ + vmovhpd(xmm0, mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2))\ + vmovhpd(xmm1, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + vunpcklpd(ymm(R7), ymm(R6), ymm0)\ + vunpcklpd(ymm(R9), ymm(R8), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R6))\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R6))\ + vmovupd(ymm(R6), mem(rcx ))\ +\ + /*Transposing 1x1 tile*/ \ + vmovlpd(mem(rdx ), xmm0, xmm0)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vfmadd213pd(ymm(R10), ymm15, ymm0)\ + vmovlpd(xmm0, mem(rdx )) + +#define C_TRANSPOSE_5x5_TILE_BZ(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 1x4 tile*/ \ + vextractf128(imm(1), ymm(R5), xmm1)\ + vmovlpd(xmm(R5), mem(rdx ))\ + vmovhpd(xmm(R5), mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2))\ + vmovhpd(xmm1, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 1x4 tile*/ \ + vunpcklpd(ymm(R7), ymm(R6), ymm0)\ + vunpcklpd(ymm(R9), ymm(R8), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R6))\ +\ + vmovupd(ymm(R6), mem(rcx ))\ +\ + /*Transposing 1x1 tile*/ \ + vmovlpd(xmm(R10), mem(rdx )) + + +#define C_TRANSPOSE_4x5_TILE(R1, R2, R3, R4, R5, R6, R7, R8) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + vbroadcastsd(mem(rbx), ymm15)\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm15, ymm(R4))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x1 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpcklpd(ymm(R8), ymm(R7), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R5))\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R5))\ + vmovupd(ymm(R5), mem(rcx )) + +#define C_TRANSPOSE_4x5_TILE_BZ(R1, R2, R3, R4, R5, R6, R7, R8) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x1 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpcklpd(ymm(R8), ymm(R7), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R5))\ +\ + vmovupd(ymm(R5), mem(rcx )) + +//3, 5, 7, 4, 6, 8 +#define C_TRANSPOSE_3x5_TILE(R1, R2, R3, R4, R5, R6) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm10, ymm(R3), ymm2)\ + vunpckhpd(ymm10, ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm10)\ +\ + vextractf128(imm(0x1), ymm(R1), xmm12)\ + vextractf128(imm(0x1), ymm(R2), xmm13)\ + vextractf128(imm(0x1), ymm(R3), xmm14)\ + vextractf128(imm(0x1), ymm10, xmm15)\ +\ + vbroadcastsd(mem(rbx), ymm11)\ +\ + vfmadd231pd(mem(rcx ), xmm11, xmm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), xmm11, xmm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), xmm11, xmm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), xmm11, xmm10)\ + vmovupd(xmm(R1), mem(rcx ))\ + vmovupd(xmm(R2), mem(rcx, rsi, 1))\ + vmovupd(xmm(R3), mem(rcx, rsi, 2))\ + vmovupd(xmm10, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 1x4 tile*/ \ + vfmadd231sd(mem(rdx ), xmm11, xmm12)\ + vfmadd231sd(mem(rdx, rsi, 1), xmm11, xmm13)\ + vfmadd231sd(mem(rdx, rsi, 2), xmm11, xmm14)\ + vfmadd231sd(mem(rdx, rax, 1), xmm11, xmm15)\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2))\ + vmovsd(xmm15, mem(rdx, rax, 1))\ + \ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 2x1 tile*/ \ + vunpcklpd(ymm(R5), ymm(R4), ymm0)\ + vunpcklpd(ymm11, ymm(R6), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R4))\ +\ + vextractf128(imm(0x1), ymm(R4), xmm12)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + vfmadd231pd(mem(rcx ), xmm3, xmm(R4))\ + vmovupd(xmm(R4), mem(rcx ))\ +\ + /*Transposing 1x1 tile*/ \ + vfmadd231sd(mem(rdx ), xmm3, xmm12)\ + vmovsd(xmm12, mem(rdx )) + +#define C_TRANSPOSE_3x5_TILE_BZ(R1, R2, R3, R4, R5, R6) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm10, ymm(R3), ymm2)\ + vunpckhpd(ymm10, ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm10)\ +\ + vextractf128(imm(0x1), ymm(R1), xmm12)\ + vextractf128(imm(0x1), ymm(R2), xmm13)\ + vextractf128(imm(0x1), ymm(R3), xmm14)\ + vextractf128(imm(0x1), ymm10, xmm15)\ +\ + vmovupd(xmm(R1), mem(rcx ))\ + vmovupd(xmm(R2), mem(rcx, rsi, 1))\ + vmovupd(xmm(R3), mem(rcx, rsi, 2))\ + vmovupd(xmm10, mem(rcx, rax, 1))\ +\ + /*Transposing 1x4 tile*/ \ + lea(mem(rcx, rsi, 4), rcx)\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2))\ + vmovsd(xmm15, mem(rdx, rax, 1))\ + \ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 2x1 tile*/ \ + vunpcklpd(ymm(R5), ymm(R4), ymm0)\ + vunpcklpd(ymm11, ymm(R6), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R4))\ +\ + vextractf128(imm(0x1), ymm(R4), xmm12)\ +\ + vmovupd(xmm(R4), mem(rcx ))\ +\ + /*Transposing 1x1 tile*/ \ + vmovsd(xmm12, mem(rdx )) + +//3, 5, 4, 6 +#define C_TRANSPOSE_2x5_TILE(R1, R2, R3, R4) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm7)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ + vfmadd231pd(mem(rcx ), xmm3, xmm0)\ + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1)\ + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2)\ + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm7)\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2))\ + vmovupd(xmm7, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x1 tile*/ \ + vunpcklpd(ymm(R4), ymm(R3), ymm0)\ +\ + vfmadd231pd(mem(rcx ), xmm3, xmm0)\ + vmovupd(xmm0, mem(rcx )) + +#define C_TRANSPOSE_2x5_TILE_BZ(R1, R2, R3, R4) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm7)\ +\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2))\ + vmovupd(xmm7, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x1 tile*/ \ + vunpcklpd(ymm(R4), ymm(R3), ymm0)\ +\ + vmovupd(xmm0, mem(rcx )) + +#define C_TRANSPOSE_1x5_TILE(R1, R2) \ + /*Transposing 1x4 tile*/ \ + vmovlpd(mem(rcx ), xmm0, xmm0)\ + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1)\ + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vbroadcastsd(mem(rbx), ymm15)\ + vfmadd213pd(ymm(R1), ymm15, ymm0)\ +\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rcx ))\ + vmovhpd(xmm0, mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2))\ + vmovhpd(xmm1, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + vmovlpd(mem(rcx ), xmm0, xmm0)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vfmadd213pd(ymm(R2), ymm15, ymm0)\ +\ + /*Transposing 1x1 tile*/ \ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rcx )) + +#define C_TRANSPOSE_1x5_TILE_BZ(R1, R2) \ + vextractf128(imm(1), ymm(R1), xmm1)\ + vmovlpd(xmm(R1), mem(rcx ))\ + vmovhpd(xmm(R1), mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2))\ + vmovhpd(xmm1, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + vextractf128(imm(1), ymm(R2), xmm1)\ + vmovlpd(xmm(R2), mem(rcx )) + +static const int64_t mask_1[4] = {-1, 0, 0, 0}; + +void bli_dgemmsup_rv_haswell_asm_5x5 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 5 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 1 element, +// kernel is using mask_1 which is set to -1, 0, 0, 0 static that the +// 1 element will be loaded and other 3 elements will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 4*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 4*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 4*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rdx, 1, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 2, 4*8)) // prefetch c + 4*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm9) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------4 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm11) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm12) + + vmovupd(ymm11, mem(rcx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_5x5_TILE(3, 5, 7, 9, 11, 4, 6, 8, 10, 12) + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------4 + vmovupd(ymm11, mem(rcx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + C_TRANSPOSE_5x5_TILE_BZ(3, 5, 7, 9, 11, 4, 6, 8, 10, 12) + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", + "ymm5", "ymm7", "ymm9", "ymm11", "ymm15", + "memory" + ) +} + + +void bli_dgemmsup_rv_haswell_asm_4x5 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 5 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 1 element, +// kernel is using mask_1 which is set to -1, 0, 0, 0 static that the +// 1 element will be loaded and other 3 elements will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 4*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 4*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rdx, 1, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 2, 3*8)) // prefetch c + 4*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm9) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + //-----------------------4 + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_4x5_TILE(3, 5, 7, 9, 4, 6, 8, 10) + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + //-----------------------4 + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + C_TRANSPOSE_4x5_TILE_BZ(3, 5, 7, 9, 4, 6, 8, 10) + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", + "ymm5", "ymm7", "ymm9", "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x5 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 5 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 1 element, +// kernel is using mask_1 which is set to -1, 0, 0, 0 static that the +// 1 element will be loaded and other 3 elements will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 4*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 4*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rdx, 1, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 2, 2*8)) // prefetch c + 4*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_3x5_TILE(3, 5, 7, 4, 6, 8) + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + C_TRANSPOSE_3x5_TILE_BZ(3, 5, 7, 4, 6, 8) + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", + "ymm5", "ymm7", "ymm11", "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x5 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 5 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 1 element, +// kernel is using mask_1 which is set to -1, 0, 0, 0 static that the +// 1 element will be loaded and other 3 elements will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 4*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rdx, 1, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 2, 1*8)) // prefetch c + 4*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_2x5_TILE(3, 5, 4, 6) + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + C_TRANSPOSE_2x5_TILE_BZ(3, 5, 4, 6) + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", + "ymm5", "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x5 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 5 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 1 element, +// kernel is using mask_1 which is set to -1, 0, 0, 0 static that the +// 1 element will be loaded and other 3 elements will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rdx, 1, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 2, 0*8)) // prefetch c + 4*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_1x5_TILE(3, 4) + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + C_TRANSPOSE_1x5_TILE_BZ(3, 4) + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", + "ymm15", + "memory" + ) +} diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx7.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx7.c new file mode 100644 index 000000000..8c14eba4a --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx7.c @@ -0,0 +1,2602 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +//3, 5, 7, 9, 11, 13, 4, 6, 8, 10, 12, 14 +#define C_TRANSPOSE_5x7_TILE(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + /*Broadcasting Beta into ymm15 vector register*/\ + vbroadcastsd(mem(rbx), ymm15)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*R1, R2, R3, R4 holds final result*/ \ + vfmadd231pd(mem(rcx ), ymm15, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm15, ymm(R4))\ + /*Storing it back to C matrix.*/ \ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + /*Moving to operate on last 1 row of 5 rows.*/ \ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 1x4 tile*/ \ + vmovlpd(mem(rdx ), xmm0, xmm0)\ + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1)\ + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vfmadd213pd(ymm(R5), ymm15, ymm0)\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rdx ))\ + vmovhpd(xmm0, mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2))\ + vmovhpd(xmm1, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R7), ymm(R6), ymm0)\ + vunpckhpd(ymm(R7), ymm(R6), ymm1)\ + vunpcklpd(ymm(R9), ymm(R8), ymm2)\ + vunpckhpd(ymm(R9), ymm(R8), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R6))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R7))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R8))\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R6))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R7))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R8))\ + vmovupd(ymm(R6), mem(rcx ))\ + vmovupd(ymm(R7), mem(rcx, rsi, 1))\ + vmovupd(ymm(R8), mem(rcx, rsi, 2))\ +\ + vmovlpd(mem(rdx ), xmm0, xmm0)\ + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + /*Transposing 1x3 tile*/ \ + vfmadd213pd(ymm(R10), ymm15, ymm0)\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rdx ))\ + vmovhpd(xmm0, mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_5x7_TILE_BZ(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 1x4 tile*/ \ + vextractf128(imm(1), ymm(R5), xmm1)\ + vmovlpd(xmm(R5), mem(rdx ))\ + vmovhpd(xmm(R5), mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2))\ + vmovhpd(xmm1, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R7), ymm(R6), ymm0)\ + vunpckhpd(ymm(R7), ymm(R6), ymm1)\ + vunpcklpd(ymm(R9), ymm(R8), ymm2)\ + vunpckhpd(ymm(R9), ymm(R8), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R6))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R7))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R8))\ +\ + vmovupd(ymm(R6), mem(rcx ))\ + vmovupd(ymm(R7), mem(rcx, rsi, 1))\ + vmovupd(ymm(R8), mem(rcx, rsi, 2))\ +\ + /*Transposing 1x3 tile*/ \ + vextractf128(imm(1), ymm(R10), xmm1)\ + vmovlpd(xmm(R10), mem(rdx ))\ + vmovhpd(xmm(R10), mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_4x7_TILE(R1, R2, R3, R4, R5, R6, R7, R8) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + vbroadcastsd(mem(rbx), ymm15)\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm15, ymm(R4))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vunpcklpd(ymm(R8), ymm(R7), ymm2)\ + vunpckhpd(ymm(R8), ymm(R7), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R5))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R6))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R7))\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R5))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R6))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R7))\ + vmovupd(ymm(R5), mem(rcx ))\ + vmovupd(ymm(R6), mem(rcx, rsi, 1))\ + vmovupd(ymm(R7), mem(rcx, rsi, 2)) + +#define C_TRANSPOSE_4x7_TILE_BZ(R1, R2, R3, R4, R5, R6, R7, R8) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vunpcklpd(ymm(R8), ymm(R7), ymm2)\ + vunpckhpd(ymm(R8), ymm(R7), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R5))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R6))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R7))\ +\ + vmovupd(ymm(R5), mem(rcx ))\ + vmovupd(ymm(R6), mem(rcx, rsi, 1))\ + vmovupd(ymm(R7), mem(rcx, rsi, 2)) + +//3, 5, 7, 4, 6, 8 +#define C_TRANSPOSE_3x7_TILE(R1, R2, R3, R4, R5, R6) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm10, ymm(R3), ymm2)\ + vunpckhpd(ymm10, ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm10)\ +\ + /*Transposing 1x4 tile*/ \ + vextractf128(imm(0x1), ymm(R1), xmm12)\ + vextractf128(imm(0x1), ymm(R2), xmm13)\ + vextractf128(imm(0x1), ymm(R3), xmm14)\ + vextractf128(imm(0x1), ymm10, xmm15)\ +\ + vbroadcastsd(mem(rbx), ymm11)\ +\ + vfmadd231pd(mem(rcx ), xmm11, xmm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), xmm11, xmm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), xmm11, xmm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), xmm11, xmm10)\ + vmovupd(xmm(R1), mem(rcx ))\ + vmovupd(xmm(R2), mem(rcx, rsi, 1))\ + vmovupd(xmm(R3), mem(rcx, rsi, 2))\ + vmovupd(xmm10, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + vfmadd231sd(mem(rdx ), xmm11, xmm12)\ + vfmadd231sd(mem(rdx, rsi, 1), xmm11, xmm13)\ + vfmadd231sd(mem(rdx, rsi, 2), xmm11, xmm14)\ + vfmadd231sd(mem(rdx, rax, 1), xmm11, xmm15)\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2))\ + vmovsd(xmm15, mem(rdx, rax, 1))\ + \ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R5), ymm(R4), ymm0)\ + vunpckhpd(ymm(R5), ymm(R4), ymm1)\ + vunpcklpd(ymm11, ymm(R6), ymm2)\ + vunpckhpd(ymm11, ymm(R6), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R4))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R5))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R6))\ +\ + /*Transposing 1x3 tile*/ \ + vextractf128(imm(0x1), ymm(R4), xmm12)\ + vextractf128(imm(0x1), ymm(R5), xmm13)\ + vextractf128(imm(0x1), ymm(R6), xmm14)\ +\ + vfmadd231pd(mem(rcx ), xmm11, xmm(R4))\ + vfmadd231pd(mem(rcx, rsi, 1), xmm11, xmm(R5))\ + vfmadd231pd(mem(rcx, rsi, 2), xmm11, xmm(R6))\ + vmovupd(xmm(R4), mem(rcx ))\ + vmovupd(xmm(R5), mem(rcx, rsi, 1))\ + vmovupd(xmm(R6), mem(rcx, rsi, 2))\ +\ + vfmadd231sd(mem(rdx ), xmm11, xmm12)\ + vfmadd231sd(mem(rdx, rsi, 1), xmm11, xmm13)\ + vfmadd231sd(mem(rdx, rsi, 2), xmm11, xmm14)\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_3x7_TILE_BZ(R1, R2, R3, R4, R5, R6) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm10, ymm(R3), ymm2)\ + vunpckhpd(ymm10, ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm10)\ +\ + /*Transposing 1x4 tile*/ \ + vextractf128(imm(0x1), ymm(R1), xmm12)\ + vextractf128(imm(0x1), ymm(R2), xmm13)\ + vextractf128(imm(0x1), ymm(R3), xmm14)\ + vextractf128(imm(0x1), ymm10, xmm15)\ +\ + vmovupd(xmm(R1), mem(rcx ))\ + vmovupd(xmm(R2), mem(rcx, rsi, 1))\ + vmovupd(xmm(R3), mem(rcx, rsi, 2))\ + vmovupd(xmm10, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2))\ + vmovsd(xmm15, mem(rdx, rax, 1))\ + \ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R5), ymm(R4), ymm0)\ + vunpckhpd(ymm(R5), ymm(R4), ymm1)\ + vunpcklpd(ymm11, ymm(R6), ymm2)\ + vunpckhpd(ymm11, ymm(R6), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R4))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R5))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R6))\ +\ + /*Transposing 1x3 tile*/ \ + vextractf128(imm(0x1), ymm(R4), xmm12)\ + vextractf128(imm(0x1), ymm(R5), xmm13)\ + vextractf128(imm(0x1), ymm(R6), xmm14)\ +\ + vmovupd(xmm(R4), mem(rcx ))\ + vmovupd(xmm(R5), mem(rcx, rsi, 1))\ + vmovupd(xmm(R6), mem(rcx, rsi, 2))\ +\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2)) + +//3, 5, 4, 6 +#define C_TRANSPOSE_2x7_TILE(R1, R2, R3, R4) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm7)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ + vfmadd231pd(mem(rcx ), xmm3, xmm0)\ + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1)\ + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2)\ + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm7)\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2))\ + vmovupd(xmm7, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R4), ymm(R3), ymm0)\ + vunpckhpd(ymm(R4), ymm(R3), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ +\ + vfmadd231pd(mem(rcx ), xmm3, xmm0)\ + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1)\ + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2)\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2)) + + +#define C_TRANSPOSE_2x7_TILE_BZ(R1, R2, R3, R4) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm7)\ +\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2))\ + vmovupd(xmm7, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R4), ymm(R3), ymm0)\ + vunpckhpd(ymm(R4), ymm(R3), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ +\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2)) + + +#define C_TRANSPOSE_1x7_TILE(R1, R2) \ + vmovlpd(mem(rcx ), xmm0, xmm0)\ + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1)\ + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vbroadcastsd(mem(rbx), ymm15)\ + vfmadd213pd(ymm(R1), ymm15, ymm0)\ +\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rcx ))\ + vmovhpd(xmm0, mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2))\ + vmovhpd(xmm1, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + vmovlpd(mem(rcx ), xmm0, xmm0)\ + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vfmadd213pd(ymm(R2), ymm15, ymm0)\ +\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rcx ))\ + vmovhpd(xmm0, mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2)) + + +#define C_TRANSPOSE_1x7_TILE_BZ(R1, R2) \ + vextractf128(imm(1), ymm(R1), xmm1)\ + vmovlpd(xmm(R1), mem(rcx ))\ + vmovhpd(xmm(R1), mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2))\ + vmovhpd(xmm1, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ + vextractf128(imm(1), ymm(R2), xmm1)\ + vmovlpd(xmm(R2), mem(rcx ))\ + vmovhpd(xmm(R2), mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2)) + +static const int64_t mask_3[4] = {-1, -1, -1, 0}; + +void bli_dgemmsup_rv_haswell_asm_5x7 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 7 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 3 elements, +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 6*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 6*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 6*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 6*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 6*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm9) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------4 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm11) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm12) + + vmovupd(ymm11, mem(rcx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_5x7_TILE(3, 5, 7, 9, 11, 4, 6, 8, 10, 12) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------4 + + vmovupd(ymm11, mem(rcx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_5x7_TILE_BZ(3, 5, 7, 9, 11, 4, 6, 8, 10, 12) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", + "ymm5", "ymm7", "ymm9", "ymm15", + "memory" + ) +} + + +void bli_dgemmsup_rv_haswell_asm_4x7 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 7 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 3 elements, +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 6*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 6*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 6*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 6*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 6*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm9) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + //-----------------------4 + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_4x7_TILE(3, 5, 7, 9, 4, 6, 8, 10) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + //-----------------------4 + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_4x7_TILE_BZ(3, 5, 7, 9, 4, 6, 8, 10) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", + "ymm5", "ymm7", "ymm11", "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x7 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 7 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 3 elements, +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 6*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 6*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 6*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 6*cs_c + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_3x7_TILE(3, 5, 7, 4, 6, 8) + jmp(.DDONE) // jump to end. + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_3x7_TILE_BZ(3, 5, 7, 4, 6, 8) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", + "ymm5", "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x7 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 7 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 3 elements, +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 6*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 6*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 6*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_2x7_TILE(3, 5, 4, 6) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_2x7_TILE_BZ(3, 5, 4, 6) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", + "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x7 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 7 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 3 elements, +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 6*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 6*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_1x7_TILE(3, 4) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_1x7_TILE_BZ(3, 4) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", + "ymm15", + "memory" + ) +} diff --git a/kernels/haswell/bli_kernels_haswell.h b/kernels/haswell/bli_kernels_haswell.h index d841d715f..8c4e3c44e 100644 --- a/kernels/haswell/bli_kernels_haswell.h +++ b/kernels/haswell/bli_kernels_haswell.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -219,6 +219,12 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x8 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x8 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x7 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x7 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x7 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x7 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x7 ) + GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x6 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x6 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x6 ) @@ -226,6 +232,12 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x6 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x6 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x6 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x5 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x5 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x5 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x5 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x5 ) + GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x4 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x4 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x4 ) @@ -233,6 +245,12 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x4 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x4 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x3 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x3 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x3 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x3 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x3 ) + GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x2 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x2 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x2 ) diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index 477c71047..1d1c5105f 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -2392,939 +2392,1119 @@ err_t bli_dgemm_small } m_remainder = M - row_idx; - - if (m_remainder >= 12) + if(m_remainder) { - m_remainder -= 12; + // Sets up the mask for loading relevant remainder elements in load direction + // int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. + // + // Low end High end * Low end High end + // ________________________ * ________________________ + // | | | | | * | | | | | + // | 1 | 2 | 3 | 4 | ----> Source vector * | 1 | 2 | 3 | 4 | ----> Source vector + // |_____|_____|_____|_____| * |_____|_____|_____|_____| + // * + // ________________________ * ________________________ + // | | | | | * | | | | | + // | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) | -1 | -1 | 0 | 0 | ----> Mask vector( mask_2 ) + // |_____|_____|_____|_____| * |_____|_____|_____|_____| + // * + // ________________________ * ________________________ + // | | | | | * | | | | | + // | 1 | 2 | 3 | 0 | ----> Destination vector * | 1 | 2 | 0 | 0 | ----> Destination vector + // |_____|_____|_____|_____| * |_____|_____|_____|_____| + // + // -1 sets all the bits to 1. + // + dim_t m_rem = 0; + int64_t mask_4[4] = {0}; + mask_4[0] = -1; + mask_4[1] = -1; + mask_4[2] = -1; + mask_4[3] = -1; - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + int64_t mask_3[4] = {0}; + mask_3[0] = -1; + mask_3[1] = -1; + mask_3[2] = -1; + mask_3[3] = 0; + + int64_t mask_2[4] = {0}; + mask_2[0] = -1; + mask_2[1] = -1; + mask_2[2] = 0; + mask_2[3] = 0; + + int64_t mask_1[4] = {0}; + mask_1[0] = -1; + mask_1[1] = 0; + mask_1[2] = 0; + mask_1[3] = 0; + + int64_t *mask_ptr[] = {mask_4, mask_1, mask_2, mask_3, mask_4}; + if(m_remainder > 12) { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - // ymm4 += ymm0 * ymm3; - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - // ymm8 += ymm1 * ymm3; - ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); - // ymm12 += ymm2 * ymm3; - ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); - - ymm3 = _mm256_loadu_pd(tA + 4); - // ymm5 += ymm0 * ymm3; - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - // ymm9 += ymm1 * ymm3; - ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); - // ymm13 += ymm2 * ymm3; - ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); - - ymm3 = _mm256_loadu_pd(tA + 8); - // ymm6 += ymm0 * ymm3; - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - // ymm10 += ymm1 * ymm3; - ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); - // ymm14 += ymm2 * ymm3; - ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - ymm8 = _mm256_mul_pd(ymm8, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - ymm10 = _mm256_mul_pd(ymm10, ymm0); - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - - // multiply C by beta and accumulate. - double *ttC = tC +ldc; - ymm2 = _mm256_loadu_pd(ttC); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(ttC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(ttC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); - - // multiply C by beta and accumulate. - ttC += ldc; - ymm2 = _mm256_loadu_pd(ttC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(ttC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(ttC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - } - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - _mm256_storeu_pd(tC + 8, ymm6); - - tC += ldc; - - _mm256_storeu_pd(tC, ymm8); - _mm256_storeu_pd(tC + 4, ymm9); - _mm256_storeu_pd(tC + 8, ymm10); - - tC += ldc; - - _mm256_storeu_pd(tC, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); - ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); - - ymm3 = _mm256_loadu_pd(tA + 4); - ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); - ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); - - ymm3 = _mm256_loadu_pd(tA + 8); - ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); - ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); - - tA += lda; - - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm8 = _mm256_mul_pd(ymm8, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - ymm10 = _mm256_mul_pd(ymm10, ymm0); - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC + 0); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); - - double *ttC = tC + ldc; - - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(ttC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(ttC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - } - _mm256_storeu_pd(tC + 0, ymm8); - _mm256_storeu_pd(tC + 4, ymm9); - _mm256_storeu_pd(tC + 8, ymm10); - - tC += ldc; - - _mm256_storeu_pd(tC, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); - - col_idx += 2; - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); - - ymm3 = _mm256_loadu_pd(tA + 4); - ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); - - ymm3 = _mm256_loadu_pd(tA + 8); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tA += lda; - - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC + 0); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - } - _mm256_storeu_pd(tC + 0, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); - } - - row_idx += 12; - } - - if (m_remainder >= 8) - { - m_remainder -= 8; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6); - ymm8 = _mm256_fmadd_pd(ymm2, ymm3, ymm8); - - ymm3 = _mm256_loadu_pd(tA + 4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - ymm8 = _mm256_mul_pd(ymm8, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - - double* ttC = tC + ldc; - - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_pd(ttC + 4); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); - - ttC += ldc; - - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(ttC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - } - - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - - tC += ldc; - _mm256_storeu_pd(tC, ymm6); - _mm256_storeu_pd(tC + 4, ymm7); - - tC += ldc; - _mm256_storeu_pd(tC, ymm8); - _mm256_storeu_pd(tC + 4, ymm9); - - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6); - - ymm3 = _mm256_loadu_pd(tA + 4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - - double* ttC = tC + ldc; - - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_pd(ttC + 4); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); - } - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - - tC += ldc; - _mm256_storeu_pd(tC, ymm6); - _mm256_storeu_pd(tC + 4, ymm7); - - col_idx += 2; - - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm3 = _mm256_loadu_pd(tA + 4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - } - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - - } - - row_idx += 8; - } - - if (m_remainder >= 4) - { - m_remainder -= 4; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - - double* ttC = tC + ldc; - - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - - ttC += ldc; - - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - } - _mm256_storeu_pd(tC, ymm4); - - tC += ldc; - _mm256_storeu_pd(tC, ymm5); - - tC += ldc; - _mm256_storeu_pd(tC, ymm6); - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - - double* ttC = tC + ldc; - - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - } - _mm256_storeu_pd(tC, ymm4); - - tC += ldc; - _mm256_storeu_pd(tC, ymm5); - - col_idx += 2; - - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm4 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - ymm4 = _mm256_mul_pd(ymm4, ymm0); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - - } - _mm256_storeu_pd(tC, ymm4); - - } - - row_idx += 4; - } - // M is not a multiple of 32. - // The handling of edge case where the remainder - // dimension is less than 8. The padding takes place - // to handle this case. - if ((m_remainder) && (lda > 3)) - { - double f_temp[8] = {0.0}; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - for (k = 0; k < (K - 1); ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tA[i]; - } - ymm3 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); - - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - - if(is_beta_non_zero) - { - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - - - double* ttC = tC + ldc; - - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = ttC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); - - ttC += ldc; - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = ttC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - } - _mm256_storeu_pd(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - - tC += ldc; - _mm256_storeu_pd(f_temp, ymm7); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - - tC += ldc; - _mm256_storeu_pd(f_temp, ymm9); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for (k = 0; k < (K - 1); ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; - - ymm3 = _mm256_loadu_pd(tA); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tA += lda; - } - - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; - - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tA[i]; - } - ymm3 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - - if(is_beta_non_zero) - { - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - - double* ttC = tC + ldc; - - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = ttC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); - - } - _mm256_storeu_pd(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - - tC += ldc; - _mm256_storeu_pd(f_temp, ymm7); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm5 = _mm256_setzero_pd(); - - for (k = 0; k < (K - 1); ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; - - ymm3 = _mm256_loadu_pd(tA); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - tA += lda; - } - - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; - - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tA[i]; - } - ymm3 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - // multiply C by beta and accumulate. - ymm5 = _mm256_mul_pd(ymm5, ymm0); - - if(is_beta_non_zero) - { - - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - } - _mm256_storeu_pd(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - } - m_remainder = 0; - } - - if (m_remainder) - { - double result; - for (; row_idx < M; row_idx += 1) - { - for (col_idx = 0; col_idx < N; col_idx += 1) + // Handles edge cases where remainder elements are between 12-16(13, 14, 15). + // Here m_rem gives index in mask_ptr that points which mask to be used based + // on remainder elements which could be 1, 2, or 3 here. + m_rem = (m_remainder % 12); + __m256i maskVec = _mm256_loadu_si256( (__m256i *)mask_ptr[m_rem]); + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) { //pointer math to point to proper memory tC = C + ldc * col_idx + row_idx; tB = B + tb_inc_col * col_idx; tA = A + row_idx; - result = 0; + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + for (k = 0; k < K; ++k) { - result += (*tA) * (*tB); - tA += lda; + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 12, maskVec); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm7 = _mm256_mul_pd(ymm7, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm11 = _mm256_mul_pd(ymm11, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 12, maskVec); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + // multiply C by beta and accumulate, col 2. + double* ttC = tC + ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 12, maskVec); + ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11); + + // multiply C by beta and accumulate, col 3. + ttC += ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 12, maskVec); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + } + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + _mm256_storeu_pd(tC + 8, ymm6); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 12, maskVec, ymm7); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 12, maskVec, ymm11); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 12, maskVec, ymm15); + } + n_remainder = N - col_idx; + + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); + ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 12, maskVec); + ymm11 = _mm256_fmadd_pd(ymm0, ymm3, ymm11); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm11 = _mm256_mul_pd(ymm11, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate, col 1. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 12, maskVec); + ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11); + + // multiply C by beta and accumulate, col 2. + double *ttC = tC + ldc; + + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 12, maskVec); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); } - result *= (*alpha_cast); + _mm256_storeu_pd(tC + 0, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 12, maskVec, ymm11); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 12, maskVec, ymm15); + col_idx += 2; + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); + + ymm3 = _mm256_loadu_pd(tA + 8); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 12, maskVec); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); + if(is_beta_non_zero) - (*tC) = (*tC) * (*beta_cast) + result; - else - (*tC) = result; + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 12, maskVec); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + } + + _mm256_storeu_pd(tC + 0, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 12, maskVec, ymm15); + } + } + else if(m_remainder > 8) + { + // Handles edge cases where remainder elements are between 9-12(9, 10, 11, 12). + // Here m_rem gives index in mask_ptr that points which mask to be used based + // on remainder elements which could be 1, 2, 3 or 4 here. + m_rem = (m_remainder % 8); + __m256i maskVec = _mm256_loadu_si256( (__m256i *)mask_ptr[m_rem]); + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 8, maskVec); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 8, maskVec); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + // multiply C by beta and accumulate. + double *ttC = tC +ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 8, maskVec); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + + // multiply C by beta and accumulate. + ttC += ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 8, maskVec); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + } + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 8, maskVec, ymm6); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 8, maskVec, ymm10); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 8, maskVec, ymm14); + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 8, maskVec); + ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); + ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 8, maskVec); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + + double *ttC = tC + ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 8, maskVec); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + } + _mm256_storeu_pd(tC + 0, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 8, maskVec, ymm10); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 8, maskVec, ymm14); + + col_idx += 2; + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); + + ymm3 = _mm256_loadu_pd(tA + 4); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 8, maskVec); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 8, maskVec); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + } + _mm256_storeu_pd(tC + 0, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 8, maskVec, ymm14); + } + } + else if(m_remainder > 4) + { + // Handles edge cases where remainder elements are between 5-8(5, 6, 7, 8). + // Here m_rem gives index in mask_ptr that points which mask to be used based + // on remainder elements which could be 1, 2, 3 or 4 here. + m_rem = (m_remainder % 4); + __m256i maskVec = _mm256_loadu_si256( (__m256i *)mask_ptr[m_rem]); + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 4, maskVec); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 4, maskVec); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + + // multiply C by beta and accumulate. + double *ttC = tC +ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 4, maskVec); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + + // multiply C by beta and accumulate. + ttC += ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 4, maskVec); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + } + _mm256_storeu_pd(tC, ymm4); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 4, maskVec, ymm5); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm8); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 4, maskVec, ymm9); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm12); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 4, maskVec, ymm13); + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 4, maskVec); + ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 4, maskVec); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + + double *ttC = tC + ldc; + + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 4, maskVec); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + + } + _mm256_storeu_pd(tC + 0, ymm8); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 4, maskVec, ymm9); + + tC += ldc; + + _mm256_storeu_pd(tC, ymm12); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 4, maskVec, ymm13); + + col_idx += 2; + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 4, maskVec); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); + + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 4, maskVec); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + } + _mm256_storeu_pd(tC + 0, ymm12); + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC + 4, maskVec, ymm13); + } + } + else + { + __m256i maskVec = _mm256_loadu_si256( (__m256i *)mask_ptr[m_remainder]); + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA, maskVec); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + + if(is_beta_non_zero) + { + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC, maskVec); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + + double* ttC = tC + ldc; + + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC, maskVec); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + + ttC += ldc; + + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC, maskVec); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + } + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC, maskVec, ymm4); + + tC += ldc; + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC, maskVec, ymm5); + + tC += ldc; + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC, maskVec, ymm6); + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA, maskVec); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + + if(is_beta_non_zero) + { + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC, maskVec); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + + double* ttC = tC + ldc; + + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC, maskVec); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + } + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC, maskVec, ymm4); + + tC += ldc; + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC, maskVec, ymm5); + + col_idx += 2; + + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + ymm4 = _mm256_setzero_pd(); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA, maskVec); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + ymm4 = _mm256_mul_pd(ymm4, ymm0); + + if(is_beta_non_zero) + { + // Masked load the relevant remaider elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC, maskVec); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + + } + // Masked store the relevant remaider elements of C matrix + _mm256_maskstore_pd(tC, maskVec, ymm4); } } } - // Return the buffer to pool + // Return the buffer to pool if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_dgemm_small(): releasing mem pool block\n" );