diff --git a/frame/base/bli_cntx.c b/frame/base/bli_cntx.c index 2ff56c0ba..3a8a2f0d7 100644 --- a/frame/base/bli_cntx.c +++ b/frame/base/bli_cntx.c @@ -1631,7 +1631,7 @@ void bli_cntx_set_l3_thresh_funcs( dim_t n_funcs, ... ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l3_thresh_funcs(): " ); #endif - l1vkr_t* func_ids = bli_malloc_intl( n_funcs * sizeof( opid_t ) ); + opid_t* func_ids = bli_malloc_intl( n_funcs * sizeof( opid_t ) ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l3_thresh_funcs(): " ); diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c index c31384cc4..b99b6eef2 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -104,7 +104,7 @@ void bli_cpackm_haswell_asm_3xk // ------------------------------------------------------------------------- - if ( cdim0 == mnr && !gs && !bli_does_conj( conja ) && unitk ) + if ( cdim0 == mnr && !gs && !conja && unitk ) { begin_asm() diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c index 02c894a39..4cad0c90c 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -104,7 +104,7 @@ void bli_cpackm_haswell_asm_8xk // ------------------------------------------------------------------------- - if ( cdim0 == mnr && !gs && !bli_does_conj( conja ) && unitk ) + if ( cdim0 == mnr && !gs && !conja && unitk ) { begin_asm() diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c index 26b98f4da..06fcf1438 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -104,7 +104,7 @@ void bli_zpackm_haswell_asm_3xk // ------------------------------------------------------------------------- - if ( cdim0 == mnr && !gs && !bli_does_conj( conja ) && unitk ) + if ( cdim0 == mnr && !gs && !conja && unitk ) { begin_asm() diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c index 655231754..25a8b6181 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -104,7 +104,7 @@ void bli_zpackm_haswell_asm_4xk // ------------------------------------------------------------------------- - if ( cdim0 == mnr && !gs && !bli_does_conj( conja ) && unitk ) + if ( cdim0 == mnr && !gs && !conja && unitk ) { begin_asm() diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index 315894b17..b4ac979e1 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -2224,40 +2224,24 @@ void bli_zgemm_haswell_asm_3x4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - //handling case when alpha and beta are real and +/-1. - uint64_t alpha_real_one = *((uint64_t*)(&alpha->real)); - uint64_t beta_real_one = *((uint64_t*)(&beta->real)); - - uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1); - uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1); - char alpha_mul_type = BLIS_MUL_DEFAULT; char beta_mul_type = BLIS_MUL_DEFAULT; - if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1) - { - alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1 - if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) - { - alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1 - } - } + //handling case when alpha and beta are real and +/-1. - if(beta->imag == 0)// beta is real - { - if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1) - { - beta_mul_type = BLIS_MUL_ONE; - if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) - { - beta_mul_type = BLIS_MUL_MINUS_ONE; - } - } - else if(beta_real_one == 0) - { - beta_mul_type = BLIS_MUL_ZERO; - } - } + if(alpha->imag == 0.0)// (alpha is real) + { + if(alpha->real == 1.0) alpha_mul_type = BLIS_MUL_ONE; + else if(alpha->real == -1.0) alpha_mul_type = BLIS_MUL_MINUS_ONE; + else if(alpha->real == 0.0) alpha_mul_type = BLIS_MUL_ZERO; + } + + if(beta->imag == 0.0)// (beta is real) + { + if(beta->real == 1.0) beta_mul_type = BLIS_MUL_ONE; + else if(beta->real == -1.0) beta_mul_type = BLIS_MUL_MINUS_ONE; + else if(beta->real == 0.0) beta_mul_type = BLIS_MUL_ZERO; + } begin_asm() diff --git a/kernels/zen/1f/bli_axpyf_zen_int_6.c b/kernels/zen/1f/bli_axpyf_zen_int_6.c index d27dce6cf..99b544db1 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_6.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_6.c @@ -83,10 +83,9 @@ void bli_saxpyf_zen_int_6 v8sf_t chi0v, chi1v, chi2v, chi3v; v8sf_t chi4v,chi5v; - v8sf_t a00v, a01v, a02v, a03v; - v8sf_t a04v,a05v; + v8sf_t a00v, a01v; - v8sf_t y0v, y1v; + v8sf_t y0v; float chi0, chi1, chi2, chi3; float chi4,chi5; diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c index 1e9bacd9a..64aedb879 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,57 +42,57 @@ and store outputs to ymm0 (creal,cimag)*(betar,beati) where c is stored in col major order*/ #define ZGEMM_INPUT_SCALE_CS_BETA_NZ \ - vmovupd(mem(rcx), xmm0) \ - vmovupd(mem(rcx, rsi, 1), xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ - vpermilpd(imm(0x5), ymm0, ymm3) \ - vmulpd(ymm1, ymm0, ymm0) \ - vmulpd(ymm2, ymm3, ymm3) \ - vaddsubpd(ymm3, ymm0, ymm0) + vmovupd(mem(rcx), xmm0) \ + vmovupd(mem(rcx, rsi, 1), xmm3) \ + vinsertf128(imm(1), xmm3, ymm0, ymm0) \ + vpermilpd(imm(0x5), ymm0, ymm3) \ + vmulpd(ymm1, ymm0, ymm0) \ + vmulpd(ymm2, ymm3, ymm3) \ + vaddsubpd(ymm3, ymm0, ymm0) //(creal,cimag)*(betar,beati) where c is stored in row major order #define ZGEMM_INPUT_SCALE_RS_BETA_NZ \ - vmovupd(mem(rcx), ymm0) \ - vpermilpd(imm(0x5), ymm0, ymm3) \ - vmulpd(ymm1, ymm0, ymm0) \ - vmulpd(ymm2, ymm3, ymm3) \ - vaddsubpd(ymm3, ymm0, ymm0) + vmovupd(mem(rcx), ymm0) \ + vpermilpd(imm(0x5), ymm0, ymm3) \ + vmulpd(ymm1, ymm0, ymm0) \ + vmulpd(ymm2, ymm3, ymm3) \ + vaddsubpd(ymm3, ymm0, ymm0) #define ZGEMM_INPUT_RS_BETA_ONE \ - vmovupd(mem(rcx), ymm0) + vmovupd(mem(rcx), ymm0) #define ZGEMM_OUTPUT_RS \ - vmovupd(ymm0, mem(rcx)) \ + vmovupd(ymm0, mem(rcx)) \ -/*(cNextRowreal,cNextRowimag)*(betar,beati) +/*(cNextRowreal,cNextRowimag)*(betar,beati) where c is stored in row major order rsi = cs_c * sizeof((real +imag)dt)*numofElements numofElements = 2, 2 elements are processed at a time*/ #define ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT \ - vmovupd(mem(rcx, rsi, 1), ymm0) \ - vpermilpd(imm(0x5), ymm0, ymm3) \ - vmulpd(ymm1, ymm0, ymm0) \ - vmulpd(ymm2, ymm3, ymm3) \ - vaddsubpd(ymm3, ymm0, ymm0) + vmovupd(mem(rcx, rsi, 1), ymm0) \ + vpermilpd(imm(0x5), ymm0, ymm3) \ + vmulpd(ymm1, ymm0, ymm0) \ + vmulpd(ymm2, ymm3, ymm3) \ + vaddsubpd(ymm3, ymm0, ymm0) #define ZGEMM_INPUT_RS_BETA_ONE_NEXT \ - vmovupd(mem(rcx, rsi, 1), ymm0) + vmovupd(mem(rcx, rsi, 1), ymm0) #define ZGEMM_OUTPUT_RS_NEXT \ - vmovupd(ymm0, mem(rcx, rsi, 1)) + vmovupd(ymm0, mem(rcx, rsi, 1)) /* rrr: - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : rcr: - -------- | | | | -------- - -------- += | | | | ... -------- - -------- | | | | -------- - -------- | | | | : + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : Assumptions: - B is row-stored; @@ -108,11 +108,11 @@ cost of the in-register transpose). crr: - | | | | | | | | ------ -------- - | | | | | | | | += ------ - -------- - | | | | | | | | ------ -------- - | | | | | | | | ------ : + | | | | | | | | ------ -------- + | | | | | | | | += ------ + -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : */ void bli_zgemmsup_rv_zen_asm_3x4m ( @@ -130,666 +130,650 @@ void bli_zgemmsup_rv_zen_asm_3x4m cntx_t* restrict cntx ) { - uint64_t n_left = n0 % 4; + uint64_t n_left = n0 % 4; - // First check whether this is a edge case in the n dimension. If so, - // dispatch other 3x?m kernels, as needed. - if (n_left ) - { + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 3x?m kernels, as needed. + if (n_left ) + { dcomplex* cij = c; dcomplex* bj = b; dcomplex* ai = a; - if ( 2 <= n_left ) - { - const dim_t nr_cur = 2; - - bli_zgemmsup_rv_zen_asm_3x2m - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 1 == n_left ) - { - bli_zgemv_ex - ( - BLIS_NO_TRANSPOSE, conjb, m0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, - beta, cij, rs_c0, cntx, NULL - ); - } - - return; - } - - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); - - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; - - uint64_t m_iter = m0 / 3; - uint64_t m_left = m0 % 3; - - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - - if ( m_iter == 0 ) goto consider_edge_cases; - - //handling case when alpha and beta are real and +/-1. - uint64_t alpha_real_one = *((uint64_t*)(&alpha->real)); - uint64_t beta_real_one = *((uint64_t*)(&beta->real)); - - uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1); - uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1); - - char alpha_mul_type = BLIS_MUL_DEFAULT; - char beta_mul_type = BLIS_MUL_DEFAULT; - - if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1) - { - alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1 - if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) - { - alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1 - } - } - - if(beta->imag == 0)// beta is real - { - if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1) - { - beta_mul_type = BLIS_MUL_ONE; - if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) - { - beta_mul_type = BLIS_MUL_MINUS_ONE; - } - } - else if(beta_real_one == 0) - { - beta_mul_type = BLIS_MUL_ZERO; - } - } - - // ------------------------------------------------------------------------- - - begin_asm() - - mov(var(a), r14) // load address of a. - mov(var(rs_a), r8) // load rs_a - mov(var(cs_a), r9) // load cs_a - lea(mem(, r8, 8), r8) // rs_a *= sizeof(real dt) - lea(mem(, r8, 2), r8) // rs_a *= sizeof((real + imag) dt) - lea(mem(, r9, 8), r9) // cs_a *= sizeof( real dt) - lea(mem(, r9, 2), r9) // cs_a *= sizeof((real + imag) dt) - - mov(var(rs_b), r10) // load rs_b - lea(mem(, r10, 8), r10) // rs_b *= sizeof(real dt) - lea(mem(, r10, 2), r10) // rs_b *= sizeof((real +imag) dt) - - // NOTE: We cannot pre-load elements of a or b - // because it could eventually, in the last - // unrolled iter or the cleanup loop, result - // in reading beyond the bounds allocated mem - // (the likely result: a segmentation fault). - - mov(var(c), r12) // load address of c - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) - lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) - - // During preamble and loops: - // r12 = rcx = c - // r14 = rax = a - // read rbx from var(b) near beginning of loop - // r11 = m dim index ii - - mov(var(m_iter), r11) // ii = m_iter; + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_zgemmsup_rv_zen_asm_3x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + char alpha_mul_type = BLIS_MUL_DEFAULT; + char beta_mul_type = BLIS_MUL_DEFAULT; + + //handling case when alpha and beta are real and +/-1. + + if(alpha->imag == 0.0)// (alpha is real) + { + if(alpha->real == 1.0) alpha_mul_type = BLIS_MUL_ONE; + else if(alpha->real == -1.0) alpha_mul_type = BLIS_MUL_MINUS_ONE; + else if(alpha->real == 0.0) alpha_mul_type = BLIS_MUL_ZERO; + } + + if(beta->imag == 0.0)// (beta is real) + { + if(beta->real == 1.0) beta_mul_type = BLIS_MUL_ONE; + else if(beta->real == -1.0) beta_mul_type = BLIS_MUL_MINUS_ONE; + else if(beta->real == 0.0) beta_mul_type = BLIS_MUL_ZERO; + } + + // ------------------------------------------------------------------------- - label(.ZLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(real dt) + lea(mem(, r8, 2), r8) // rs_a *= sizeof((real + imag) dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof( real dt) + lea(mem(, r9, 2), r9) // cs_a *= sizeof((real + imag) dt) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(real dt) + lea(mem(, r10, 2), r10) // rs_b *= sizeof((real +imag) dt) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). - vzeroall() // zero all xmm/ymm registers. + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) - mov(var(b), rbx) // load address of b. - mov(r14, rax) // reset rax to current upanel of a. + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii - cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLPFETCH) // jump to column storage case - label(.ZROWPFETCH) // row-stored pre-fetching on c // not used + mov(var(m_iter), r11) // ii = m_iter; - jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c - label(.ZCOLPFETCH) // column-stored pre-fetching c + label(.ZLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] - mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) - lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) - label(.ZPOSTPFETCH) // done prefetching c + vzeroall() // zero all xmm/ymm registers. - mov(var(k_iter), rsi) // i = k_iter; - test(rsi, rsi) // check i via logical AND. - je(.ZCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. - label(.ZLOOPKITER) // MAIN LOOP + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.ZCOLPFETCH) // jump to column storage case + label(.ZROWPFETCH) // row-stored pre-fetching on c // not used - // ---------------------------------- iteration 0 + jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c + label(.ZCOLPFETCH) // column-stored pre-fetching c - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) + label(.ZPOSTPFETCH) // done prefetching c - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.ZCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) + label(.ZLOOPKITER) // MAIN LOOP - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + // ---------------------------------- iteration 0 - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) - add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) - // ---------------------------------- iteration 1 + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) + add(r9, rax) // a += cs_a; - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + // ---------------------------------- iteration 1 - vbroadcastsd(mem(rax, 8 ), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) - add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) - // ---------------------------------- iteration 2 + vbroadcastsd(mem(rax, 8 ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) + add(r9, rax) // a += cs_a; - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + // ---------------------------------- iteration 2 - vbroadcastsd(mem(rax, 8 ), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) - add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) - // ---------------------------------- iteration 3 - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; + vbroadcastsd(mem(rax, 8 ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + add(r9, rax) // a += cs_a; - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) + // ---------------------------------- iteration 3 + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) - add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) - dec(rsi) // i -= 1; - jne(.ZLOOPKITER) // iterate again if i != 0. + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) - label(.ZCONSIDKLEFT) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) - mov(var(k_left), rsi) // i = k_left; - test(rsi, rsi) // check i via logical AND. - je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) - label(.ZLOOPKLEFT) // EDGE LOOP + add(r9, rax) // a += cs_a; - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; + dec(rsi) // i -= 1; + jne(.ZLOOPKITER) // iterate again if i != 0. - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) + label(.ZCONSIDKLEFT) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + label(.ZLOOPKLEFT) // EDGE LOOP - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) - add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) - dec(rsi) // i -= 1; - jne(.ZLOOPKLEFT) // iterate again if i != 0. + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) - label(.ZPOSTACCUM) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) - mov(r12, rcx) // reset rcx to current utile of c. + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) - // permute even and odd elements - // of ymm6/7, ymm10/11, ymm/14/15 - vpermilpd(imm(0x5), ymm6, ymm6) - vpermilpd(imm(0x5), ymm7, ymm7) - vpermilpd(imm(0x5), ymm10, ymm10) - vpermilpd(imm(0x5), ymm11, ymm11) - vpermilpd(imm(0x5), ymm14, ymm14) - vpermilpd(imm(0x5), ymm15, ymm15) + add(r9, rax) // a += cs_a; - // subtract/add even/odd elements - vaddsubpd(ymm6, ymm4, ymm4) - vaddsubpd(ymm7, ymm5, ymm5) + dec(rsi) // i -= 1; + jne(.ZLOOPKLEFT) // iterate again if i != 0. - vaddsubpd(ymm10, ymm8, ymm8) - vaddsubpd(ymm11, ymm9, ymm9) + label(.ZPOSTACCUM) - vaddsubpd(ymm14, ymm12, ymm12) - vaddsubpd(ymm15, ymm13, ymm13) + mov(r12, rcx) // reset rcx to current utile of c. - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilpd(imm(0x5), ymm6, ymm6) + vpermilpd(imm(0x5), ymm7, ymm7) + vpermilpd(imm(0x5), ymm10, ymm10) + vpermilpd(imm(0x5), ymm11, ymm11) + vpermilpd(imm(0x5), ymm14, ymm14) + vpermilpd(imm(0x5), ymm15, ymm15) - //if(alpha_mul_type == BLIS_MUL_MINUS_ONE) - mov(var(alpha_mul_type), al) - cmp(imm(0xFF), al) - jne(.ALPHA_NOT_MINUS1) + // subtract/add even/odd elements + vaddsubpd(ymm6, ymm4, ymm4) + vaddsubpd(ymm7, ymm5, ymm5) - // when alpha = -1 and real. - vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. - vsubpd(ymm4, ymm0, ymm4) - vsubpd(ymm5, ymm0, ymm5) - vsubpd(ymm8, ymm0, ymm8) - vsubpd(ymm9, ymm0, ymm9) - vsubpd(ymm12, ymm0, ymm12) - vsubpd(ymm13, ymm0, ymm13) - jmp(.ALPHA_REAL_ONE) - - label(.ALPHA_NOT_MINUS1) - //when alpha is real and +/-1, multiplication is skipped. - cmp(imm(2), al)//if(alpha_mul_type != BLIS_MUL_DEFAULT) skip below multiplication. - jne(.ALPHA_REAL_ONE) + vaddsubpd(ymm10, ymm8, ymm8) + vaddsubpd(ymm11, ymm9, ymm9) - /* (ar + ai) x AB */ - mov(var(alpha), rax) // load address of alpha - vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate - vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + vaddsubpd(ymm14, ymm12, ymm12) + vaddsubpd(ymm15, ymm13, ymm13) - vpermilpd(imm(0x5), ymm4, ymm3) - vmulpd(ymm0, ymm4, ymm4) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm4, ymm4) - - vpermilpd(imm(0x5), ymm5, ymm3) - vmulpd(ymm0, ymm5, ymm5) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm5, ymm5) - - vpermilpd(imm(0x5), ymm8, ymm3) - vmulpd(ymm0, ymm8, ymm8) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm8, ymm8) - - vpermilpd(imm(0x5), ymm9, ymm3) - vmulpd(ymm0, ymm9, ymm9) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm9, ymm9) - - vpermilpd(imm(0x5), ymm12, ymm3) - vmulpd(ymm0, ymm12, ymm12) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm12, ymm12) - - vpermilpd(imm(0x5), ymm13, ymm3) - vmulpd(ymm0, ymm13, ymm13) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm13, ymm13) - - label(.ALPHA_REAL_ONE) - // Beta multiplication - /* (br + bi)x C + ((ar + ai) x AB) */ - - mov(var(beta_mul_type), al) - cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO) - je(.ZBETAZERO) //jump to beta == 0 case - - cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16. - jz(.ZCOLSTORED) // jump to column storage case - - label(.ZROWSTORED) - - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) * numofElements - - cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT) - je(.ROW_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication. - - cmp(imm(0xFF), al) // if(beta_mul_type == BLIS_MUL_MINUS_ONE) - je(.ROW_BETA_REAL_MINUS1) // jump to beta real = -1 section. - - //CASE 1: beta is real = 1 - ZGEMM_INPUT_RS_BETA_ONE - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_RS_BETA_ONE_NEXT - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_RS_NEXT - add(rdi, rcx) // rcx = c + 1*rs_c - - ZGEMM_INPUT_RS_BETA_ONE - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_RS_BETA_ONE_NEXT - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_RS_NEXT - add(rdi, rcx) // rcx = c + 2*rs_c - - ZGEMM_INPUT_RS_BETA_ONE - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_RS_BETA_ONE_NEXT - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_RS_NEXT - jmp(.ZDONE) - - - //CASE 2: beta is real = -1 - label(.ROW_BETA_REAL_MINUS1) - ZGEMM_INPUT_RS_BETA_ONE - vsubpd(ymm0, ymm4, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_RS_BETA_ONE_NEXT - vsubpd(ymm0, ymm5, ymm0) - ZGEMM_OUTPUT_RS_NEXT - add(rdi, rcx) // rcx = c + 1*rs_c - - ZGEMM_INPUT_RS_BETA_ONE - vsubpd(ymm0, ymm8, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_RS_BETA_ONE_NEXT - vsubpd(ymm0, ymm9, ymm0) - ZGEMM_OUTPUT_RS_NEXT - add(rdi, rcx) // rcx = c + 2*rs_c - - ZGEMM_INPUT_RS_BETA_ONE - vsubpd(ymm0, ymm12, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_RS_BETA_ONE_NEXT - vsubpd(ymm0, ymm13, ymm0) - ZGEMM_OUTPUT_RS_NEXT - jmp(.ZDONE) - - - //CASE 3: Default case with multiplication - // beta not equal to (+/-1) or zero, do normal multiplication. - label(.ROW_BETA_NOT_REAL_ONE) - mov(var(beta), rbx) // load address of beta - vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate - vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_RS_NEXT - add(rdi, rcx) // rcx = c + 1*rs_c - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_RS_NEXT - add(rdi, rcx) // rcx = c + 2*rs_c - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_RS_NEXT - jmp(.ZDONE) // jump to end. - - label(.ZCOLSTORED) - mov(var(beta), rbx) // load address of beta - vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate - vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - /*|--------| |-------| - | | | | - | 3x4 | | 4x3 | - |--------| |-------| - */ - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm4, ymm0, ymm4) - - add(rdi, rcx) - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm8, ymm0, ymm8) - add(rdi, rcx) - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm12, ymm0, ymm12) - - lea(mem(r12, rsi, 2), rcx) - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm5, ymm0, ymm5) - add(rdi, rcx) - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm9, ymm0, ymm9) - add(rdi, rcx) - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm13, ymm0, ymm13) - - mov(r12, rcx) // reset rcx to current utile of c. - - - /****3x4 tile going to save into 4x3 tile in C*****/ - - /******************Transpose top tile 4x3***************************/ - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) - - add(rsi, rcx) - - vextractf128(imm(0x1), ymm4, xmm4) - vextractf128(imm(0x1), ymm8, xmm8) - vextractf128(imm(0x1), ymm12, xmm12) - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) - - add(rsi, rcx) - - vmovups(xmm5, mem(rcx)) - vmovups(xmm9, mem(rcx, 16)) - vmovups(xmm13,mem(rcx,32)) - - add(rsi, rcx) - - vextractf128(imm(0x1), ymm5, xmm5) - vextractf128(imm(0x1), ymm9, xmm9) - vextractf128(imm(0x1), ymm13, xmm13) - vmovups(xmm5, mem(rcx)) - vmovups(xmm9, mem(rcx, 16)) - vmovups(xmm13,mem(rcx,32)) - - jmp(.ZDONE) // jump to end. - - label(.ZBETAZERO) - cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLSTORBZ) // jump to column storage case - - label(.ZROWSTORBZ) - /* Store 3x4 elements to C matrix where is C row major order*/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) + + //if(alpha_mul_type == BLIS_MUL_MINUS_ONE) + mov(var(alpha_mul_type), al) + cmp(imm(0xFF), al) + jne(.ALPHA_NOT_MINUS1) + + // when alpha = -1 and real. + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vsubpd(ymm4, ymm0, ymm4) + vsubpd(ymm5, ymm0, ymm5) + vsubpd(ymm8, ymm0, ymm8) + vsubpd(ymm9, ymm0, ymm9) + vsubpd(ymm12, ymm0, ymm12) + vsubpd(ymm13, ymm0, ymm13) + jmp(.ALPHA_REAL_ONE) + + label(.ALPHA_NOT_MINUS1) + //when alpha is real and +/-1, multiplication is skipped. + cmp(imm(2), al)//if(alpha_mul_type != BLIS_MUL_DEFAULT) skip below multiplication. + jne(.ALPHA_REAL_ONE) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + + vpermilpd(imm(0x5), ymm4, ymm3) + vmulpd(ymm0, ymm4, ymm4) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm4, ymm4) + + vpermilpd(imm(0x5), ymm5, ymm3) + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm5, ymm5) + + vpermilpd(imm(0x5), ymm8, ymm3) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm8, ymm8) + + vpermilpd(imm(0x5), ymm9, ymm3) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm9, ymm9) + + vpermilpd(imm(0x5), ymm12, ymm3) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm12, ymm12) + + vpermilpd(imm(0x5), ymm13, ymm3) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm13, ymm13) + + label(.ALPHA_REAL_ONE) + // Beta multiplication + /* (br + bi)x C + ((ar + ai) x AB) */ + + mov(var(beta_mul_type), al) + cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO) + je(.ZBETAZERO) //jump to beta == 0 case + + cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16. + jz(.ZCOLSTORED) // jump to column storage case + + label(.ZROWSTORED) + + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) * numofElements + + cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT) + je(.ROW_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication. + + cmp(imm(0xFF), al) // if(beta_mul_type == BLIS_MUL_MINUS_ONE) + je(.ROW_BETA_REAL_MINUS1) // jump to beta real = -1 section. + + //CASE 1: beta is real = 1 + ZGEMM_INPUT_RS_BETA_ONE + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vaddpd(ymm5, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_RS_BETA_ONE + vaddpd(ymm8, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vaddpd(ymm9, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 2*rs_c + + ZGEMM_INPUT_RS_BETA_ONE + vaddpd(ymm12, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vaddpd(ymm13, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + jmp(.ZDONE) + + + //CASE 2: beta is real = -1 + label(.ROW_BETA_REAL_MINUS1) + ZGEMM_INPUT_RS_BETA_ONE + vsubpd(ymm0, ymm4, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vsubpd(ymm0, ymm5, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_RS_BETA_ONE + vsubpd(ymm0, ymm8, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vsubpd(ymm0, ymm9, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 2*rs_c + + ZGEMM_INPUT_RS_BETA_ONE + vsubpd(ymm0, ymm12, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vsubpd(ymm0, ymm13, ymm0) + ZGEMM_OUTPUT_RS_NEXT + jmp(.ZDONE) + + + //CASE 3: Default case with multiplication + // beta not equal to (+/-1) or zero, do normal multiplication. + label(.ROW_BETA_NOT_REAL_ONE) + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddpd(ymm5, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm8, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddpd(ymm9, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 2*rs_c + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm12, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddpd(ymm13, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + jmp(.ZDONE) // jump to end. + + label(.ZCOLSTORED) + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + /*|--------| |-------| + | | | | + | 3x4 | | 4x3 | + |--------| |-------| + */ + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm4, ymm0, ymm4) + + add(rdi, rcx) + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm8, ymm0, ymm8) + add(rdi, rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm12, ymm0, ymm12) + + lea(mem(r12, rsi, 2), rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm5, ymm0, ymm5) + add(rdi, rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm9, ymm0, ymm9) + add(rdi, rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm13, ymm0, ymm13) + + mov(r12, rcx) // reset rcx to current utile of c. + + + /****3x4 tile going to save into 4x3 tile in C*****/ + + /******************Transpose top tile 4x3***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + + add(rsi, rcx) + + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + vmovups(xmm13,mem(rcx,32)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm5, xmm5) + vextractf128(imm(0x1), ymm9, xmm9) + vextractf128(imm(0x1), ymm13, xmm13) + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + vmovups(xmm13,mem(rcx,32)) + + jmp(.ZDONE) // jump to end. + + label(.ZBETAZERO) + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.ZCOLSTORBZ) // jump to column storage case + + label(.ZROWSTORBZ) + /* Store 3x4 elements to C matrix where is C row major order*/ // rsi = cs_c * sizeof((real +imag)dt) *numofElements - lea(mem(, rsi, 2), rsi) + lea(mem(, rsi, 2), rsi) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rsi, 1)) - add(rdi, rcx) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 1)) + add(rdi, rcx) - vmovupd(ymm8, mem(rcx)) - vmovupd(ymm9, mem(rcx, rsi, 1)) - add(rdi, rcx) + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 1)) + add(rdi, rcx) - vmovupd(ymm12, mem(rcx)) - vmovupd(ymm13, mem(rcx, rsi, 1)) + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx, rsi, 1)) - jmp(.ZDONE) // jump to end. + jmp(.ZDONE) // jump to end. - label(.ZCOLSTORBZ) + label(.ZCOLSTORBZ) - /****3x4 tile going to save into 4x3 tile in C*****/ + /****3x4 tile going to save into 4x3 tile in C*****/ - /******************Transpose top tile 4x3***************************/ - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) + /******************Transpose top tile 4x3***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) - add(rsi, rcx) + add(rsi, rcx) - vextractf128(imm(0x1), ymm4, xmm4) - vextractf128(imm(0x1), ymm8, xmm8) - vextractf128(imm(0x1), ymm12, xmm12) - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) - add(rsi, rcx) - - vmovups(xmm5, mem(rcx)) - vmovups(xmm9, mem(rcx, 16)) - vmovups(xmm13,mem(rcx,32)) - - add(rsi, rcx) - - vextractf128(imm(0x1), ymm5, xmm5) - vextractf128(imm(0x1), ymm9, xmm9) - vextractf128(imm(0x1), ymm13, xmm13) - vmovups(xmm5, mem(rcx)) - vmovups(xmm9, mem(rcx, 16)) - vmovups(xmm13,mem(rcx,32)) + add(rsi, rcx) - label(.ZDONE) + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + vmovups(xmm13,mem(rcx,32)) - lea(mem(r12, rdi, 2), r12) - lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + add(rsi, rcx) - lea(mem(r14, r8, 2), r14) - lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a + vextractf128(imm(0x1), ymm5, xmm5) + vextractf128(imm(0x1), ymm9, xmm9) + vextractf128(imm(0x1), ymm13, xmm13) + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + vmovups(xmm13,mem(rcx,32)) - dec(r11) // ii -= 1; - jne(.ZLOOP3X4I) // iterate again if ii != 0. + label(.ZDONE) - label(.ZRETURN) + lea(mem(r12, rdi, 2), r12) + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c - end_asm( - : // output operands (none) - : // input operands + lea(mem(r14, r8, 2), r14) + lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a + + dec(r11) // ii -= 1; + jne(.ZLOOP3X4I) // iterate again if ii != 0. + + label(.ZRETURN) + + end_asm( + : // output operands (none) + : // input operands [alpha_mul_type] "m" (alpha_mul_type), [beta_mul_type] "m" (beta_mul_type), [m_iter] "m" (m_iter), @@ -807,46 +791,46 @@ void bli_zgemmsup_rv_zen_asm_3x4m [cs_c] "m" (cs_c)/*, [a_next] "m" (a_next), [b_next] "m" (b_next)*/ - : // register clobber list - "rax", "rbx", "rcx", "rsi", "rdi", - "r8", "r9", "r10", "r11", "r12", "r14", "r15", - "xmm0", "xmm1", "xmm2", "xmm3", - "xmm4", "xmm5", "xmm6", "xmm7", - "xmm8", "xmm9", "xmm10", "xmm11", - "xmm12", "xmm13", "xmm14", "xmm15", - "memory" - ) + : // register clobber list + "rax", "rbx", "rcx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) - consider_edge_cases: + consider_edge_cases: - // Handle edge cases in the m dimension, if they exist. - if ( m_left ) - { - const dim_t nr_cur = 4; - const dim_t i_edge = m0 - ( dim_t )m_left; + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; dcomplex* cij = c + i_edge*rs_c; dcomplex* ai = a + i_edge*rs_a; dcomplex* bj = b; - zgemmsup_ker_ft ker_fps[3] = - { - NULL, - bli_zgemmsup_rv_zen_asm_1x4, - bli_zgemmsup_rv_zen_asm_2x4, - }; + zgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_zgemmsup_rv_zen_asm_1x4, + bli_zgemmsup_rv_zen_asm_2x4, + }; - zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; - ker_fp - ( - conja, conjb, m_left, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - return; + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; - } + } } @@ -867,393 +851,393 @@ void bli_zgemmsup_rv_zen_asm_3x2m ) { - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; - uint64_t m_iter = m0 / 3; - uint64_t m_left = m0 % 3; + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; - if ( m_iter == 0 ) goto consider_edge_cases; + if ( m_iter == 0 ) goto consider_edge_cases; - // ------------------------------------------------------------------------- + // ------------------------------------------------------------------------- - begin_asm() + begin_asm() - mov(var(a), r14) // load address of a. - mov(var(rs_a), r8) // load rs_a - mov(var(cs_a), r9) // load cs_a - lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) - lea(mem(, r8, 2), r8) // rs_a *= sizeof(dt) - lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) - lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r8, 2), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) - mov(var(rs_b), r10) // load rs_b - lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) - lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt) + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt) - // NOTE: We cannot pre-load elements of a or b - // because it could eventually, in the last - // unrolled iter or the cleanup loop, result - // in reading beyond the bounds allocated mem - // (the likely result: a segmentation fault). + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). - mov(var(c), r12) // load address of c - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) - lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) - // During preamble and loops: - // r12 = rcx = c - // r14 = rax = a - // read rbx from var(b) near beginning of loop - // r11 = m dim index ii + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii - mov(var(m_iter), r11) // ii = m_iter; + mov(var(m_iter), r11) // ii = m_iter; - label(.ZLOOP3X2I) // LOOP OVER ii = [ m_iter ... 1 0 ] + label(.ZLOOP3X2I) // LOOP OVER ii = [ m_iter ... 1 0 ] - vzeroall() // zero all xmm/ymm registers. + vzeroall() // zero all xmm/ymm registers. - mov(var(b), rbx) // load address of b. - mov(r14, rax) // reset rax to current upanel of a. + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. - cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLPFETCH) // jump to column storage case - label(.ZROWPFETCH) // row-stored pre-fetching on c // not used + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.ZCOLPFETCH) // jump to column storage case + label(.ZROWPFETCH) // row-stored pre-fetching on c // not used - jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c - label(.ZCOLPFETCH) // column-stored pre-fetching c + jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c + label(.ZCOLPFETCH) // column-stored pre-fetching c - mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) - lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) - label(.ZPOSTPFETCH) // done prefetching c + label(.ZPOSTPFETCH) // done prefetching c - mov(var(k_iter), rsi) // i = k_iter; - test(rsi, rsi) // check i via logical AND. - je(.ZCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.ZCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. - label(.ZLOOPKITER) // MAIN LOOP + label(.ZLOOPKITER) // MAIN LOOP - // ---------------------------------- iteration 0 + // ---------------------------------- iteration 0 - vmovupd(mem(rbx, 0*32), ymm0) - add(r10, rbx) // b += rs_b; + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) - add(r9, rax) // a += cs_a; + add(r9, rax) // a += cs_a; - // ---------------------------------- iteration 1 + // ---------------------------------- iteration 1 - vmovupd(mem(rbx, 0*32), ymm0) - add(r10, rbx) // b += rs_b; + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) - add(r9, rax) // a += cs_a; + add(r9, rax) // a += cs_a; - // ---------------------------------- iteration 2 + // ---------------------------------- iteration 2 - vmovupd(mem(rbx, 0*32), ymm0) - add(r10, rbx) // b += rs_b; + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) - add(r9, rax) // a += cs_a; + add(r9, rax) // a += cs_a; - // ---------------------------------- iteration 3 - vmovupd(mem(rbx, 0*32), ymm0) - add(r10, rbx) // b += rs_b; + // ---------------------------------- iteration 3 + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) - add(r9, rax) // a += cs_a; + add(r9, rax) // a += cs_a; - dec(rsi) // i -= 1; - jne(.ZLOOPKITER) // iterate again if i != 0. + dec(rsi) // i -= 1; + jne(.ZLOOPKITER) // iterate again if i != 0. - label(.ZCONSIDKLEFT) + label(.ZCONSIDKLEFT) - mov(var(k_left), rsi) // i = k_left; - test(rsi, rsi) // check i via logical AND. - je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. - label(.ZLOOPKLEFT) // EDGE LOOP + label(.ZLOOPKLEFT) // EDGE LOOP - vmovupd(mem(rbx, 0*32), ymm0) - add(r10, rbx) // b += rs_b; + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) - add(r9, rax) // a += cs_a; + add(r9, rax) // a += cs_a; - dec(rsi) // i -= 1; - jne(.ZLOOPKLEFT) // iterate again if i != 0. + dec(rsi) // i -= 1; + jne(.ZLOOPKLEFT) // iterate again if i != 0. - label(.ZPOSTACCUM) + label(.ZPOSTACCUM) - mov(r12, rcx) // reset rcx to current utile of c. + mov(r12, rcx) // reset rcx to current utile of c. - // permute even and odd elements - // of ymm6/7, ymm10/11, ymm/14/15 - vpermilpd(imm(0x5), ymm6, ymm6) - vpermilpd(imm(0x5), ymm10, ymm10) - vpermilpd(imm(0x5), ymm14, ymm14) + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilpd(imm(0x5), ymm6, ymm6) + vpermilpd(imm(0x5), ymm10, ymm10) + vpermilpd(imm(0x5), ymm14, ymm14) - // subtract/add even/odd elements - vaddsubpd(ymm6, ymm4, ymm4) - vaddsubpd(ymm10, ymm8, ymm8) - vaddsubpd(ymm14, ymm12, ymm12) + // subtract/add even/odd elements + vaddsubpd(ymm6, ymm4, ymm4) + vaddsubpd(ymm10, ymm8, ymm8) + vaddsubpd(ymm14, ymm12, ymm12) - /* (ar + ai) x AB */ - mov(var(alpha), rax) // load address of alpha - vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate - vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate - vpermilpd(imm(0x5), ymm4, ymm3) - vmulpd(ymm0, ymm4, ymm4) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm4, ymm4) + vpermilpd(imm(0x5), ymm4, ymm3) + vmulpd(ymm0, ymm4, ymm4) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm4, ymm4) - vpermilpd(imm(0x5), ymm8, ymm3) - vmulpd(ymm0, ymm8, ymm8) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm8, ymm8) + vpermilpd(imm(0x5), ymm8, ymm3) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm8, ymm8) - vpermilpd(imm(0x5), ymm12, ymm3) - vmulpd(ymm0, ymm12, ymm12) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm12, ymm12) + vpermilpd(imm(0x5), ymm12, ymm3) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm12, ymm12) - /* (br + bi)x C + ((ar + ai) x AB) */ - mov(var(beta), rbx) // load address of beta - vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate - vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + /* (br + bi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - // now avoid loading C if beta == 0 - vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. - vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. - sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); - vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. - sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); - and(r13b, r15b) // set ZF if r13b & r15b == 1. - jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case - cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLSTORED) // jump to column storage case + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.ZCOLSTORED) // jump to column storage case - label(.ZROWSTORED) + label(.ZROWSTORED) - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_RS + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS - add(rdi, rcx) // rcx = c + 1*rs_c + add(rdi, rcx) // rcx = c + 1*rs_c - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_RS + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm8, ymm0, ymm0) + ZGEMM_OUTPUT_RS - add(rdi, rcx) // rcx = c + 2*rs_c + add(rdi, rcx) // rcx = c + 2*rs_c - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_RS + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm12, ymm0, ymm0) + ZGEMM_OUTPUT_RS - jmp(.ZDONE) // jump to end. + jmp(.ZDONE) // jump to end. - label(.ZCOLSTORED) - /*|--------| |-------| - | | | | - | 3x2 | | 2x3 | - |--------| |-------| - */ + label(.ZCOLSTORED) + /*|--------| |-------| + | | | | + | 3x2 | | 2x3 | + |--------| |-------| + */ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm4, ymm0, ymm4) + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) - add(rdi, rcx) - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm8, ymm0, ymm8) - add(rdi, rcx) - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm12, ymm0, ymm12) + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm4, ymm0, ymm4) - mov(r12, rcx) // reset rcx to current utile of c. + add(rdi, rcx) + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm8, ymm0, ymm8) + add(rdi, rcx) - /****3x2 tile going to save into 2x3 tile in C*****/ + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm12, ymm0, ymm12) - /******************Transpose top tile 2x3***************************/ - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) + mov(r12, rcx) // reset rcx to current utile of c. - add(rsi, rcx) + /****3x2 tile going to save into 2x3 tile in C*****/ - vextractf128(imm(0x1), ymm4, xmm4) - vextractf128(imm(0x1), ymm8, xmm8) - vextractf128(imm(0x1), ymm12, xmm12) - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) + /******************Transpose top tile 2x3***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + add(rsi, rcx) - jmp(.ZDONE) // jump to end. + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) - label(.ZBETAZERO) - cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8. - jz(.ZCOLSTORBZ) // jump to column storage case + jmp(.ZDONE) // jump to end. - label(.ZROWSTORBZ) + label(.ZBETAZERO) - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - - vmovupd(ymm12, mem(rcx)) + cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8. + jz(.ZCOLSTORBZ) // jump to column storage case - jmp(.ZDONE) // jump to end. + label(.ZROWSTORBZ) - label(.ZCOLSTORBZ) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) - /****3x2 tile going to save into 2x3 tile in C*****/ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) - /******************Transpose tile 3x2***************************/ - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) + vmovupd(ymm12, mem(rcx)) - add(rsi, rcx) + jmp(.ZDONE) // jump to end. - vextractf128(imm(0x1), ymm4, xmm4) - vextractf128(imm(0x1), ymm8, xmm8) - vextractf128(imm(0x1), ymm12, xmm12) - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) + label(.ZCOLSTORBZ) - label(.ZDONE) + /****3x2 tile going to save into 2x3 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) - lea(mem(r12, rdi, 2), r12) - lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + /******************Transpose tile 3x2***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) - lea(mem(r14, r8, 2), r14) - lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a + add(rsi, rcx) - dec(r11) // ii -= 1; - jne(.ZLOOP3X2I) // iterate again if ii != 0. + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) - label(.ZRETURN) + label(.ZDONE) - end_asm( - : // output operands (none) - : // input operands + lea(mem(r12, rdi, 2), r12) + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) + lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a + + dec(r11) // ii -= 1; + jne(.ZLOOP3X2I) // iterate again if ii != 0. + + label(.ZRETURN) + + end_asm( + : // output operands (none) + : // input operands [m_iter] "m" (m_iter), [k_iter] "m" (k_iter), [k_left] "m" (k_left), @@ -1269,43 +1253,43 @@ void bli_zgemmsup_rv_zen_asm_3x2m [cs_c] "m" (cs_c)/*, [a_next] "m" (a_next), [b_next] "m" (b_next)*/ - : // register clobber list - "rax", "rbx", "rcx", "rsi", "rdi", - "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", - "xmm0", "xmm1", "xmm2", "xmm3", - "xmm4", "xmm5", "xmm6", "xmm7", - "xmm8", "xmm9", "xmm10", "xmm11", - "xmm12", "xmm13", "xmm14", "xmm15", - "memory" - ) + : // register clobber list + "rax", "rbx", "rcx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) - consider_edge_cases: + consider_edge_cases: - // Handle edge cases in the m dimension, if they exist. - if ( m_left ) - { - const dim_t nr_cur = 4; - const dim_t i_edge = m0 - ( dim_t )m_left; + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; - dcomplex* cij = c + i_edge*rs_c; - dcomplex* ai = a + i_edge*rs_a; - dcomplex* bj = b; + dcomplex* cij = c + i_edge*rs_c; + dcomplex* ai = a + i_edge*rs_a; + dcomplex* bj = b; - zgemmsup_ker_ft ker_fps[3] = - { - NULL, - bli_zgemmsup_rv_zen_asm_1x2, - bli_zgemmsup_rv_zen_asm_2x2, - }; + zgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_zgemmsup_rv_zen_asm_1x2, + bli_zgemmsup_rv_zen_asm_2x2, + }; - zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; - ker_fp - ( - conja, conjb, m_left, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - return; - } -} \ No newline at end of file + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + } +} diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c index 44b43e741..b12f67ca9 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,16 +37,16 @@ /* rrr: - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : rcr: - -------- | | | | -------- - -------- += | | | | ... -------- - -------- | | | | -------- - -------- | | | | : + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : Assumptions: - B is row-stored; @@ -62,10 +62,10 @@ cost of the in-register transpose). crr: - | | | | | | | | ------ -------- - | | | | | | | | += ------ ... -------- - | | | | | | | | ------ -------- - | | | | | | | | ------ : + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : */ void bli_zgemmsup_rv_zen_asm_3x4n ( @@ -83,452 +83,436 @@ void bli_zgemmsup_rv_zen_asm_3x4n cntx_t* restrict cntx ) { - uint64_t m_left = m0 % 3; - if ( m_left ) - { - zgemmsup_ker_ft ker_fps[3] = - { - NULL, - bli_zgemmsup_rv_zen_asm_1x4n, - bli_zgemmsup_rv_zen_asm_2x4n, - }; - zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; - ker_fp - ( - conja, conjb, m_left, n0, k0, - alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, - beta, c, rs_c0, cs_c0, data, cntx - ); - return; - } - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); + uint64_t m_left = m0 % 3; + if ( m_left ) + { + zgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_zgemmsup_rv_zen_asm_1x4n, + bli_zgemmsup_rv_zen_asm_2x4n, + }; + zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + ker_fp + ( + conja, conjb, m_left, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; + } + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. - uint64_t k_iter = 0; + uint64_t k_iter = 0; - uint64_t n_iter = n0 / 4; - uint64_t n_left = n0 % 4; + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; - if ( n_iter == 0 ) goto consider_edge_cases; + if ( n_iter == 0 ) goto consider_edge_cases; - //handling case when alpha and beta are real and +/-1. - uint64_t alpha_real_one = *((uint64_t*)(&alpha->real)); - uint64_t beta_real_one = *((uint64_t*)(&beta->real)); + char alpha_mul_type = BLIS_MUL_DEFAULT; + char beta_mul_type = BLIS_MUL_DEFAULT; - uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1); - uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1); + //handling case when alpha and beta are real and +/-1. - char alpha_mul_type = BLIS_MUL_DEFAULT; - char beta_mul_type = BLIS_MUL_DEFAULT; + if(alpha->imag == 0.0)// (alpha is real) + { + if(alpha->real == 1.0) alpha_mul_type = BLIS_MUL_ONE; + else if(alpha->real == -1.0) alpha_mul_type = BLIS_MUL_MINUS_ONE; + else if(alpha->real == 0.0) alpha_mul_type = BLIS_MUL_ZERO; + } - if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1) - { - alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1 - if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) - { - alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1 - } - } + if(beta->imag == 0.0)// (beta is real) + { + if(beta->real == 1.0) beta_mul_type = BLIS_MUL_ONE; + else if(beta->real == -1.0) beta_mul_type = BLIS_MUL_MINUS_ONE; + else if(beta->real == 0.0) beta_mul_type = BLIS_MUL_ZERO; + } - if(beta->imag == 0)// beta is real - { - if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1) - { - beta_mul_type = BLIS_MUL_ONE; - if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) - { - beta_mul_type = BLIS_MUL_MINUS_ONE; - } - } - else if(beta_real_one == 0) - { - beta_mul_type = BLIS_MUL_ZERO; - } - } + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m128d xmm0, xmm3; - // ------------------------------------------------------------------------- - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m128d xmm0, xmm3; + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 4; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); - dcomplex *tA = a; - double *tAimag = &a->imag; - dcomplex *tB = b; - dcomplex *tC = c; - for (n_iter = 0; n_iter < n0 / 4; n_iter++) - { - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - dim_t ta_inc_row = rs_a; - dim_t tb_inc_row = rs_b; - dim_t tc_inc_row = rs_c; + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; - dim_t ta_inc_col = cs_a; - dim_t tb_inc_col = cs_b; - dim_t tc_inc_col = cs_c; + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; - tA = a; - tAimag = &a->imag; - tB = b + n_iter*tb_inc_col*4; - tC = c + n_iter*tc_inc_col*4; - for (k_iter = 0; k_iter imag; + tB = b + n_iter*tb_inc_col*4; + tC = c + n_iter*tc_inc_col*4; + for (k_iter = 0; k_iter real == -1.0) - { - ymm0 = _mm256_setzero_pd(); - ymm4 = _mm256_sub_pd(ymm0,ymm4); - ymm5 = _mm256_sub_pd(ymm0, ymm5); - ymm8 = _mm256_sub_pd(ymm0, ymm8); - ymm9 = _mm256_sub_pd(ymm0, ymm9); - ymm12 = _mm256_sub_pd(ymm0, ymm12); - ymm13 = _mm256_sub_pd(ymm0, ymm13); - } + ymm12 = _mm256_addsub_pd( ymm12, ymm14); + ymm13 = _mm256_addsub_pd( ymm13, ymm15); - //when alpha is real and +/-1, multiplication is skipped. - if(alpha_mul_type == BLIS_MUL_DEFAULT) - { - // alpha, beta multiplication. - /* (ar + ai) x AB */ - ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate - ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate + //When alpha_real = -1.0, instead of multiplying with -1, sign is changed. + if(alpha_mul_type == BLIS_MUL_MINUS_ONE)// equivalent to if(alpha->real == -1.0) + { + ymm0 = _mm256_setzero_pd(); + ymm4 = _mm256_sub_pd(ymm0,ymm4); + ymm5 = _mm256_sub_pd(ymm0, ymm5); + ymm8 = _mm256_sub_pd(ymm0, ymm8); + ymm9 = _mm256_sub_pd(ymm0, ymm9); + ymm12 = _mm256_sub_pd(ymm0, ymm12); + ymm13 = _mm256_sub_pd(ymm0, ymm13); + } - ymm3 = _mm256_permute_pd(ymm4, 5); - ymm4 = _mm256_mul_pd(ymm0, ymm4); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_addsub_pd(ymm4, ymm3); + //when alpha is real and +/-1, multiplication is skipped. + if(alpha_mul_type == BLIS_MUL_DEFAULT) + { + // alpha, beta multiplication. + /* (ar + ai) x AB */ + ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate - ymm3 = _mm256_permute_pd(ymm5, 5); - ymm5 = _mm256_mul_pd(ymm0, ymm5); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_addsub_pd(ymm5, ymm3); + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); - ymm3 = _mm256_permute_pd(ymm8, 5); - ymm8 = _mm256_mul_pd(ymm0, ymm8); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_addsub_pd(ymm8, ymm3); + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); - ymm3 = _mm256_permute_pd(ymm9, 5); - ymm9 = _mm256_mul_pd(ymm0, ymm9); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm9 = _mm256_addsub_pd(ymm9, ymm3); + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); - ymm3 = _mm256_permute_pd(ymm12, 5); - ymm12 = _mm256_mul_pd(ymm0, ymm12); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm12 = _mm256_addsub_pd(ymm12, ymm3); + ymm3 = _mm256_permute_pd(ymm9, 5); + ymm9 = _mm256_mul_pd(ymm0, ymm9); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_addsub_pd(ymm9, ymm3); - ymm3 = _mm256_permute_pd(ymm13, 5); - ymm13 = _mm256_mul_pd(ymm0, ymm13); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm13 = _mm256_addsub_pd(ymm13, ymm3); - } + ymm3 = _mm256_permute_pd(ymm12, 5); + ymm12 = _mm256_mul_pd(ymm0, ymm12); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_addsub_pd(ymm12, ymm3); - if(tc_inc_row == 1) //col stored - { - if(beta_mul_type == BLIS_MUL_ZERO) - { - //transpose left 3x2 - _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); - _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); - tC += tc_inc_col; + ymm3 = _mm256_permute_pd(ymm13, 5); + ymm13 = _mm256_mul_pd(ymm0, ymm13); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm13 = _mm256_addsub_pd(ymm13, ymm3); + } - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); - _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); - tC += tc_inc_col; + if(tc_inc_row == 1) //col stored + { + if(beta_mul_type == BLIS_MUL_ZERO) + { + //transpose left 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; - //transpose right 3x2 - _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); - _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm13)); - tC += tc_inc_col; + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + tC += tc_inc_col; - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); - _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm13, 1)); - } - else{ - ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate - ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate - //Multiply ymm4 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm4 = _mm256_add_pd(ymm4, ymm0); - //Multiply ymm8 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm8 = _mm256_add_pd(ymm8, ymm0); + //transpose right 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm13)); + tC += tc_inc_col; - //Multiply ymm12 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm12 = _mm256_add_pd(ymm12, ymm0); + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm13, 1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + //Multiply ymm8 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm8 = _mm256_add_pd(ymm8, ymm0); - //transpose left 3x2 - _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); - _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); - tC += tc_inc_col; + //Multiply ymm12 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm12 = _mm256_add_pd(ymm12, ymm0); - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); - _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); - tC += tc_inc_col; + //transpose left 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; - //Multiply ymm5 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm5 = _mm256_add_pd(ymm5, ymm0); - //Multiply ymm9 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm9 = _mm256_add_pd(ymm9, ymm0); + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + tC += tc_inc_col; - //Multiply ymm13 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm13 = _mm256_add_pd(ymm13, ymm0); + //Multiply ymm5 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm5 = _mm256_add_pd(ymm5, ymm0); + //Multiply ymm9 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm9 = _mm256_add_pd(ymm9, ymm0); - //transpose right 3x2 - _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); - _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm13)); - tC += tc_inc_col; + //Multiply ymm13 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm13 = _mm256_add_pd(ymm13, ymm0); - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); - _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm13, 1)); - } + //transpose right 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm13)); + tC += tc_inc_col; - } - else - { - if(beta_mul_type == BLIS_MUL_ZERO) - { - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); - } - else if(beta_mul_type == BLIS_MUL_ONE)// equivalent to if(beta->real == 1.0) - { - ymm2 = _mm256_loadu_pd((double const *)(tC)); - ymm4 = _mm256_add_pd(ymm4,ymm2); - ymm2 = _mm256_loadu_pd((double const *)(tC+2)); - ymm5 = _mm256_add_pd(ymm5,ymm2); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); - ymm8 = _mm256_add_pd(ymm8,ymm2); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); - ymm9 = _mm256_add_pd(ymm9,ymm2); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); - ymm12 = _mm256_add_pd(ymm12,ymm2); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2)); - ymm13 = _mm256_add_pd(ymm13,ymm2); + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm13, 1)); + } - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); - } - else{ - /* (br + bi) C + (ar + ai) AB */ - ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate - ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + } + else + { + if(beta_mul_type == BLIS_MUL_ZERO) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); + } + else if(beta_mul_type == BLIS_MUL_ONE)// equivalent to if(beta->real == 1.0) + { + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm4 = _mm256_add_pd(ymm4,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm5 = _mm256_add_pd(ymm5,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm8 = _mm256_add_pd(ymm8,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); + ymm9 = _mm256_add_pd(ymm9,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); + ymm12 = _mm256_add_pd(ymm12,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2)); + ymm13 = _mm256_add_pd(ymm13,ymm2); - ymm2 = _mm256_loadu_pd((double const *)(tC)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate - ymm2 = _mm256_loadu_pd((double const *)(tC+2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm9 = _mm256_add_pd(ymm9, _mm256_addsub_pd(ymm2, ymm3)); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm12 = _mm256_add_pd(ymm12, _mm256_addsub_pd(ymm2, ymm3)); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_add_pd(ymm9, _mm256_addsub_pd(ymm2, ymm3)); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm13 = _mm256_add_pd(ymm13, _mm256_addsub_pd(ymm2, ymm3)); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_add_pd(ymm12, _mm256_addsub_pd(ymm2, ymm3)); - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); - } - } - } + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm13 = _mm256_add_pd(ymm13, _mm256_addsub_pd(ymm2, ymm3)); - consider_edge_cases: - // Handle edge cases in the m dimension, if they exist. - if ( n_left ) - { - const dim_t mr_cur = 3; - const dim_t j_edge = n0 - ( dim_t )n_left; + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); + } + } + } - dcomplex* restrict cij = c + j_edge*cs_c; - dcomplex* restrict ai = a; - dcomplex* restrict bj = b + n_iter * 4; + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; - if ( 2 <= n_left ) - { - const dim_t nr_cur = 2; + dcomplex* restrict cij = c + j_edge*cs_c; + dcomplex* restrict ai = a; + dcomplex* restrict bj = b + n_iter * 4; - bli_zgemmsup_rv_zen_asm_3x2 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 1 == n_left ) - { - bli_zgemv_ex - ( - BLIS_NO_TRANSPOSE, conjb, m0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, - beta, cij, rs_c0, cntx, NULL - ); - } - } + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_zgemmsup_rv_zen_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } } @@ -549,286 +533,286 @@ void bli_zgemmsup_rv_zen_asm_2x4n ) { - uint64_t k_iter = 0; + uint64_t k_iter = 0; - uint64_t n_iter = n0 / 4; - uint64_t n_left = n0 % 4; + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; - if ( n_iter == 0 ) goto consider_edge_cases; + if ( n_iter == 0 ) goto consider_edge_cases; - // ------------------------------------------------------------------------- - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m128d xmm0, xmm3; + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m128d xmm0, xmm3; - dcomplex *tA = a; - double *tAimag = &a->imag; - dcomplex *tB = b; - dcomplex *tC = c; - for (n_iter = 0; n_iter < n0 / 4; n_iter++) - { - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - dim_t ta_inc_row = rs_a; - dim_t tb_inc_row = rs_b; - dim_t tc_inc_row = rs_c; + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 4; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); - dim_t ta_inc_col = cs_a; - dim_t tb_inc_col = cs_b; - dim_t tc_inc_col = cs_c; + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; - tA = a; - tAimag = &a->imag; - tB = b + n_iter*tb_inc_col*4; - tC = c + n_iter*tc_inc_col*4; - for (k_iter = 0; k_iter imag; + tB = b + n_iter*tb_inc_col*4; + tC = c + n_iter*tc_inc_col*4; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + ymm8 = _mm256_addsub_pd(ymm8, ymm10); + ymm9 = _mm256_addsub_pd(ymm9, ymm11); - ymm3 = _mm256_permute_pd(ymm4, 5); - ymm4 = _mm256_mul_pd(ymm0, ymm4); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_addsub_pd(ymm4, ymm3); + // alpha, beta multiplication. - ymm3 = _mm256_permute_pd(ymm5, 5); - ymm5 = _mm256_mul_pd(ymm0, ymm5); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_addsub_pd(ymm5, ymm3); + /* (ar + ai) x AB */ + ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate - ymm3 = _mm256_permute_pd(ymm8, 5); - ymm8 = _mm256_mul_pd(ymm0, ymm8); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_addsub_pd(ymm8, ymm3); + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); - ymm3 = _mm256_permute_pd(ymm9, 5); - ymm9 = _mm256_mul_pd(ymm0, ymm9); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm9 = _mm256_addsub_pd(ymm9, ymm3); + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); - if(tc_inc_row == 1) //col stored - { - if(beta->real == 0.0 && beta->imag == 0.0) - { - //transpose left 2x2 - _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); - tC += tc_inc_col; + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); - tC += tc_inc_col; + ymm3 = _mm256_permute_pd(ymm9, 5); + ymm9 = _mm256_mul_pd(ymm0, ymm9); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_addsub_pd(ymm9, ymm3); - //transpose right 2x2 - _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); - tC += tc_inc_col; + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 2x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + tC += tc_inc_col; - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); - } - else{ - ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate - ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate - //Multiply ymm4 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm4 = _mm256_add_pd(ymm4, ymm0); - //Multiply ymm8 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm8 = _mm256_add_pd(ymm8, ymm0); + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + tC += tc_inc_col; - //transpose left 2x2 - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); - tC += tc_inc_col; + //transpose right 2x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + tC += tc_inc_col; - _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm4,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); - tC += tc_inc_col; + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + //Multiply ymm8 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm8 = _mm256_add_pd(ymm8, ymm0); + + //transpose left 2x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + tC += tc_inc_col; - //Multiply ymm5 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm5 = _mm256_add_pd(ymm5, ymm0); - //Multiply ymm9 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm9 = _mm256_add_pd(ymm9, ymm0); + //Multiply ymm5 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm5 = _mm256_add_pd(ymm5, ymm0); + //Multiply ymm9 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm9 = _mm256_add_pd(ymm9, ymm0); - //transpose right 2x2 - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); - tC += tc_inc_col; + //transpose right 2x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + tC += tc_inc_col; - _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm5,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); - } + _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + } - } - else - { - if(beta->real == 0.0 && beta->imag == 0.0) - { - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); - } - else{ - /* (br + bi) C + (ar + ai) AB */ - ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate - ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate - ymm2 = _mm256_loadu_pd((double const *)(tC)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); - ymm2 = _mm256_loadu_pd((double const *)(tC+2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm9 = _mm256_add_pd(ymm9, _mm256_addsub_pd(ymm2, ymm3)); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_add_pd(ymm9, _mm256_addsub_pd(ymm2, ymm3)); - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); - } - } - } + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + } + } + } - consider_edge_cases: - // Handle edge cases in the m dimension, if they exist. - if ( n_left ) - { - const dim_t mr_cur = 3; - const dim_t j_edge = n0 - ( dim_t )n_left; + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; - dcomplex* restrict cij = c + j_edge*cs_c; - dcomplex* restrict ai = a; - dcomplex* restrict bj = b + n_iter * 4; + dcomplex* restrict cij = c + j_edge*cs_c; + dcomplex* restrict ai = a; + dcomplex* restrict bj = b + n_iter * 4; - if ( 2 <= n_left ) - { - const dim_t nr_cur = 2; + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; - bli_zgemmsup_rv_zen_asm_2x2 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 1 == n_left ) - { - bli_zgemv_ex - ( - BLIS_NO_TRANSPOSE, conjb, m0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, - beta, cij, rs_c0, cntx, NULL - ); - } - } + bli_zgemmsup_rv_zen_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } } @@ -848,215 +832,215 @@ void bli_zgemmsup_rv_zen_asm_1x4n cntx_t* restrict cntx ) { - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. - uint64_t k_iter = 0; + uint64_t k_iter = 0; - uint64_t n_iter = n0 / 4; - uint64_t n_left = n0 % 4; + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; - if ( n_iter == 0 ) goto consider_edge_cases; + if ( n_iter == 0 ) goto consider_edge_cases; - // ------------------------------------------------------------------------- - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m128d xmm0, xmm3; + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m128d xmm0, xmm3; - dcomplex *tA = a; - double *tAimag = &a->imag; - dcomplex *tB = b; - dcomplex *tC = c; - for (n_iter = 0; n_iter < n0 / 4; n_iter++) - { - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 4; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); - dim_t tb_inc_row = rs_b; - dim_t tc_inc_row = rs_c; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; - dim_t ta_inc_col = cs_a; - dim_t tb_inc_col = cs_b; - dim_t tc_inc_col = cs_c; + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; - tA = a; - tAimag = &a->imag; - tB = b + n_iter*tb_inc_col*4; - tC = c + n_iter*tc_inc_col*4; - for (k_iter = 0; k_iter imag; + tB = b + n_iter*tb_inc_col*4; + tC = c + n_iter*tc_inc_col*4; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + // alpha, beta multiplication. - ymm3 = _mm256_permute_pd(ymm4, 5); - ymm4 = _mm256_mul_pd(ymm0, ymm4); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_addsub_pd(ymm4, ymm3); + /* (ar + ai) x AB */ + ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate - ymm3 = _mm256_permute_pd(ymm5, 5); - ymm5 = _mm256_mul_pd(ymm0, ymm5); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_addsub_pd(ymm5, ymm3); + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); - if(tc_inc_row == 1) //col stored - { - if(beta->real == 0.0 && beta->imag == 0.0) - { - //transpose left 1x2 - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); - tC += tc_inc_col; + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); - _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm4,1)); - tC += tc_inc_col; + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 1x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + tC += tc_inc_col; - //transpose right 1x2 - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); - tC += tc_inc_col; + _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm4,1)); + tC += tc_inc_col; - _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm5,1)); - } - else{ - ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate - ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate - //Multiply ymm4 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm4 = _mm256_add_pd(ymm4, ymm0); + //transpose right 1x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); + tC += tc_inc_col; - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); - tC += tc_inc_col; + _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm5,1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); - _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm4,1)); - tC += tc_inc_col; + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + tC += tc_inc_col; - //Multiply ymm5 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm5 = _mm256_add_pd(ymm5, ymm0); + _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm4,1)); + tC += tc_inc_col; - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); - tC += tc_inc_col; + //Multiply ymm5 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm5 = _mm256_add_pd(ymm5, ymm0); - _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm5,1)); - } + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); + tC += tc_inc_col; - } - else - { - if(beta->real == 0.0 && beta->imag == 0.0) - { - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - } - else{ - /* (br + bi) C + (ar + ai) AB */ - ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate - ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm5,1)); + } - ymm2 = _mm256_loadu_pd((double const *)(tC)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate - ymm2 = _mm256_loadu_pd((double const *)(tC+2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - } - } - } + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); - consider_edge_cases: - // Handle edge cases in the m dimension, if they exist. - if ( n_left ) - { - const dim_t mr_cur = 3; - const dim_t j_edge = n0 - ( dim_t )n_left; + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + } + } + } - dcomplex* restrict cij = c + j_edge*cs_c; - dcomplex* restrict ai = a; - dcomplex* restrict bj = b + n_iter * 4; + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; - if ( 2 <= n_left ) - { - const dim_t nr_cur = 2; - bli_zgemmsup_rv_zen_asm_1x2 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 1 == n_left ) - { - bli_zgemv_ex - ( - BLIS_NO_TRANSPOSE, conjb, m0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, - beta, cij, rs_c0, cntx, NULL - ); - } - } + dcomplex* restrict cij = c + j_edge*cs_c; + dcomplex* restrict ai = a; + dcomplex* restrict bj = b + n_iter * 4; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + bli_zgemmsup_rv_zen_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } } void bli_zgemmsup_rv_zen_asm_3x2 @@ -1075,194 +1059,194 @@ void bli_zgemmsup_rv_zen_asm_3x2 cntx_t* restrict cntx ) { - uint64_t k_iter = 0; + uint64_t k_iter = 0; - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; - // ------------------------------------------------------------------------- - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm6; - __m256d ymm8, ymm10; - __m256d ymm12, ymm14; - __m128d xmm0, xmm3; + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm6; + __m256d ymm8, ymm10; + __m256d ymm12, ymm14; + __m128d xmm0, xmm3; - dcomplex *tA = a; - double *tAimag = &a->imag; - dcomplex *tB = b; - dcomplex *tC = c; - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); - dim_t ta_inc_row = rs_a; - dim_t tb_inc_row = rs_b; - dim_t tc_inc_row = rs_c; + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; - dim_t ta_inc_col = cs_a; - dim_t tc_inc_col = cs_c; + dim_t ta_inc_col = cs_a; + dim_t tc_inc_col = cs_c; - for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + // alpha, beta multiplication. - ymm3 = _mm256_permute_pd(ymm4, 5); - ymm4 = _mm256_mul_pd(ymm0, ymm4); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_addsub_pd(ymm4, ymm3); + /* (ar + ai) x AB */ + ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate - ymm3 = _mm256_permute_pd(ymm8, 5); - ymm8 = _mm256_mul_pd(ymm0, ymm8); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_addsub_pd(ymm8, ymm3); + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); - ymm3 = _mm256_permute_pd(ymm12, 5); - ymm12 = _mm256_mul_pd(ymm0, ymm12); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm12 = _mm256_addsub_pd(ymm12, ymm3); + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); - if(tc_inc_row == 1) //col stored - { - if(beta->real == 0.0 && beta->imag == 0.0) - { - //transpose left 3x2 - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); - _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); - tC += tc_inc_col; + ymm3 = _mm256_permute_pd(ymm12, 5); + ymm12 = _mm256_mul_pd(ymm0, ymm12); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_addsub_pd(ymm12, ymm3); - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); - _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); - } - else{ - ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate - ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate - //Multiply ymm4 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm4 = _mm256_add_pd(ymm4, ymm0); - //Multiply ymm8 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm8 = _mm256_add_pd(ymm8, ymm0); + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 3x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; - //Multiply ymm12 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm12 = _mm256_add_pd(ymm12, ymm0); + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + //Multiply ymm8 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm8 = _mm256_add_pd(ymm8, ymm0); - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); - _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); - tC += tc_inc_col; - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); - _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); - } - } - else - { - if(beta->real == 0.0 && beta->imag == 0.0) - { - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + tc_inc_row ), ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); - } - else{ - /* (br + bi) C + (ar + ai) AB */ - ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate - ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + //Multiply ymm12 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm12 = _mm256_add_pd(ymm12, ymm0); - ymm2 = _mm256_loadu_pd((double const *)(tC)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + tc_inc_row ), ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm12 = _mm256_add_pd(ymm12, _mm256_addsub_pd(ymm2, ymm3)); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); - } - } + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_add_pd(ymm12, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + } + } }