diff --git a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c index 2af7e9f45..1023821b8 100644 --- a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c +++ b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c @@ -66,249 +66,11 @@ typedef void (*gemmt_ker_ft) dim_t m0, dim_t n0, dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=0 and n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=6 and n_offset=8 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=12 and n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant combined kernel for m_offset=12, n_offset=16 and m_offset=18, n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=6 and n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=12 and n_offset=8 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=18 and n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=0 and n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=6 and n_offset=8 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=12 and n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant combined kernel for m_offset=0, n_offset=0 and m_offset=6, n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=6 and n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=12 and n_offset=8 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=18 and n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ); @@ -2177,11 +1939,11 @@ void PASTEMACT(ch,opname,uplo,varname) \ mr_cur, \ nr_cur, \ kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ &aux, \ cntx \ ); \ @@ -2203,11 +1965,11 @@ void PASTEMACT(ch,opname,uplo,varname) \ mr_cur, \ nr_cur, \ kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ &aux, \ cntx \ ); \ @@ -2849,11 +2611,11 @@ void PASTEMACT(ch,opname,uplo,varname) \ mr_cur, \ nr_cur, \ kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ &aux, \ cntx \ ); \ @@ -2874,11 +2636,11 @@ void PASTEMACT(ch,opname,uplo,varname) \ mr_cur, \ nr_cur, \ kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ &aux, \ cntx \ ); \ diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index 107917d07..eb734fe0d 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -1127,11 +1127,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -1541,7 +1541,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L label(.DDONE) - + vzeroupper() end_asm( : // output operands (none) @@ -1604,11 +1604,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -1861,7 +1861,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L label(.DDONE) - + vzeroupper() end_asm( @@ -1925,11 +1925,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -2032,19 +2032,21 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L mov(var(cs_c), rsi) lea(mem(, rsi, 8), rsi) vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) - cmp(imm(8), rdi) //rs_c == 0? + cmp(imm(8), rdi) //rs_c == 0? je(.DCOLSTOR) label(.DROWSTOR) - lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) vmovlpd(xmm12, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) vmovlpd(xmm14, mem(rcx)) vmovhpd(xmm14, mem(rcx, rsi, 1)) - jmp(.DRETURN) + jmp(.DDONE) label(.DCOLSTOR) @@ -2058,8 +2060,35 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm14) vmovupd(xmm12, mem(rdx )) vmovhpd(xmm14, mem(rdx, rsi, 1, 1*8)) + jmp(.DDONE) + + label(.DBETAZERO) + cmp(imm(8), rdi) //rs_c == 0? + je(.DCOLSTORBZ) + + label(.DROWSTORBZ) + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + vmovlpd(xmm12, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm14, mem(rcx)) + vmovhpd(xmm14, mem(rcx, rsi, 1)) + jmp(.DDONE) + + label(.DCOLSTORBZ) + + lea(mem(rcx, rdi, 4), rdx) + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vinsertf128(imm(0x1), xmm2, ymm0, ymm12) + vinsertf128(imm(0x1), xmm3, ymm1, ymm14) + + vmovupd(xmm12, mem(rdx )) + vmovhpd(xmm14, mem(rdx, rsi, 1, 1*8)) + jmp(.DDONE) + + label(.DDONE) + vzeroupper() - label(.DRETURN) end_asm( : // output operands (none) : // input operands @@ -2125,11 +2154,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -2143,7 +2172,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); - double* a_next = a + rs_a * 6; + double* a_next = ( (double*)a ) + rs_a * 6; begin_asm() mov(var(a), r14) mov(var(b), rbx) @@ -2428,6 +2457,9 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L lea(mem(, rsi, 8), rsi) vxorpd(ymm0, ymm0, ymm0) lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) + cmp(imm(8), rdi) //rs_c == 8? je(.DCOLSTOR) @@ -2484,7 +2516,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) vmovupd(ymm15, mem(rcx, 1*32)) - jmp(.DRETURN) + jmp(.DDONE) label(.DCOLSTOR) vbroadcastsd(mem(rbx), ymm3) @@ -2575,10 +2607,131 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + jmp(.DDONE) + + label(.DBETAZERO) + cmp(imm(8), rdi) + je(.DCOLSTORBZ) + + label(.DROWSTORBZ) + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm7, mem(rcx)) + vmovhpd(xmm7, mem(rcx, rsi, 1)) + + //For lower 6x8 block + lea(mem(rcx, rdi, 1), rcx) //rcx += 1 * rdi + vmovupd(xmm4, mem(rcx, 0*32)) + vextractf128(imm(1), ymm4, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + + vmovlpd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx, 0*32)) + + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + vmovupd(ymm12, mem(rcx, 0*32)) + + vmovupd(xmm13, mem(rcx, 1*32)) + vextractf128(imm(1), ymm13, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vmovupd(ymm14, mem(rcx, 0*32)) + + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + vbroadcastsd(mem(rbx), ymm3) + + lea(mem(rcx, rdi, 4), rdx) //rdx = rcx + 4* rs_c + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(xmm5, mem(rdx )) + vmovhpd(xmm7, mem(rdx, rsi, 1, 1*8)) + + lea(mem(rcx, rdi, 4), rcx) + lea(mem(rcx, rdi, 2), rcx) + lea(mem(rcx, rdi, 4), rdx) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vextractf128(imm(1), ymm10, xmm1) + vmovhpd(xmm10, mem(rcx, rax, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vextractf128(imm(1), ymm5, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + jmp(.DDONE) + + label(.DDONE) + vzeroupper() - label(.DRETURN) end_asm( : // output operands (none) : // input operands @@ -2639,11 +2792,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -3097,7 +3250,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L vmovupd(xmm4, mem(rdx, rax, 1)) label(.DDONE) - + vzeroupper() end_asm( : // output operands (none) @@ -3162,11 +3315,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -3624,6 +3777,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L vmovupd(xmm4, mem(rdx, rax, 1)) label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -3688,11 +3842,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -4140,6 +4294,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L label(.DDONE) label(.DRETURN) + vzeroupper() end_asm( : // output operands (none) @@ -4203,11 +4358,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -4572,6 +4727,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -4632,11 +4788,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -4993,10 +5149,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(rdx ), xmm3, xmm0) - vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) vmovlpd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) @@ -5029,6 +5181,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -5090,11 +5243,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -5480,6 +5633,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -5541,11 +5695,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -5656,36 +5810,67 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U mov(var(cs_c), rsi) lea(mem(, rsi, 8), rsi) vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) cmp(imm(8), rdi) je(.DCOLSTOR) label(.DROWSTOR) - lea(mem(rcx, 1*32), rcx) - lea(mem(rcx, 1*16), rcx) + lea(mem(rcx, 1*32), rcx) + lea(mem(rcx, 1*16), rcx) vfmadd231pd(mem(rcx, 0*32), xmm3, xmm5) - vmovlpd(xmm5, mem(rcx)) + vmovlpd(xmm5, mem(rcx)) vmovhpd(xmm5, mem(rcx, rsi, 1)) add(rdi, rcx) vfmadd231pd(mem(rcx, 0*32), xmm3, xmm7) vmovhpd(xmm7, mem(rcx, rsi, 1)) - jmp(.DRETURN) + jmp(.DDONE) label(.DCOLSTOR) - vbroadcastsd(mem(rbx), ymm3) + vbroadcastsd(mem(rbx), ymm3) lea(mem(rcx, rsi, 4), rcx) lea(mem(rcx, rsi, 2), rcx) vunpcklpd(xmm7, xmm5, xmm0) vunpckhpd(xmm7, xmm5, xmm1) - vfmadd231pd(mem(rcx ), xmm3, xmm0) - vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) - vmovlpd(xmm0, mem(rcx )) - vmovupd(xmm1, mem(rcx, rsi, 1)) + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + jmp(.DDONE) + + label(.DBETAZERO) + cmp(imm(8), rdi) + je(.DCOLSTORBZ) + + label(.DROWSTORBZ) + lea(mem(rcx, 1*32), rcx) + lea(mem(rcx, 1*16), rcx) + + vmovlpd(xmm5, mem(rcx)) + vmovhpd(xmm5, mem(rcx, rsi, 1)) + add(rdi, rcx) + vmovhpd(xmm7, mem(rcx, rsi, 1)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + lea(mem(rcx, rsi, 4), rcx) + lea(mem(rcx, rsi, 2), rcx) + vunpcklpd(xmm7, xmm5, xmm0) + vunpckhpd(xmm7, xmm5, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + jmp(.DDONE) + + + label(.DDONE) + vzeroupper() - label(.DRETURN) end_asm( : // output operands (none) : // input operands @@ -5745,11 +5930,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -6008,10 +6193,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U vmovupd(xmm7, mem(rcx, rsi, 1)) vmovupd(xmm9, mem(rcx, rsi, 2)) vextractf128(imm(0x1), ymm9, xmm9) - vmovupd(ymm9, mem(rcx, rsi, 2, 1*16)) - + vmovlpd(xmm9, mem(rcx, rsi, 2, 2*8)) + vmovupd(ymm11, mem(rcx, rax, 1)) label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -6073,11 +6259,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -6394,7 +6580,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U vmovupd(xmm5, mem(rcx )) vextractf128(imm(0x1), ymm5, xmm5) - vmovlpd(xmm5, mem(rcx )) + vmovlpd(xmm5, mem(rcx, 2*8 )) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, rax, 1)) @@ -6409,6 +6595,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -6476,11 +6663,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -7012,9 +7199,9 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U vunpckhpd(ymm14, ymm12, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - lea(mem(rcx, rsi, 4), rbp) + lea(mem(rcx, rdi, 4), rbp) + lea(mem(rbp, rdi, 2), rbp) lea(mem(rbp, rsi, 2), rbp) - lea(mem(rbp, 1*32+1*16), rbp) vmovlpd(xmm2, mem(rbp)) vmovupd(xmm4, mem(rbp, rsi, 1)) @@ -7047,6 +7234,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -7741,8 +7929,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) @@ -8400,8 +8587,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) @@ -9034,8 +9220,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) diff --git a/kernels/haswell/bli_kernels_haswell.h b/kernels/haswell/bli_kernels_haswell.h index 1c35122a4..5b4c8a05b 100644 --- a/kernels/haswell/bli_kernels_haswell.h +++ b/kernels/haswell/bli_kernels_haswell.h @@ -278,6 +278,22 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x1 ) // gemmsup_rd (mkernel in m dim) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x0_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x8_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x8_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x16_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_18x16_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_combined_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x0_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x8_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x8_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x16_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_18x16_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_16x12_combined_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_combined_U ) + GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m ) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x4m ) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x2m )