diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c index 6c68707e1..a21c9b5ed 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c @@ -6,7 +6,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020-2021, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1281,6 +1281,9 @@ void bli_cgemmsup_rv_zen_asm_3x4 ymm1 = _mm256_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate ymm2 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate + xmm0 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); + //Multiply ymm4 with beta xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ; xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col)); diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c index 1638eaba0..787d3f772 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020-2021, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,8 +38,9 @@ #define BLIS_ASM_SYNTAX_ATT #include "bli_x86_asm_macros.h" -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 +/* Assumes beta.r, beta.i have been broadcast into ymm1, ymm2. + and store outputs to ymm0 + (creal,cimag)*(betar,beati) where c is stored in col major order*/ #define ZGEMM_INPUT_SCALE_CS_BETA_NZ \ vmovupd(mem(rcx), xmm0) \ vmovupd(mem(rcx, rsi, 1), xmm3) \ @@ -49,6 +50,7 @@ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) +//(creal,cimag)*(betar,beati) where c is stored in row major order #define ZGEMM_INPUT_SCALE_RS_BETA_NZ \ vmovupd(mem(rcx), ymm0) \ vpermilpd(imm(0x5), ymm0, ymm3) \ @@ -59,15 +61,19 @@ #define ZGEMM_OUTPUT_RS \ vmovupd(ymm0, mem(rcx)) \ +/*(cNextRowreal,cNextRowimag)*(betar,beati) + where c is stored in row major order + rsi = cs_c * sizeof((real +imag)dt)*numofElements + numofElements = 2, 2 elements are processed at a time*/ #define ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT \ - vmovupd(mem(rcx, rsi, 8), ymm0) \ + vmovupd(mem(rcx, rsi, 1), ymm0) \ vpermilpd(imm(0x5), ymm0, ymm3) \ vmulpd(ymm1, ymm0, ymm0) \ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) #define ZGEMM_OUTPUT_RS_NEXT \ - vmovupd(ymm0, mem(rcx, rsi, 8)) \ + vmovupd(ymm0, mem(rcx, rsi, 1)) \ void bli_zgemmsup_rv_zen_asm_2x4 @@ -100,7 +106,6 @@ void bli_zgemmsup_rv_zen_asm_2x4 uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; @@ -116,8 +121,6 @@ void bli_zgemmsup_rv_zen_asm_2x4 lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) - //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a - mov(var(rs_b), r10) // load rs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt) @@ -141,40 +144,29 @@ void bli_zgemmsup_rv_zen_asm_2x4 mov(var(m_iter), r11) // ii = m_iter; - label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] - vzeroall() // zero all xmm/ymm registers. mov(var(b), rbx) // load address of b. - //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.SCOLPFETCH) // jump to column storage case - label(.SROWPFETCH) // row-stored pre-fetching on c // not used + jz(.ZCOLPFETCH1) // jump to column storage case + label(.ZROWPFETCH1) // row-stored pre-fetching on c // not used - lea(mem(r12, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - jmp(.SPOSTPFETCH) // jump to end of pre-fetching c - label(.SCOLPFETCH) // column-stored pre-fetching c + jmp(.ZPOSTPFETCH1) // jump to end of pre-fetching c + label(.ZCOLPFETCH1) // column-stored pre-fetching c mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) - lea(mem(r12, rsi, 2), rdx) // - lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - - label(.SPOSTPFETCH) // done prefetching c - - lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines - lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + label(.ZPOSTPFETCH1) // done prefetching c mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. - je(.SCONSIDKLEFT) // if i == 0, jump to code that + je(.ZCONSIDKLEFT1) // if i == 0, jump to code that // contains the k_left loop. - label(.SLOOPKITER) // MAIN LOOP + label(.ZLOOPKITER1) // MAIN LOOP // ---------------------------------- iteration 0 @@ -251,8 +243,6 @@ void bli_zgemmsup_rv_zen_asm_2x4 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; @@ -277,16 +267,16 @@ void bli_zgemmsup_rv_zen_asm_2x4 add(r9, rax) // a += cs_a; dec(rsi) // i -= 1; - jne(.SLOOPKITER) // iterate again if i != 0. + jne(.ZLOOPKITER1) // iterate again if i != 0. - label(.SCONSIDKLEFT) + label(.ZCONSIDKLEFT1) mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. - je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + je(.ZPOSTACCUM1) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - label(.SLOOPKLEFT) // EDGE LOOP + label(.ZLOOPKLEFT1) // EDGE LOOP vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -312,9 +302,9 @@ void bli_zgemmsup_rv_zen_asm_2x4 add(r9, rax) // a += cs_a; dec(rsi) // i -= 1; - jne(.SLOOPKLEFT) // iterate again if i != 0. + jne(.ZLOOPKLEFT1) // iterate again if i != 0. - label(.SPOSTACCUM) + label(.ZPOSTACCUM1) mov(r12, rcx) // reset rcx to current utile of c. @@ -363,10 +353,8 @@ void bli_zgemmsup_rv_zen_asm_2x4 vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) - - lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. @@ -375,14 +363,13 @@ void bli_zgemmsup_rv_zen_asm_2x4 vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); and(r13b, r15b) // set ZF if r13b & r15b == 1. - jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + jne(.ZBETAZERO1) // if ZF = 1, jump to beta == 0 case cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16. - jz(.SCOLSTORED) // jump to column storage case + jz(.ZCOLSTORED1) // jump to column storage case - label(.SROWSTORED) + label(.ZROWSTORED1) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) * numofElements ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm4, ymm0, ymm0) @@ -401,21 +388,15 @@ void bli_zgemmsup_rv_zen_asm_2x4 vaddpd(ymm9, ymm0, ymm0) ZGEMM_OUTPUT_RS_NEXT - jmp(.SDONE) // jump to end. + jmp(.ZDONE1) // jump to end. - label(.SCOLSTORED) + label(.ZCOLSTORED1) /*|--------| |-------| | | | | | 2x4 | | 4x2 | |--------| |-------| */ - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a - ZGEMM_INPUT_SCALE_CS_BETA_NZ vaddpd(ymm4, ymm0, ymm4) @@ -438,9 +419,6 @@ void bli_zgemmsup_rv_zen_asm_2x4 /****3x4 tile going to save into 4x2 tile in C*****/ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) /******************Transpose top tile 4x3***************************/ vmovups(xmm4, mem(rcx)) @@ -465,29 +443,27 @@ void bli_zgemmsup_rv_zen_asm_2x4 vmovups(xmm5, mem(rcx)) vmovups(xmm9, mem(rcx, 16)) - jmp(.SDONE) // jump to end. + jmp(.ZDONE1) // jump to end. - label(.SBETAZERO) - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + label(.ZBETAZERO1) cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.SCOLSTORBZ) // jump to column storage case + jz(.ZCOLSTORBZ1) // jump to column storage case - label(.SROWSTORBZ) + label(.ZROWSTORBZ1) + + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) * numofElements vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rsi, 8)) + vmovupd(ymm5, mem(rcx, rsi, 1)) add(rdi, rcx) vmovupd(ymm8, mem(rcx)) - vmovupd(ymm9, mem(rcx, rsi, 8)) + vmovupd(ymm9, mem(rcx, rsi, 1)) - jmp(.SDONE) // jump to end. + jmp(.ZDONE1) // jump to end. - label(.SCOLSTORBZ) + label(.ZCOLSTORBZ1) /****2x4 tile going to save into 4x2 tile in C*****/ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) /******************Transpose tile 2x4***************************/ vmovups(xmm4, mem(rcx)) @@ -512,7 +488,7 @@ void bli_zgemmsup_rv_zen_asm_2x4 vmovups(xmm5, mem(rcx)) vmovups(xmm9, mem(rcx, 16)) - label(.SDONE) + label(.ZDONE1) end_asm( : // output operands (none) @@ -525,7 +501,6 @@ void bli_zgemmsup_rv_zen_asm_2x4 [cs_a] "m" (cs_a), [b] "m" (b), [rs_b] "m" (rs_b), - [cs_b] "m" (cs_b), [alpha] "m" (alpha), [beta] "m" (beta), [c] "m" (c), @@ -534,7 +509,7 @@ void bli_zgemmsup_rv_zen_asm_2x4 [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -577,7 +552,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; @@ -593,8 +567,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) - //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a - mov(var(rs_b), r10) // load rs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt) @@ -618,8 +590,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 mov(var(m_iter), r11) // ii = m_iter; - label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] - vzeroall() // zero all xmm/ymm registers. mov(var(b), rbx) // load address of b. @@ -627,31 +597,23 @@ void bli_zgemmsup_rv_zen_asm_1x4 mov(r14, rax) // reset rax to current upanel of a. cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.SCOLPFETCH) // jump to column storage case - label(.SROWPFETCH) // row-stored pre-fetching on c // not used + jz(.ZCOLPFETCH1) // jump to column storage case + label(.ZROWPFETCH1) // row-stored pre-fetching on c // not used - lea(mem(r12, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - jmp(.SPOSTPFETCH) // jump to end of pre-fetching c - label(.SCOLPFETCH) // column-stored pre-fetching c + jmp(.ZPOSTPFETCH1) // jump to end of pre-fetching c + label(.ZCOLPFETCH1) // column-stored pre-fetching c mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) - lea(mem(r12, rsi, 2), rdx) // - lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - label(.SPOSTPFETCH) // done prefetching c - - lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines - lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + label(.ZPOSTPFETCH1) // done prefetching c mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. - je(.SCONSIDKLEFT) // if i == 0, jump to code that + je(.ZCONSIDKLEFT1) // if i == 0, jump to code that // contains the k_left loop. - label(.SLOOPKITER) // MAIN LOOP + label(.ZLOOPKITER1) // MAIN LOOP // ---------------------------------- iteration 0 @@ -707,7 +669,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -725,16 +686,16 @@ void bli_zgemmsup_rv_zen_asm_1x4 add(r9, rax) // a += cs_a; dec(rsi) // i -= 1; - jne(.SLOOPKITER) // iterate again if i != 0. + jne(.ZLOOPKITER1) // iterate again if i != 0. - label(.SCONSIDKLEFT) + label(.ZCONSIDKLEFT1) mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. - je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + je(.ZPOSTACCUM1) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - label(.SLOOPKLEFT) // EDGE LOOP + label(.ZLOOPKLEFT1) // EDGE LOOP vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -751,9 +712,9 @@ void bli_zgemmsup_rv_zen_asm_1x4 add(r9, rax) // a += cs_a; dec(rsi) // i -= 1; - jne(.SLOOPKLEFT) // iterate again if i != 0. + jne(.ZLOOPKLEFT1) // iterate again if i != 0. - label(.SPOSTACCUM) + label(.ZPOSTACCUM1) mov(r12, rcx) // reset rcx to current utile of c. @@ -787,10 +748,8 @@ void bli_zgemmsup_rv_zen_asm_1x4 vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) - - lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. @@ -799,14 +758,14 @@ void bli_zgemmsup_rv_zen_asm_1x4 vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); and(r13b, r15b) // set ZF if r13b & r15b == 1. - jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + jne(.ZBETAZERO1) // if ZF = 1, jump to beta == 0 case cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16. - jz(.SCOLSTORED) // jump to column storage case + jz(.ZCOLSTORED1) // jump to column storage case - label(.SROWSTORED) + label(.ZROWSTORED1) + + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) * numofElements ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm4, ymm0, ymm0) @@ -816,21 +775,15 @@ void bli_zgemmsup_rv_zen_asm_1x4 vaddpd(ymm5, ymm0, ymm0) ZGEMM_OUTPUT_RS_NEXT - jmp(.SDONE) // jump to end. + jmp(.ZDONE1) // jump to end. - label(.SCOLSTORED) + label(.ZCOLSTORED1) /*|--------| |-------| | | | | | 1x4 | | 4x1 | |--------| |-------| */ - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a - ZGEMM_INPUT_SCALE_CS_BETA_NZ vaddpd(ymm4, ymm0, ymm4) @@ -842,9 +795,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 mov(r12, rcx) // reset rcx to current utile of c. /****1x4 tile going to save into 4x1 tile in C*****/ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) vmovups(xmm4, mem(rcx)) @@ -862,21 +812,21 @@ void bli_zgemmsup_rv_zen_asm_1x4 vextractf128(imm(0x1), ymm5, xmm5) vmovups(xmm5, mem(rcx)) - jmp(.SDONE) // jump to end. + jmp(.ZDONE1) // jump to end. - label(.SBETAZERO) - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + label(.ZBETAZERO1) cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.SCOLSTORBZ) // jump to column storage case + jz(.ZCOLSTORBZ1) // jump to column storage case - label(.SROWSTORBZ) + label(.ZROWSTORBZ1) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) * numofElements vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rsi, 8)) + vmovupd(ymm5, mem(rcx, rsi, 1)) - jmp(.SDONE) // jump to end. + jmp(.ZDONE1) // jump to end. - label(.SCOLSTORBZ) + label(.ZCOLSTORBZ1) /****1x4 tile going to save into 4x1 tile in C*****/ mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) @@ -898,7 +848,7 @@ void bli_zgemmsup_rv_zen_asm_1x4 vextractf128(imm(0x1), ymm5, xmm5) vmovups(xmm5, mem(rcx)) - label(.SDONE) + label(.ZDONE1) end_asm( : // output operands (none) @@ -911,7 +861,6 @@ void bli_zgemmsup_rv_zen_asm_1x4 [cs_a] "m" (cs_a), [b] "m" (b), [rs_b] "m" (rs_b), - [cs_b] "m" (cs_b), [alpha] "m" (alpha), [beta] "m" (beta), [c] "m" (c), @@ -920,7 +869,7 @@ void bli_zgemmsup_rv_zen_asm_1x4 [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -962,7 +911,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; @@ -1001,41 +949,29 @@ void bli_zgemmsup_rv_zen_asm_2x2 mov(var(m_iter), r11) // ii = m_iter; - label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] - vzeroall() // zero all xmm/ymm registers. mov(var(b), rbx) // load address of b. - //mov(r12, rcx) // reset rcx to current utile of c. mov(r14, rax) // reset rax to current upanel of a. cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.SCOLPFETCH) // jump to column storage case - label(.SROWPFETCH) // row-stored pre-fetching on c // not used + jz(.ZCOLPFETCH1) // jump to column storage case + label(.ZROWPFETCH1) // row-stored pre-fetching on c // not used - lea(mem(r12, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - - jmp(.SPOSTPFETCH) // jump to end of pre-fetching c - label(.SCOLPFETCH) // column-stored pre-fetching c + jmp(.ZPOSTPFETCH1) // jump to end of pre-fetching c + label(.ZCOLPFETCH1) // column-stored pre-fetching c mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) - lea(mem(r12, rsi, 2), rdx) // - lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - label(.SPOSTPFETCH) // done prefetching c - - lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines - lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + label(.ZPOSTPFETCH1) // done prefetching c mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. - je(.SCONSIDKLEFT) // if i == 0, jump to code that + je(.ZCONSIDKLEFT1) // if i == 0, jump to code that // contains the k_left loop. - label(.SLOOPKITER) // MAIN LOOP + label(.ZLOOPKITER1) // MAIN LOOP // ---------------------------------- iteration 0 @@ -1096,7 +1032,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1117,16 +1052,16 @@ void bli_zgemmsup_rv_zen_asm_2x2 add(r9, rax) // a += cs_a; dec(rsi) // i -= 1; - jne(.SLOOPKITER) // iterate again if i != 0. + jne(.ZLOOPKITER1) // iterate again if i != 0. - label(.SCONSIDKLEFT) + label(.ZCONSIDKLEFT1) mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. - je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + je(.ZPOSTACCUM1) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - label(.SLOOPKLEFT) // EDGE LOOP + label(.ZLOOPKLEFT1) // EDGE LOOP vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1147,9 +1082,9 @@ void bli_zgemmsup_rv_zen_asm_2x2 add(r9, rax) // a += cs_a; dec(rsi) // i -= 1; - jne(.SLOOPKLEFT) // iterate again if i != 0. + jne(.ZLOOPKLEFT1) // iterate again if i != 0. - label(.SPOSTACCUM) + label(.ZPOSTACCUM1) mov(r12, rcx) // reset rcx to current utile of c. @@ -1183,12 +1118,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) - - lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. @@ -1196,13 +1125,12 @@ void bli_zgemmsup_rv_zen_asm_2x2 vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); and(r13b, r15b) // set ZF if r13b & r15b == 1. - jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + jne(.ZBETAZERO1) // if ZF = 1, jump to beta == 0 case - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.SCOLSTORED) // jump to column storage case + jz(.ZCOLSTORED1) // jump to column storage case - label(.SROWSTORED) + label(.ZROWSTORED1) ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm4, ymm0, ymm0) @@ -1214,9 +1142,9 @@ void bli_zgemmsup_rv_zen_asm_2x2 vaddpd(ymm8, ymm0, ymm0) ZGEMM_OUTPUT_RS - jmp(.SDONE) // jump to end. + jmp(.ZDONE1) // jump to end. - label(.SCOLSTORED) + label(.ZCOLSTORED1) /*|--------| |-------| | | | | | 2x2 | | 2x2 | @@ -1227,8 +1155,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a - ZGEMM_INPUT_SCALE_CS_BETA_NZ vaddpd(ymm4, ymm0, ymm4) @@ -1239,9 +1165,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 mov(r12, rcx) // reset rcx to current utile of c. /****2x2 tile going to save into 2x2 tile in C*****/ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) vmovups(xmm4, mem(rcx)) vmovups(xmm8, mem(rcx, 16)) @@ -1253,23 +1176,23 @@ void bli_zgemmsup_rv_zen_asm_2x2 vmovups(xmm4, mem(rcx)) vmovups(xmm8, mem(rcx, 16)) - jmp(.SDONE) // jump to end. + jmp(.ZDONE1) // jump to end. - label(.SBETAZERO) + label(.ZBETAZERO1) cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8. - jz(.SCOLSTORBZ) // jump to column storage case + jz(.ZCOLSTORBZ1) // jump to column storage case - label(.SROWSTORBZ) + label(.ZROWSTORBZ1) vmovupd(ymm4, mem(rcx)) add(rdi, rcx) vmovupd(ymm8, mem(rcx)) - jmp(.SDONE) // jump to end. + jmp(.ZDONE1) // jump to end. - label(.SCOLSTORBZ) + label(.ZCOLSTORBZ1) /****2x2 tile going to save into 2x2 tile in C*****/ mov(var(cs_c), rsi) // load cs_c @@ -1286,7 +1209,7 @@ void bli_zgemmsup_rv_zen_asm_2x2 vmovups(xmm4, mem(rcx)) vmovups(xmm8, mem(rcx, 16)) - label(.SDONE) + label(.ZDONE1) end_asm( : // output operands (none) @@ -1299,7 +1222,6 @@ void bli_zgemmsup_rv_zen_asm_2x2 [cs_a] "m" (cs_a), [b] "m" (b), [rs_b] "m" (rs_b), - [cs_b] "m" (cs_b), [alpha] "m" (alpha), [beta] "m" (beta), [c] "m" (c), @@ -1308,7 +1230,7 @@ void bli_zgemmsup_rv_zen_asm_2x2 [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1350,7 +1272,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; @@ -1366,7 +1287,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) -// lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a mov(var(rs_b), r10) // load rs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) @@ -1391,8 +1311,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 mov(var(m_iter), r11) // ii = m_iter; - label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] - vzeroall() // zero all xmm/ymm registers. mov(var(b), rbx) // load address of b. @@ -1400,32 +1318,23 @@ void bli_zgemmsup_rv_zen_asm_1x2 mov(r14, rax) // reset rax to current upanel of a. cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.SCOLPFETCH) // jump to column storage case - label(.SROWPFETCH) // row-stored pre-fetching on c // not used + jz(.ZCOLPFETCH1) // jump to column storage case + label(.ZROWPFETCH1) // row-stored pre-fetching on c // not used - lea(mem(r12, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - - jmp(.SPOSTPFETCH) // jump to end of pre-fetching c - label(.SCOLPFETCH) // column-stored pre-fetching c + jmp(.ZPOSTPFETCH1) // jump to end of pre-fetching c + label(.ZCOLPFETCH1) // column-stored pre-fetching c mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) - lea(mem(r12, rsi, 2), rdx) // - lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - label(.SPOSTPFETCH) // done prefetching c - - lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines - lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + label(.ZPOSTPFETCH1) // done prefetching c mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. - je(.SCONSIDKLEFT) // if i == 0, jump to code that + je(.ZCONSIDKLEFT1) // if i == 0, jump to code that // contains the k_left loop. - label(.SLOOPKITER) // MAIN LOOP + label(.ZLOOPKITER1) // MAIN LOOP // ---------------------------------- iteration 0 @@ -1469,8 +1378,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1483,16 +1390,16 @@ void bli_zgemmsup_rv_zen_asm_1x2 add(r9, rax) // a += cs_a; dec(rsi) // i -= 1; - jne(.SLOOPKITER) // iterate again if i != 0. + jne(.ZLOOPKITER1) // iterate again if i != 0. - label(.SCONSIDKLEFT) + label(.ZCONSIDKLEFT1) mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. - je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + je(.ZPOSTACCUM1) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - label(.SLOOPKLEFT) // EDGE LOOP + label(.ZLOOPKLEFT1) // EDGE LOOP vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1506,9 +1413,9 @@ void bli_zgemmsup_rv_zen_asm_1x2 add(r9, rax) // a += cs_a; dec(rsi) // i -= 1; - jne(.SLOOPKLEFT) // iterate again if i != 0. + jne(.ZLOOPKLEFT1) // iterate again if i != 0. - label(.SPOSTACCUM) + label(.ZPOSTACCUM1) mov(r12, rcx) // reset rcx to current utile of c. @@ -1534,12 +1441,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) - - lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. @@ -1547,21 +1448,20 @@ void bli_zgemmsup_rv_zen_asm_1x2 vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); and(r13b, r15b) // set ZF if r13b & r15b == 1. - jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + jne(.ZBETAZERO1) // if ZF = 1, jump to beta == 0 case - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.SCOLSTORED) // jump to column storage case + jz(.ZCOLSTORED1) // jump to column storage case - label(.SROWSTORED) + label(.ZROWSTORED1) ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm4, ymm0, ymm0) ZGEMM_OUTPUT_RS - jmp(.SDONE) // jump to end. + jmp(.ZDONE1) // jump to end. - label(.SCOLSTORED) + label(.ZCOLSTORED1) /*|--------| |-------| | | | | | 1x2 | | 2x1 | @@ -1572,17 +1472,11 @@ void bli_zgemmsup_rv_zen_asm_1x2 lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a - ZGEMM_INPUT_SCALE_CS_BETA_NZ vaddpd(ymm4, ymm0, ymm4) /****3x4 tile going to save into 4x3 tile in C*****/ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) - /******************Transpose tile 1x2***************************/ vmovups(xmm4, mem(rcx)) add(rsi, rcx) @@ -1591,22 +1485,22 @@ void bli_zgemmsup_rv_zen_asm_1x2 vmovups(xmm4, mem(rcx)) - jmp(.SDONE) // jump to end. + jmp(.ZDONE1) // jump to end. - label(.SBETAZERO) + label(.ZBETAZERO1) cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8. - jz(.SCOLSTORBZ) // jump to column storage case + jz(.ZCOLSTORBZ1) // jump to column storage case - label(.SROWSTORBZ) + label(.ZROWSTORBZ1) vmovupd(ymm4, mem(rcx)) add(rdi, rcx) - jmp(.SDONE) // jump to end. + jmp(.ZDONE1) // jump to end. - label(.SCOLSTORBZ) + label(.ZCOLSTORBZ1) /****1x2 tile going to save into 2x1 tile in C*****/ mov(var(cs_c), rsi) // load cs_c @@ -1622,7 +1516,7 @@ void bli_zgemmsup_rv_zen_asm_1x2 vextractf128(imm(0x1), ymm4, xmm4) vmovups(xmm4, mem(rcx)) - label(.SDONE) + label(.ZDONE1) end_asm( : // output operands (none) @@ -1635,7 +1529,6 @@ void bli_zgemmsup_rv_zen_asm_1x2 [cs_a] "m" (cs_a), [b] "m" (b), [rs_b] "m" (rs_b), - [cs_b] "m" (cs_b), [alpha] "m" (alpha), [beta] "m" (beta), [c] "m" (c), @@ -1644,7 +1537,7 @@ void bli_zgemmsup_rv_zen_asm_1x2 [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c index f5b42c623..1e9bacd9a 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c @@ -38,8 +38,9 @@ #define BLIS_ASM_SYNTAX_ATT #include "bli_x86_asm_macros.h" -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 +/* Assumes beta.r, beta.i have been broadcast into ymm1, ymm2. + and store outputs to ymm0 + (creal,cimag)*(betar,beati) where c is stored in col major order*/ #define ZGEMM_INPUT_SCALE_CS_BETA_NZ \ vmovupd(mem(rcx), xmm0) \ vmovupd(mem(rcx, rsi, 1), xmm3) \ @@ -49,6 +50,7 @@ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) +//(creal,cimag)*(betar,beati) where c is stored in row major order #define ZGEMM_INPUT_SCALE_RS_BETA_NZ \ vmovupd(mem(rcx), ymm0) \ vpermilpd(imm(0x5), ymm0, ymm3) \ @@ -62,18 +64,22 @@ #define ZGEMM_OUTPUT_RS \ vmovupd(ymm0, mem(rcx)) \ +/*(cNextRowreal,cNextRowimag)*(betar,beati) + where c is stored in row major order + rsi = cs_c * sizeof((real +imag)dt)*numofElements + numofElements = 2, 2 elements are processed at a time*/ #define ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT \ - vmovupd(mem(rcx, rsi, 8), ymm0) \ + vmovupd(mem(rcx, rsi, 1), ymm0) \ vpermilpd(imm(0x5), ymm0, ymm3) \ vmulpd(ymm1, ymm0, ymm0) \ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) #define ZGEMM_INPUT_RS_BETA_ONE_NEXT \ - vmovupd(mem(rcx, rsi, 8), ymm0) + vmovupd(mem(rcx, rsi, 1), ymm0) #define ZGEMM_OUTPUT_RS_NEXT \ - vmovupd(ymm0, mem(rcx, rsi, 8)) + vmovupd(ymm0, mem(rcx, rsi, 1)) /* rrr: @@ -174,7 +180,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; @@ -227,8 +232,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m lea(mem(, r9, 8), r9) // cs_a *= sizeof( real dt) lea(mem(, r9, 2), r9) // cs_a *= sizeof((real + imag) dt) - //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a - mov(var(rs_b), r10) // load rs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(real dt) lea(mem(, r10, 2), r10) // rs_b *= sizeof((real +imag) dt) @@ -252,7 +255,7 @@ void bli_zgemmsup_rv_zen_asm_3x4m mov(var(m_iter), r11) // ii = m_iter; - label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + label(.ZLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] vzeroall() // zero all xmm/ymm registers. @@ -260,31 +263,22 @@ void bli_zgemmsup_rv_zen_asm_3x4m mov(r14, rax) // reset rax to current upanel of a. cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.SCOLPFETCH) // jump to column storage case - label(.SROWPFETCH) // row-stored pre-fetching on c // not used + jz(.ZCOLPFETCH) // jump to column storage case + label(.ZROWPFETCH) // row-stored pre-fetching on c // not used - lea(mem(r12, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - jmp(.SPOSTPFETCH) // jump to end of pre-fetching c - label(.SCOLPFETCH) // column-stored pre-fetching c + jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c + label(.ZCOLPFETCH) // column-stored pre-fetching c mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) - lea(mem(r12, rsi, 2), rdx) // - lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - - label(.SPOSTPFETCH) // done prefetching c - - lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines - lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + label(.ZPOSTPFETCH) // done prefetching c mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. - je(.SCONSIDKLEFT) // if i == 0, jump to code that + je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - label(.SLOOPKITER) // MAIN LOOP + label(.ZLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 @@ -383,8 +377,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; @@ -416,16 +408,16 @@ void bli_zgemmsup_rv_zen_asm_3x4m add(r9, rax) // a += cs_a; dec(rsi) // i -= 1; - jne(.SLOOPKITER) // iterate again if i != 0. + jne(.ZLOOPKITER) // iterate again if i != 0. - label(.SCONSIDKLEFT) + label(.ZCONSIDKLEFT) mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. - je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - label(.SLOOPKLEFT) // EDGE LOOP + label(.ZLOOPKLEFT) // EDGE LOOP vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) @@ -458,9 +450,9 @@ void bli_zgemmsup_rv_zen_asm_3x4m add(r9, rax) // a += cs_a; dec(rsi) // i -= 1; - jne(.SLOOPKLEFT) // iterate again if i != 0. + jne(.ZLOOPKLEFT) // iterate again if i != 0. - label(.SPOSTACCUM) + label(.ZPOSTACCUM) mov(r12, rcx) // reset rcx to current utile of c. @@ -483,6 +475,10 @@ void bli_zgemmsup_rv_zen_asm_3x4m vaddsubpd(ymm14, ymm12, ymm12) vaddsubpd(ymm15, ymm13, ymm13) + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) + //if(alpha_mul_type == BLIS_MUL_MINUS_ONE) mov(var(alpha_mul_type), al) cmp(imm(0xFF), al) @@ -541,19 +537,18 @@ void bli_zgemmsup_rv_zen_asm_3x4m label(.ALPHA_REAL_ONE) // Beta multiplication /* (br + bi)x C + ((ar + ai) x AB) */ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) mov(var(beta_mul_type), al) cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO) - je(.SBETAZERO) //jump to beta == 0 case - - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + je(.ZBETAZERO) //jump to beta == 0 case cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16. - jz(.SCOLSTORED) // jump to column storage case + jz(.ZCOLSTORED) // jump to column storage case + + label(.ZROWSTORED) + + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) * numofElements - label(.SROWSTORED) cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT) je(.ROW_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication. @@ -586,7 +581,7 @@ void bli_zgemmsup_rv_zen_asm_3x4m ZGEMM_INPUT_RS_BETA_ONE_NEXT vaddpd(ymm13, ymm0, ymm0) ZGEMM_OUTPUT_RS_NEXT - jmp(.SDONE) + jmp(.ZDONE) //CASE 2: beta is real = -1 @@ -616,7 +611,7 @@ void bli_zgemmsup_rv_zen_asm_3x4m ZGEMM_INPUT_RS_BETA_ONE_NEXT vsubpd(ymm0, ymm13, ymm0) ZGEMM_OUTPUT_RS_NEXT - jmp(.SDONE) + jmp(.ZDONE) //CASE 3: Default case with multiplication @@ -651,9 +646,9 @@ void bli_zgemmsup_rv_zen_asm_3x4m ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT vaddpd(ymm13, ymm0, ymm0) ZGEMM_OUTPUT_RS_NEXT - jmp(.SDONE) // jump to end. + jmp(.ZDONE) // jump to end. - label(.SCOLSTORED) + label(.ZCOLSTORED) mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate @@ -663,11 +658,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m |--------| |-------| */ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a - ZGEMM_INPUT_SCALE_CS_BETA_NZ vaddpd(ymm4, ymm0, ymm4) @@ -696,9 +686,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m /****3x4 tile going to save into 4x3 tile in C*****/ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) /******************Transpose top tile 4x3***************************/ vmovups(xmm4, mem(rcx)) @@ -729,34 +716,34 @@ void bli_zgemmsup_rv_zen_asm_3x4m vmovups(xmm9, mem(rcx, 16)) vmovups(xmm13,mem(rcx,32)) - jmp(.SDONE) // jump to end. + jmp(.ZDONE) // jump to end. - label(.SBETAZERO) - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + label(.ZBETAZERO) cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.SCOLSTORBZ) // jump to column storage case + jz(.ZCOLSTORBZ) // jump to column storage case - label(.SROWSTORBZ) + label(.ZROWSTORBZ) + /* Store 3x4 elements to C matrix where is C row major order*/ + + // rsi = cs_c * sizeof((real +imag)dt) *numofElements + lea(mem(, rsi, 2), rsi) vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rsi, 8)) + vmovupd(ymm5, mem(rcx, rsi, 1)) add(rdi, rcx) vmovupd(ymm8, mem(rcx)) - vmovupd(ymm9, mem(rcx, rsi, 8)) + vmovupd(ymm9, mem(rcx, rsi, 1)) add(rdi, rcx) vmovupd(ymm12, mem(rcx)) - vmovupd(ymm13, mem(rcx, rsi, 8)) + vmovupd(ymm13, mem(rcx, rsi, 1)) - jmp(.SDONE) // jump to end. + jmp(.ZDONE) // jump to end. - label(.SCOLSTORBZ) + label(.ZCOLSTORBZ) /****3x4 tile going to save into 4x3 tile in C*****/ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) /******************Transpose top tile 4x3***************************/ vmovups(xmm4, mem(rcx)) @@ -787,7 +774,7 @@ void bli_zgemmsup_rv_zen_asm_3x4m vmovups(xmm9, mem(rcx, 16)) vmovups(xmm13,mem(rcx,32)) - label(.SDONE) + label(.ZDONE) lea(mem(r12, rdi, 2), r12) lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c @@ -796,9 +783,9 @@ void bli_zgemmsup_rv_zen_asm_3x4m lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a dec(r11) // ii -= 1; - jne(.SLOOP3X8I) // iterate again if ii != 0. + jne(.ZLOOP3X4I) // iterate again if ii != 0. - label(.SRETURN) + label(.ZRETURN) end_asm( : // output operands (none) @@ -813,7 +800,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m [cs_a] "m" (cs_a), [b] "m" (b), [rs_b] "m" (rs_b), - [cs_b] "m" (cs_b), [alpha] "m" (alpha), [beta] "m" (beta), [c] "m" (c), @@ -822,8 +808,8 @@ void bli_zgemmsup_rv_zen_asm_3x4m [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", - "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "rax", "rbx", "rcx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", @@ -896,7 +882,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; @@ -914,8 +899,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) -// lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a - mov(var(rs_b), r10) // load rs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt) @@ -939,41 +922,31 @@ void bli_zgemmsup_rv_zen_asm_3x2m mov(var(m_iter), r11) // ii = m_iter; - label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + label(.ZLOOP3X2I) // LOOP OVER ii = [ m_iter ... 1 0 ] vzeroall() // zero all xmm/ymm registers. mov(var(b), rbx) // load address of b. - //mov(r12, rcx) // reset rcx to current utile of c. mov(r14, rax) // reset rax to current upanel of a. cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.SCOLPFETCH) // jump to column storage case - label(.SROWPFETCH) // row-stored pre-fetching on c // not used + jz(.ZCOLPFETCH) // jump to column storage case + label(.ZROWPFETCH) // row-stored pre-fetching on c // not used - lea(mem(r12, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - - jmp(.SPOSTPFETCH) // jump to end of pre-fetching c - label(.SCOLPFETCH) // column-stored pre-fetching c + jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c + label(.ZCOLPFETCH) // column-stored pre-fetching c mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) - lea(mem(r12, rsi, 2), rdx) // - lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - label(.SPOSTPFETCH) // done prefetching c - - lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines - lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + label(.ZPOSTPFETCH) // done prefetching c mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. - je(.SCONSIDKLEFT) // if i == 0, jump to code that + je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - label(.SLOOPKITER) // MAIN LOOP + label(.ZLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 @@ -1051,8 +1024,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m add(r9, rax) // a += cs_a; // ---------------------------------- iteration 3 - lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1077,16 +1048,16 @@ void bli_zgemmsup_rv_zen_asm_3x2m add(r9, rax) // a += cs_a; dec(rsi) // i -= 1; - jne(.SLOOPKITER) // iterate again if i != 0. + jne(.ZLOOPKITER) // iterate again if i != 0. - label(.SCONSIDKLEFT) + label(.ZCONSIDKLEFT) mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. - je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - label(.SLOOPKLEFT) // EDGE LOOP + label(.ZLOOPKLEFT) // EDGE LOOP vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1112,9 +1083,9 @@ void bli_zgemmsup_rv_zen_asm_3x2m add(r9, rax) // a += cs_a; dec(rsi) // i -= 1; - jne(.SLOOPKLEFT) // iterate again if i != 0. + jne(.ZLOOPKLEFT) // iterate again if i != 0. - label(.SPOSTACCUM) + label(.ZPOSTACCUM) mov(r12, rcx) // reset rcx to current utile of c. @@ -1126,9 +1097,7 @@ void bli_zgemmsup_rv_zen_asm_3x2m // subtract/add even/odd elements vaddsubpd(ymm6, ymm4, ymm4) - vaddsubpd(ymm10, ymm8, ymm8) - vaddsubpd(ymm14, ymm12, ymm12) /* (ar + ai) x AB */ @@ -1156,12 +1125,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) - - lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. @@ -1169,13 +1132,12 @@ void bli_zgemmsup_rv_zen_asm_3x2m vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); and(r13b, r15b) // set ZF if r13b & r15b == 1. - jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.SCOLSTORED) // jump to column storage case + jz(.ZCOLSTORED) // jump to column storage case - label(.SROWSTORED) + label(.ZROWSTORED) ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm4, ymm0, ymm0) @@ -1193,9 +1155,9 @@ void bli_zgemmsup_rv_zen_asm_3x2m vaddpd(ymm12, ymm0, ymm0) ZGEMM_OUTPUT_RS - jmp(.SDONE) // jump to end. + jmp(.ZDONE) // jump to end. - label(.SCOLSTORED) + label(.ZCOLSTORED) /*|--------| |-------| | | | | | 3x2 | | 2x3 | @@ -1206,8 +1168,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a - ZGEMM_INPUT_SCALE_CS_BETA_NZ vaddpd(ymm4, ymm0, ymm4) @@ -1222,9 +1182,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m mov(r12, rcx) // reset rcx to current utile of c. /****3x2 tile going to save into 2x3 tile in C*****/ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) /******************Transpose top tile 2x3***************************/ vmovups(xmm4, mem(rcx)) @@ -1241,14 +1198,14 @@ void bli_zgemmsup_rv_zen_asm_3x2m vmovups(xmm12, mem(rcx,32)) - jmp(.SDONE) // jump to end. + jmp(.ZDONE) // jump to end. - label(.SBETAZERO) + label(.ZBETAZERO) cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8. - jz(.SCOLSTORBZ) // jump to column storage case + jz(.ZCOLSTORBZ) // jump to column storage case - label(.SROWSTORBZ) + label(.ZROWSTORBZ) vmovupd(ymm4, mem(rcx)) add(rdi, rcx) @@ -1258,14 +1215,14 @@ void bli_zgemmsup_rv_zen_asm_3x2m vmovupd(ymm12, mem(rcx)) - jmp(.SDONE) // jump to end. + jmp(.ZDONE) // jump to end. - label(.SCOLSTORBZ) + label(.ZCOLSTORBZ) /****3x2 tile going to save into 2x3 tile in C*****/ mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) /******************Transpose tile 3x2***************************/ vmovups(xmm4, mem(rcx)) @@ -1281,7 +1238,7 @@ void bli_zgemmsup_rv_zen_asm_3x2m vmovups(xmm8, mem(rcx, 16)) vmovups(xmm12, mem(rcx,32)) - label(.SDONE) + label(.ZDONE) lea(mem(r12, rdi, 2), r12) lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c @@ -1290,9 +1247,9 @@ void bli_zgemmsup_rv_zen_asm_3x2m lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a dec(r11) // ii -= 1; - jne(.SLOOP3X8I) // iterate again if ii != 0. + jne(.ZLOOP3X2I) // iterate again if ii != 0. - label(.SRETURN) + label(.ZRETURN) end_asm( : // output operands (none) @@ -1305,7 +1262,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m [cs_a] "m" (cs_a), [b] "m" (b), [rs_b] "m" (rs_b), - [cs_b] "m" (cs_b), [alpha] "m" (alpha), [beta] "m" (beta), [c] "m" (c), @@ -1314,7 +1270,7 @@ void bli_zgemmsup_rv_zen_asm_3x2m [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c index 072f5262c..44b43e741 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c @@ -767,7 +767,7 @@ void bli_zgemmsup_rv_zen_asm_2x4n ymm2 = _mm256_loadu_pd((double const *)(tC)); ymm3 = _mm256_permute_pd(ymm2, 5); ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm3 = _mm256_mul_pd(ymm1, ymm3); ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); ymm2 = _mm256_loadu_pd((double const *)(tC+2));