From 8adef27aca822f1867f3aeaa2bcb6cc0a11f27db Mon Sep 17 00:00:00 2001 From: Shubham Sharma Date: Tue, 16 Aug 2022 12:27:46 +0530 Subject: [PATCH] Optimization of DGEMMT SUP kernel for beta zero cases. Details: 1. In kernels for non-transpose variants, changes are made to optimize the cases of beta zero. 2. Validated the changes with BLIS Testsuite, GTestSuite(Functionality, Valgrind, Integer Tests) and Netlib Tests. 3. Fixed warnings during the build process. AMD-Internal: [CPUPL-2341] Change-Id: I8bb53ad619eb2413c999fe18eafd67c75fe1f83a --- frame/3/gemmt/bli_gemmt_sup_var1n2m.c | 288 ++----------- .../3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c | 395 +++++++++++++----- kernels/haswell/bli_kernels_haswell.h | 16 + 3 files changed, 331 insertions(+), 368 deletions(-) diff --git a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c index 2af7e9f45..1023821b8 100644 --- a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c +++ b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c @@ -66,249 +66,11 @@ typedef void (*gemmt_ker_ft) dim_t m0, dim_t n0, dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=0 and n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=6 and n_offset=8 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=12 and n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant combined kernel for m_offset=12, n_offset=16 and m_offset=18, n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=6 and n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=12 and n_offset=8 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=18 and n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=0 and n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=6 and n_offset=8 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=12 and n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant combined kernel for m_offset=0, n_offset=0 and m_offset=6, n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=6 and n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=12 and n_offset=8 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=18 and n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ); @@ -2177,11 +1939,11 @@ void PASTEMACT(ch,opname,uplo,varname) \ mr_cur, \ nr_cur, \ kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ &aux, \ cntx \ ); \ @@ -2203,11 +1965,11 @@ void PASTEMACT(ch,opname,uplo,varname) \ mr_cur, \ nr_cur, \ kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ &aux, \ cntx \ ); \ @@ -2849,11 +2611,11 @@ void PASTEMACT(ch,opname,uplo,varname) \ mr_cur, \ nr_cur, \ kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ &aux, \ cntx \ ); \ @@ -2874,11 +2636,11 @@ void PASTEMACT(ch,opname,uplo,varname) \ mr_cur, \ nr_cur, \ kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ &aux, \ cntx \ ); \ diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index 107917d07..eb734fe0d 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -1127,11 +1127,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -1541,7 +1541,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L label(.DDONE) - + vzeroupper() end_asm( : // output operands (none) @@ -1604,11 +1604,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -1861,7 +1861,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L label(.DDONE) - + vzeroupper() end_asm( @@ -1925,11 +1925,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -2032,19 +2032,21 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L mov(var(cs_c), rsi) lea(mem(, rsi, 8), rsi) vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) - cmp(imm(8), rdi) //rs_c == 0? + cmp(imm(8), rdi) //rs_c == 0? je(.DCOLSTOR) label(.DROWSTOR) - lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) vmovlpd(xmm12, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) vmovlpd(xmm14, mem(rcx)) vmovhpd(xmm14, mem(rcx, rsi, 1)) - jmp(.DRETURN) + jmp(.DDONE) label(.DCOLSTOR) @@ -2058,8 +2060,35 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm14) vmovupd(xmm12, mem(rdx )) vmovhpd(xmm14, mem(rdx, rsi, 1, 1*8)) + jmp(.DDONE) + + label(.DBETAZERO) + cmp(imm(8), rdi) //rs_c == 0? + je(.DCOLSTORBZ) + + label(.DROWSTORBZ) + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + vmovlpd(xmm12, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm14, mem(rcx)) + vmovhpd(xmm14, mem(rcx, rsi, 1)) + jmp(.DDONE) + + label(.DCOLSTORBZ) + + lea(mem(rcx, rdi, 4), rdx) + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vinsertf128(imm(0x1), xmm2, ymm0, ymm12) + vinsertf128(imm(0x1), xmm3, ymm1, ymm14) + + vmovupd(xmm12, mem(rdx )) + vmovhpd(xmm14, mem(rdx, rsi, 1, 1*8)) + jmp(.DDONE) + + label(.DDONE) + vzeroupper() - label(.DRETURN) end_asm( : // output operands (none) : // input operands @@ -2125,11 +2154,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -2143,7 +2172,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); - double* a_next = a + rs_a * 6; + double* a_next = ( (double*)a ) + rs_a * 6; begin_asm() mov(var(a), r14) mov(var(b), rbx) @@ -2428,6 +2457,9 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L lea(mem(, rsi, 8), rsi) vxorpd(ymm0, ymm0, ymm0) lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) + cmp(imm(8), rdi) //rs_c == 8? je(.DCOLSTOR) @@ -2484,7 +2516,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) vmovupd(ymm15, mem(rcx, 1*32)) - jmp(.DRETURN) + jmp(.DDONE) label(.DCOLSTOR) vbroadcastsd(mem(rbx), ymm3) @@ -2575,10 +2607,131 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + jmp(.DDONE) + + label(.DBETAZERO) + cmp(imm(8), rdi) + je(.DCOLSTORBZ) + + label(.DROWSTORBZ) + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm7, mem(rcx)) + vmovhpd(xmm7, mem(rcx, rsi, 1)) + + //For lower 6x8 block + lea(mem(rcx, rdi, 1), rcx) //rcx += 1 * rdi + vmovupd(xmm4, mem(rcx, 0*32)) + vextractf128(imm(1), ymm4, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + + vmovlpd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx, 0*32)) + + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + vmovupd(ymm12, mem(rcx, 0*32)) + + vmovupd(xmm13, mem(rcx, 1*32)) + vextractf128(imm(1), ymm13, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vmovupd(ymm14, mem(rcx, 0*32)) + + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + vbroadcastsd(mem(rbx), ymm3) + + lea(mem(rcx, rdi, 4), rdx) //rdx = rcx + 4* rs_c + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(xmm5, mem(rdx )) + vmovhpd(xmm7, mem(rdx, rsi, 1, 1*8)) + + lea(mem(rcx, rdi, 4), rcx) + lea(mem(rcx, rdi, 2), rcx) + lea(mem(rcx, rdi, 4), rdx) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vextractf128(imm(1), ymm10, xmm1) + vmovhpd(xmm10, mem(rcx, rax, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vextractf128(imm(1), ymm5, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + jmp(.DDONE) + + label(.DDONE) + vzeroupper() - label(.DRETURN) end_asm( : // output operands (none) : // input operands @@ -2639,11 +2792,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -3097,7 +3250,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L vmovupd(xmm4, mem(rdx, rax, 1)) label(.DDONE) - + vzeroupper() end_asm( : // output operands (none) @@ -3162,11 +3315,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -3624,6 +3777,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L vmovupd(xmm4, mem(rdx, rax, 1)) label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -3688,11 +3842,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -4140,6 +4294,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L label(.DDONE) label(.DRETURN) + vzeroupper() end_asm( : // output operands (none) @@ -4203,11 +4358,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -4572,6 +4727,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -4632,11 +4788,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -4993,10 +5149,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(rdx ), xmm3, xmm0) - vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) vmovlpd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) @@ -5029,6 +5181,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -5090,11 +5243,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -5480,6 +5633,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -5541,11 +5695,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -5656,36 +5810,67 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U mov(var(cs_c), rsi) lea(mem(, rsi, 8), rsi) vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) cmp(imm(8), rdi) je(.DCOLSTOR) label(.DROWSTOR) - lea(mem(rcx, 1*32), rcx) - lea(mem(rcx, 1*16), rcx) + lea(mem(rcx, 1*32), rcx) + lea(mem(rcx, 1*16), rcx) vfmadd231pd(mem(rcx, 0*32), xmm3, xmm5) - vmovlpd(xmm5, mem(rcx)) + vmovlpd(xmm5, mem(rcx)) vmovhpd(xmm5, mem(rcx, rsi, 1)) add(rdi, rcx) vfmadd231pd(mem(rcx, 0*32), xmm3, xmm7) vmovhpd(xmm7, mem(rcx, rsi, 1)) - jmp(.DRETURN) + jmp(.DDONE) label(.DCOLSTOR) - vbroadcastsd(mem(rbx), ymm3) + vbroadcastsd(mem(rbx), ymm3) lea(mem(rcx, rsi, 4), rcx) lea(mem(rcx, rsi, 2), rcx) vunpcklpd(xmm7, xmm5, xmm0) vunpckhpd(xmm7, xmm5, xmm1) - vfmadd231pd(mem(rcx ), xmm3, xmm0) - vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) - vmovlpd(xmm0, mem(rcx )) - vmovupd(xmm1, mem(rcx, rsi, 1)) + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + jmp(.DDONE) + + label(.DBETAZERO) + cmp(imm(8), rdi) + je(.DCOLSTORBZ) + + label(.DROWSTORBZ) + lea(mem(rcx, 1*32), rcx) + lea(mem(rcx, 1*16), rcx) + + vmovlpd(xmm5, mem(rcx)) + vmovhpd(xmm5, mem(rcx, rsi, 1)) + add(rdi, rcx) + vmovhpd(xmm7, mem(rcx, rsi, 1)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + lea(mem(rcx, rsi, 4), rcx) + lea(mem(rcx, rsi, 2), rcx) + vunpcklpd(xmm7, xmm5, xmm0) + vunpckhpd(xmm7, xmm5, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + jmp(.DDONE) + + + label(.DDONE) + vzeroupper() - label(.DRETURN) end_asm( : // output operands (none) : // input operands @@ -5745,11 +5930,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -6008,10 +6193,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U vmovupd(xmm7, mem(rcx, rsi, 1)) vmovupd(xmm9, mem(rcx, rsi, 2)) vextractf128(imm(0x1), ymm9, xmm9) - vmovupd(ymm9, mem(rcx, rsi, 2, 1*16)) - + vmovlpd(xmm9, mem(rcx, rsi, 2, 2*8)) + vmovupd(ymm11, mem(rcx, rax, 1)) label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -6073,11 +6259,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -6394,7 +6580,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U vmovupd(xmm5, mem(rcx )) vextractf128(imm(0x1), ymm5, xmm5) - vmovlpd(xmm5, mem(rcx )) + vmovlpd(xmm5, mem(rcx, 2*8 )) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, rax, 1)) @@ -6409,6 +6595,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -6476,11 +6663,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -7012,9 +7199,9 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U vunpckhpd(ymm14, ymm12, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - lea(mem(rcx, rsi, 4), rbp) + lea(mem(rcx, rdi, 4), rbp) + lea(mem(rbp, rdi, 2), rbp) lea(mem(rbp, rsi, 2), rbp) - lea(mem(rbp, 1*32+1*16), rbp) vmovlpd(xmm2, mem(rbp)) vmovupd(xmm4, mem(rbp, rsi, 1)) @@ -7047,6 +7234,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -7741,8 +7929,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) @@ -8400,8 +8587,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) @@ -9034,8 +9220,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) diff --git a/kernels/haswell/bli_kernels_haswell.h b/kernels/haswell/bli_kernels_haswell.h index 1c35122a4..5b4c8a05b 100644 --- a/kernels/haswell/bli_kernels_haswell.h +++ b/kernels/haswell/bli_kernels_haswell.h @@ -278,6 +278,22 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x1 ) // gemmsup_rd (mkernel in m dim) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x0_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x8_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x8_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x16_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_18x16_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_combined_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x0_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x8_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x8_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x16_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_18x16_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_16x12_combined_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_combined_U ) + GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m ) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x4m ) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x2m )