mirror of
https://github.com/amd/blis.git
synced 2026-05-12 01:59:59 +00:00
Merge "ZGEMM SUP: Removed unused assembly intructions" into amd-staging-milan-3.1
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2020, Advanced Micro Devices, Inc.
|
||||
Copyright (C) 2020-2021, Advanced Micro Devices, Inc.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -1281,6 +1281,9 @@ void bli_cgemmsup_rv_zen_asm_3x4
|
||||
ymm1 = _mm256_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate
|
||||
ymm2 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate
|
||||
|
||||
xmm0 = _mm_setzero_ps();
|
||||
xmm3 = _mm_setzero_ps();
|
||||
|
||||
//Multiply ymm4 with beta
|
||||
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ;
|
||||
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col));
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2020, Advanced Micro Devices, Inc.
|
||||
Copyright (C) 2020-2021, Advanced Micro Devices, Inc.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -38,8 +38,9 @@
|
||||
#define BLIS_ASM_SYNTAX_ATT
|
||||
#include "bli_x86_asm_macros.h"
|
||||
|
||||
// assumes beta.r, beta.i have been broadcast into ymm1, ymm2.
|
||||
// outputs to ymm0
|
||||
/* Assumes beta.r, beta.i have been broadcast into ymm1, ymm2.
|
||||
and store outputs to ymm0
|
||||
(creal,cimag)*(betar,beati) where c is stored in col major order*/
|
||||
#define ZGEMM_INPUT_SCALE_CS_BETA_NZ \
|
||||
vmovupd(mem(rcx), xmm0) \
|
||||
vmovupd(mem(rcx, rsi, 1), xmm3) \
|
||||
@@ -49,6 +50,7 @@
|
||||
vmulpd(ymm2, ymm3, ymm3) \
|
||||
vaddsubpd(ymm3, ymm0, ymm0)
|
||||
|
||||
//(creal,cimag)*(betar,beati) where c is stored in row major order
|
||||
#define ZGEMM_INPUT_SCALE_RS_BETA_NZ \
|
||||
vmovupd(mem(rcx), ymm0) \
|
||||
vpermilpd(imm(0x5), ymm0, ymm3) \
|
||||
@@ -59,15 +61,19 @@
|
||||
#define ZGEMM_OUTPUT_RS \
|
||||
vmovupd(ymm0, mem(rcx)) \
|
||||
|
||||
/*(cNextRowreal,cNextRowimag)*(betar,beati)
|
||||
where c is stored in row major order
|
||||
rsi = cs_c * sizeof((real +imag)dt)*numofElements
|
||||
numofElements = 2, 2 elements are processed at a time*/
|
||||
#define ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT \
|
||||
vmovupd(mem(rcx, rsi, 8), ymm0) \
|
||||
vmovupd(mem(rcx, rsi, 1), ymm0) \
|
||||
vpermilpd(imm(0x5), ymm0, ymm3) \
|
||||
vmulpd(ymm1, ymm0, ymm0) \
|
||||
vmulpd(ymm2, ymm3, ymm3) \
|
||||
vaddsubpd(ymm3, ymm0, ymm0)
|
||||
|
||||
#define ZGEMM_OUTPUT_RS_NEXT \
|
||||
vmovupd(ymm0, mem(rcx, rsi, 8)) \
|
||||
vmovupd(ymm0, mem(rcx, rsi, 1)) \
|
||||
|
||||
|
||||
void bli_zgemmsup_rv_zen_asm_2x4
|
||||
@@ -100,7 +106,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
uint64_t rs_a = rs_a0;
|
||||
uint64_t cs_a = cs_a0;
|
||||
uint64_t rs_b = rs_b0;
|
||||
uint64_t cs_b = cs_b0;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
@@ -116,8 +121,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt)
|
||||
lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt)
|
||||
|
||||
//lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
|
||||
mov(var(rs_b), r10) // load rs_b
|
||||
lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt)
|
||||
lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt)
|
||||
@@ -141,40 +144,29 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
|
||||
mov(var(m_iter), r11) // ii = m_iter;
|
||||
|
||||
label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
|
||||
vzeroall() // zero all xmm/ymm registers.
|
||||
|
||||
mov(var(b), rbx) // load address of b.
|
||||
//mov(r12, rcx) // reset rcx to current utile of c.
|
||||
|
||||
mov(r14, rax) // reset rax to current upanel of a.
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
|
||||
jz(.SCOLPFETCH) // jump to column storage case
|
||||
label(.SROWPFETCH) // row-stored pre-fetching on c // not used
|
||||
jz(.ZCOLPFETCH1) // jump to column storage case
|
||||
label(.ZROWPFETCH1) // row-stored pre-fetching on c // not used
|
||||
|
||||
lea(mem(r12, rdi, 2), rdx) //
|
||||
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
|
||||
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
|
||||
label(.SCOLPFETCH) // column-stored pre-fetching c
|
||||
jmp(.ZPOSTPFETCH1) // jump to end of pre-fetching c
|
||||
label(.ZCOLPFETCH1) // column-stored pre-fetching c
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
|
||||
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
|
||||
lea(mem(r12, rsi, 2), rdx) //
|
||||
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
|
||||
|
||||
label(.SPOSTPFETCH) // done prefetching c
|
||||
|
||||
lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a;
|
||||
lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines
|
||||
lea(mem(rdx, r8, 2), rdx) // from next upanel of a.
|
||||
label(.ZPOSTPFETCH1) // done prefetching c
|
||||
|
||||
mov(var(k_iter), rsi) // i = k_iter;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.SCONSIDKLEFT) // if i == 0, jump to code that
|
||||
je(.ZCONSIDKLEFT1) // if i == 0, jump to code that
|
||||
// contains the k_left loop.
|
||||
|
||||
label(.SLOOPKITER) // MAIN LOOP
|
||||
label(.ZLOOPKITER1) // MAIN LOOP
|
||||
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
@@ -251,8 +243,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
// ---------------------------------- iteration 3
|
||||
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
|
||||
|
||||
vmovupd(mem(rbx, 0*32), ymm0)
|
||||
vmovupd(mem(rbx, 1*32), ymm1)
|
||||
add(r10, rbx) // b += rs_b;
|
||||
@@ -277,16 +267,16 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.SLOOPKITER) // iterate again if i != 0.
|
||||
jne(.ZLOOPKITER1) // iterate again if i != 0.
|
||||
|
||||
label(.SCONSIDKLEFT)
|
||||
label(.ZCONSIDKLEFT1)
|
||||
|
||||
mov(var(k_left), rsi) // i = k_left;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
|
||||
je(.ZPOSTACCUM1) // if i == 0, we're done; jump to end.
|
||||
// else, we prepare to enter k_left loop.
|
||||
|
||||
label(.SLOOPKLEFT) // EDGE LOOP
|
||||
label(.ZLOOPKLEFT1) // EDGE LOOP
|
||||
|
||||
vmovupd(mem(rbx, 0*32), ymm0)
|
||||
vmovupd(mem(rbx, 1*32), ymm1)
|
||||
@@ -312,9 +302,9 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.SLOOPKLEFT) // iterate again if i != 0.
|
||||
jne(.ZLOOPKLEFT1) // iterate again if i != 0.
|
||||
|
||||
label(.SPOSTACCUM)
|
||||
label(.ZPOSTACCUM1)
|
||||
|
||||
mov(r12, rcx) // reset rcx to current utile of c.
|
||||
|
||||
@@ -363,10 +353,8 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
|
||||
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
|
||||
|
||||
// now avoid loading C if beta == 0
|
||||
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
|
||||
@@ -375,14 +363,13 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
|
||||
sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 );
|
||||
and(r13b, r15b) // set ZF if r13b & r15b == 1.
|
||||
jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
|
||||
|
||||
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
jne(.ZBETAZERO1) // if ZF = 1, jump to beta == 0 case
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16.
|
||||
jz(.SCOLSTORED) // jump to column storage case
|
||||
jz(.ZCOLSTORED1) // jump to column storage case
|
||||
|
||||
label(.SROWSTORED)
|
||||
label(.ZROWSTORED1)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) * numofElements
|
||||
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm0)
|
||||
@@ -401,21 +388,15 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
vaddpd(ymm9, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS_NEXT
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE1) // jump to end.
|
||||
|
||||
label(.SCOLSTORED)
|
||||
label(.ZCOLSTORED1)
|
||||
/*|--------| |-------|
|
||||
| | | |
|
||||
| 2x4 | | 4x2 |
|
||||
|--------| |-------|
|
||||
*/
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
|
||||
|
||||
lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a
|
||||
|
||||
ZGEMM_INPUT_SCALE_CS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm4)
|
||||
|
||||
@@ -438,9 +419,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
|
||||
|
||||
/****3x4 tile going to save into 4x2 tile in C*****/
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
/******************Transpose top tile 4x3***************************/
|
||||
vmovups(xmm4, mem(rcx))
|
||||
@@ -465,29 +443,27 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
vmovups(xmm5, mem(rcx))
|
||||
vmovups(xmm9, mem(rcx, 16))
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE1) // jump to end.
|
||||
|
||||
label(.SBETAZERO)
|
||||
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
label(.ZBETAZERO1)
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
|
||||
jz(.SCOLSTORBZ) // jump to column storage case
|
||||
jz(.ZCOLSTORBZ1) // jump to column storage case
|
||||
|
||||
label(.SROWSTORBZ)
|
||||
label(.ZROWSTORBZ1)
|
||||
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) * numofElements
|
||||
|
||||
vmovupd(ymm4, mem(rcx))
|
||||
vmovupd(ymm5, mem(rcx, rsi, 8))
|
||||
vmovupd(ymm5, mem(rcx, rsi, 1))
|
||||
add(rdi, rcx)
|
||||
|
||||
vmovupd(ymm8, mem(rcx))
|
||||
vmovupd(ymm9, mem(rcx, rsi, 8))
|
||||
vmovupd(ymm9, mem(rcx, rsi, 1))
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE1) // jump to end.
|
||||
|
||||
label(.SCOLSTORBZ)
|
||||
label(.ZCOLSTORBZ1)
|
||||
/****2x4 tile going to save into 4x2 tile in C*****/
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
/******************Transpose tile 2x4***************************/
|
||||
vmovups(xmm4, mem(rcx))
|
||||
@@ -512,7 +488,7 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
vmovups(xmm5, mem(rcx))
|
||||
vmovups(xmm9, mem(rcx, 16))
|
||||
|
||||
label(.SDONE)
|
||||
label(.ZDONE1)
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
@@ -525,7 +501,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
[cs_a] "m" (cs_a),
|
||||
[b] "m" (b),
|
||||
[rs_b] "m" (rs_b),
|
||||
[cs_b] "m" (cs_b),
|
||||
[alpha] "m" (alpha),
|
||||
[beta] "m" (beta),
|
||||
[c] "m" (c),
|
||||
@@ -534,7 +509,7 @@ void bli_zgemmsup_rv_zen_asm_2x4
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"rax", "rbx", "rcx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
@@ -577,7 +552,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
uint64_t rs_a = rs_a0;
|
||||
uint64_t cs_a = cs_a0;
|
||||
uint64_t rs_b = rs_b0;
|
||||
uint64_t cs_b = cs_b0;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
@@ -593,8 +567,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt)
|
||||
lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt)
|
||||
|
||||
//lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
|
||||
mov(var(rs_b), r10) // load rs_b
|
||||
lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt)
|
||||
lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt)
|
||||
@@ -618,8 +590,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
|
||||
mov(var(m_iter), r11) // ii = m_iter;
|
||||
|
||||
label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
|
||||
vzeroall() // zero all xmm/ymm registers.
|
||||
|
||||
mov(var(b), rbx) // load address of b.
|
||||
@@ -627,31 +597,23 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
mov(r14, rax) // reset rax to current upanel of a.
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
|
||||
jz(.SCOLPFETCH) // jump to column storage case
|
||||
label(.SROWPFETCH) // row-stored pre-fetching on c // not used
|
||||
jz(.ZCOLPFETCH1) // jump to column storage case
|
||||
label(.ZROWPFETCH1) // row-stored pre-fetching on c // not used
|
||||
|
||||
lea(mem(r12, rdi, 2), rdx) //
|
||||
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
|
||||
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
|
||||
label(.SCOLPFETCH) // column-stored pre-fetching c
|
||||
jmp(.ZPOSTPFETCH1) // jump to end of pre-fetching c
|
||||
label(.ZCOLPFETCH1) // column-stored pre-fetching c
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
|
||||
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
|
||||
lea(mem(r12, rsi, 2), rdx) //
|
||||
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
|
||||
|
||||
label(.SPOSTPFETCH) // done prefetching c
|
||||
|
||||
lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a;
|
||||
lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines
|
||||
lea(mem(rdx, r8, 2), rdx) // from next upanel of a.
|
||||
label(.ZPOSTPFETCH1) // done prefetching c
|
||||
|
||||
mov(var(k_iter), rsi) // i = k_iter;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.SCONSIDKLEFT) // if i == 0, jump to code that
|
||||
je(.ZCONSIDKLEFT1) // if i == 0, jump to code that
|
||||
// contains the k_left loop.
|
||||
|
||||
label(.SLOOPKITER) // MAIN LOOP
|
||||
label(.ZLOOPKITER1) // MAIN LOOP
|
||||
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
@@ -707,7 +669,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
// ---------------------------------- iteration 3
|
||||
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
|
||||
|
||||
vmovupd(mem(rbx, 0*32), ymm0)
|
||||
vmovupd(mem(rbx, 1*32), ymm1)
|
||||
@@ -725,16 +686,16 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.SLOOPKITER) // iterate again if i != 0.
|
||||
jne(.ZLOOPKITER1) // iterate again if i != 0.
|
||||
|
||||
label(.SCONSIDKLEFT)
|
||||
label(.ZCONSIDKLEFT1)
|
||||
|
||||
mov(var(k_left), rsi) // i = k_left;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
|
||||
je(.ZPOSTACCUM1) // if i == 0, we're done; jump to end.
|
||||
// else, we prepare to enter k_left loop.
|
||||
|
||||
label(.SLOOPKLEFT) // EDGE LOOP
|
||||
label(.ZLOOPKLEFT1) // EDGE LOOP
|
||||
|
||||
vmovupd(mem(rbx, 0*32), ymm0)
|
||||
vmovupd(mem(rbx, 1*32), ymm1)
|
||||
@@ -751,9 +712,9 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.SLOOPKLEFT) // iterate again if i != 0.
|
||||
jne(.ZLOOPKLEFT1) // iterate again if i != 0.
|
||||
|
||||
label(.SPOSTACCUM)
|
||||
label(.ZPOSTACCUM1)
|
||||
|
||||
mov(r12, rcx) // reset rcx to current utile of c.
|
||||
|
||||
@@ -787,10 +748,8 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
|
||||
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
|
||||
|
||||
// now avoid loading C if beta == 0
|
||||
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
|
||||
@@ -799,14 +758,14 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
|
||||
sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 );
|
||||
and(r13b, r15b) // set ZF if r13b & r15b == 1.
|
||||
jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
|
||||
|
||||
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
jne(.ZBETAZERO1) // if ZF = 1, jump to beta == 0 case
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16.
|
||||
jz(.SCOLSTORED) // jump to column storage case
|
||||
jz(.ZCOLSTORED1) // jump to column storage case
|
||||
|
||||
label(.SROWSTORED)
|
||||
label(.ZROWSTORED1)
|
||||
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) * numofElements
|
||||
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm0)
|
||||
@@ -816,21 +775,15 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
vaddpd(ymm5, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS_NEXT
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE1) // jump to end.
|
||||
|
||||
label(.SCOLSTORED)
|
||||
label(.ZCOLSTORED1)
|
||||
/*|--------| |-------|
|
||||
| | | |
|
||||
| 1x4 | | 4x1 |
|
||||
|--------| |-------|
|
||||
*/
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
|
||||
|
||||
lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a
|
||||
|
||||
ZGEMM_INPUT_SCALE_CS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm4)
|
||||
|
||||
@@ -842,9 +795,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
mov(r12, rcx) // reset rcx to current utile of c.
|
||||
|
||||
/****1x4 tile going to save into 4x1 tile in C*****/
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
vmovups(xmm4, mem(rcx))
|
||||
|
||||
@@ -862,21 +812,21 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
vextractf128(imm(0x1), ymm5, xmm5)
|
||||
vmovups(xmm5, mem(rcx))
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE1) // jump to end.
|
||||
|
||||
label(.SBETAZERO)
|
||||
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
label(.ZBETAZERO1)
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
|
||||
jz(.SCOLSTORBZ) // jump to column storage case
|
||||
jz(.ZCOLSTORBZ1) // jump to column storage case
|
||||
|
||||
label(.SROWSTORBZ)
|
||||
label(.ZROWSTORBZ1)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) * numofElements
|
||||
|
||||
vmovupd(ymm4, mem(rcx))
|
||||
vmovupd(ymm5, mem(rcx, rsi, 8))
|
||||
vmovupd(ymm5, mem(rcx, rsi, 1))
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE1) // jump to end.
|
||||
|
||||
label(.SCOLSTORBZ)
|
||||
label(.ZCOLSTORBZ1)
|
||||
/****1x4 tile going to save into 4x1 tile in C*****/
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
|
||||
@@ -898,7 +848,7 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
vextractf128(imm(0x1), ymm5, xmm5)
|
||||
vmovups(xmm5, mem(rcx))
|
||||
|
||||
label(.SDONE)
|
||||
label(.ZDONE1)
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
@@ -911,7 +861,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
[cs_a] "m" (cs_a),
|
||||
[b] "m" (b),
|
||||
[rs_b] "m" (rs_b),
|
||||
[cs_b] "m" (cs_b),
|
||||
[alpha] "m" (alpha),
|
||||
[beta] "m" (beta),
|
||||
[c] "m" (c),
|
||||
@@ -920,7 +869,7 @@ void bli_zgemmsup_rv_zen_asm_1x4
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"rax", "rbx", "rcx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
@@ -962,7 +911,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
uint64_t rs_a = rs_a0;
|
||||
uint64_t cs_a = cs_a0;
|
||||
uint64_t rs_b = rs_b0;
|
||||
uint64_t cs_b = cs_b0;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
@@ -1001,41 +949,29 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
|
||||
mov(var(m_iter), r11) // ii = m_iter;
|
||||
|
||||
label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
|
||||
vzeroall() // zero all xmm/ymm registers.
|
||||
|
||||
mov(var(b), rbx) // load address of b.
|
||||
//mov(r12, rcx) // reset rcx to current utile of c.
|
||||
mov(r14, rax) // reset rax to current upanel of a.
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
|
||||
jz(.SCOLPFETCH) // jump to column storage case
|
||||
label(.SROWPFETCH) // row-stored pre-fetching on c // not used
|
||||
jz(.ZCOLPFETCH1) // jump to column storage case
|
||||
label(.ZROWPFETCH1) // row-stored pre-fetching on c // not used
|
||||
|
||||
lea(mem(r12, rdi, 2), rdx) //
|
||||
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
|
||||
|
||||
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
|
||||
label(.SCOLPFETCH) // column-stored pre-fetching c
|
||||
jmp(.ZPOSTPFETCH1) // jump to end of pre-fetching c
|
||||
label(.ZCOLPFETCH1) // column-stored pre-fetching c
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
|
||||
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
|
||||
lea(mem(r12, rsi, 2), rdx) //
|
||||
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
|
||||
|
||||
label(.SPOSTPFETCH) // done prefetching c
|
||||
|
||||
lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a;
|
||||
lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines
|
||||
lea(mem(rdx, r8, 2), rdx) // from next upanel of a.
|
||||
label(.ZPOSTPFETCH1) // done prefetching c
|
||||
|
||||
mov(var(k_iter), rsi) // i = k_iter;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.SCONSIDKLEFT) // if i == 0, jump to code that
|
||||
je(.ZCONSIDKLEFT1) // if i == 0, jump to code that
|
||||
// contains the k_left loop.
|
||||
|
||||
label(.SLOOPKITER) // MAIN LOOP
|
||||
label(.ZLOOPKITER1) // MAIN LOOP
|
||||
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
@@ -1096,7 +1032,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
// ---------------------------------- iteration 3
|
||||
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
|
||||
|
||||
vmovupd(mem(rbx, 0*32), ymm0)
|
||||
add(r10, rbx) // b += rs_b;
|
||||
@@ -1117,16 +1052,16 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.SLOOPKITER) // iterate again if i != 0.
|
||||
jne(.ZLOOPKITER1) // iterate again if i != 0.
|
||||
|
||||
label(.SCONSIDKLEFT)
|
||||
label(.ZCONSIDKLEFT1)
|
||||
|
||||
mov(var(k_left), rsi) // i = k_left;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
|
||||
je(.ZPOSTACCUM1) // if i == 0, we're done; jump to end.
|
||||
// else, we prepare to enter k_left loop.
|
||||
|
||||
label(.SLOOPKLEFT) // EDGE LOOP
|
||||
label(.ZLOOPKLEFT1) // EDGE LOOP
|
||||
|
||||
vmovupd(mem(rbx, 0*32), ymm0)
|
||||
add(r10, rbx) // b += rs_b;
|
||||
@@ -1147,9 +1082,9 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.SLOOPKLEFT) // iterate again if i != 0.
|
||||
jne(.ZLOOPKLEFT1) // iterate again if i != 0.
|
||||
|
||||
label(.SPOSTACCUM)
|
||||
label(.ZPOSTACCUM1)
|
||||
|
||||
mov(r12, rcx) // reset rcx to current utile of c.
|
||||
|
||||
@@ -1183,12 +1118,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
|
||||
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
|
||||
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
|
||||
|
||||
// now avoid loading C if beta == 0
|
||||
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
|
||||
vucomisd(xmm0, xmm1) // set ZF if beta_r == 0.
|
||||
@@ -1196,13 +1125,12 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
|
||||
sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 );
|
||||
and(r13b, r15b) // set ZF if r13b & r15b == 1.
|
||||
jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
|
||||
jne(.ZBETAZERO1) // if ZF = 1, jump to beta == 0 case
|
||||
|
||||
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
|
||||
jz(.SCOLSTORED) // jump to column storage case
|
||||
jz(.ZCOLSTORED1) // jump to column storage case
|
||||
|
||||
label(.SROWSTORED)
|
||||
label(.ZROWSTORED1)
|
||||
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm0)
|
||||
@@ -1214,9 +1142,9 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
vaddpd(ymm8, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE1) // jump to end.
|
||||
|
||||
label(.SCOLSTORED)
|
||||
label(.ZCOLSTORED1)
|
||||
/*|--------| |-------|
|
||||
| | | |
|
||||
| 2x2 | | 2x2 |
|
||||
@@ -1227,8 +1155,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
|
||||
|
||||
lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a
|
||||
|
||||
ZGEMM_INPUT_SCALE_CS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm4)
|
||||
|
||||
@@ -1239,9 +1165,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
mov(r12, rcx) // reset rcx to current utile of c.
|
||||
|
||||
/****2x2 tile going to save into 2x2 tile in C*****/
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
vmovups(xmm4, mem(rcx))
|
||||
vmovups(xmm8, mem(rcx, 16))
|
||||
@@ -1253,23 +1176,23 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
vmovups(xmm4, mem(rcx))
|
||||
vmovups(xmm8, mem(rcx, 16))
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE1) // jump to end.
|
||||
|
||||
label(.SBETAZERO)
|
||||
label(.ZBETAZERO1)
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8.
|
||||
jz(.SCOLSTORBZ) // jump to column storage case
|
||||
jz(.ZCOLSTORBZ1) // jump to column storage case
|
||||
|
||||
label(.SROWSTORBZ)
|
||||
label(.ZROWSTORBZ1)
|
||||
|
||||
vmovupd(ymm4, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
|
||||
vmovupd(ymm8, mem(rcx))
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE1) // jump to end.
|
||||
|
||||
label(.SCOLSTORBZ)
|
||||
label(.ZCOLSTORBZ1)
|
||||
|
||||
/****2x2 tile going to save into 2x2 tile in C*****/
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
@@ -1286,7 +1209,7 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
vmovups(xmm4, mem(rcx))
|
||||
vmovups(xmm8, mem(rcx, 16))
|
||||
|
||||
label(.SDONE)
|
||||
label(.ZDONE1)
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
@@ -1299,7 +1222,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
[cs_a] "m" (cs_a),
|
||||
[b] "m" (b),
|
||||
[rs_b] "m" (rs_b),
|
||||
[cs_b] "m" (cs_b),
|
||||
[alpha] "m" (alpha),
|
||||
[beta] "m" (beta),
|
||||
[c] "m" (c),
|
||||
@@ -1308,7 +1230,7 @@ void bli_zgemmsup_rv_zen_asm_2x2
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"rax", "rbx", "rcx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
@@ -1350,7 +1272,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
uint64_t rs_a = rs_a0;
|
||||
uint64_t cs_a = cs_a0;
|
||||
uint64_t rs_b = rs_b0;
|
||||
uint64_t cs_b = cs_b0;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
@@ -1366,7 +1287,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt)
|
||||
lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt)
|
||||
|
||||
// lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
|
||||
mov(var(rs_b), r10) // load rs_b
|
||||
lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt)
|
||||
@@ -1391,8 +1311,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
|
||||
mov(var(m_iter), r11) // ii = m_iter;
|
||||
|
||||
label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
|
||||
vzeroall() // zero all xmm/ymm registers.
|
||||
|
||||
mov(var(b), rbx) // load address of b.
|
||||
@@ -1400,32 +1318,23 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
mov(r14, rax) // reset rax to current upanel of a.
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
|
||||
jz(.SCOLPFETCH) // jump to column storage case
|
||||
label(.SROWPFETCH) // row-stored pre-fetching on c // not used
|
||||
jz(.ZCOLPFETCH1) // jump to column storage case
|
||||
label(.ZROWPFETCH1) // row-stored pre-fetching on c // not used
|
||||
|
||||
lea(mem(r12, rdi, 2), rdx) //
|
||||
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
|
||||
|
||||
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
|
||||
label(.SCOLPFETCH) // column-stored pre-fetching c
|
||||
jmp(.ZPOSTPFETCH1) // jump to end of pre-fetching c
|
||||
label(.ZCOLPFETCH1) // column-stored pre-fetching c
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
|
||||
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
|
||||
lea(mem(r12, rsi, 2), rdx) //
|
||||
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
|
||||
|
||||
label(.SPOSTPFETCH) // done prefetching c
|
||||
|
||||
lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a;
|
||||
lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines
|
||||
lea(mem(rdx, r8, 2), rdx) // from next upanel of a.
|
||||
label(.ZPOSTPFETCH1) // done prefetching c
|
||||
|
||||
mov(var(k_iter), rsi) // i = k_iter;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.SCONSIDKLEFT) // if i == 0, jump to code that
|
||||
je(.ZCONSIDKLEFT1) // if i == 0, jump to code that
|
||||
// contains the k_left loop.
|
||||
|
||||
label(.SLOOPKITER) // MAIN LOOP
|
||||
label(.ZLOOPKITER1) // MAIN LOOP
|
||||
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
@@ -1469,8 +1378,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
// ---------------------------------- iteration 3
|
||||
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
|
||||
|
||||
vmovupd(mem(rbx, 0*32), ymm0)
|
||||
add(r10, rbx) // b += rs_b;
|
||||
|
||||
@@ -1483,16 +1390,16 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.SLOOPKITER) // iterate again if i != 0.
|
||||
jne(.ZLOOPKITER1) // iterate again if i != 0.
|
||||
|
||||
label(.SCONSIDKLEFT)
|
||||
label(.ZCONSIDKLEFT1)
|
||||
|
||||
mov(var(k_left), rsi) // i = k_left;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
|
||||
je(.ZPOSTACCUM1) // if i == 0, we're done; jump to end.
|
||||
// else, we prepare to enter k_left loop.
|
||||
|
||||
label(.SLOOPKLEFT) // EDGE LOOP
|
||||
label(.ZLOOPKLEFT1) // EDGE LOOP
|
||||
|
||||
vmovupd(mem(rbx, 0*32), ymm0)
|
||||
add(r10, rbx) // b += rs_b;
|
||||
@@ -1506,9 +1413,9 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.SLOOPKLEFT) // iterate again if i != 0.
|
||||
jne(.ZLOOPKLEFT1) // iterate again if i != 0.
|
||||
|
||||
label(.SPOSTACCUM)
|
||||
label(.ZPOSTACCUM1)
|
||||
|
||||
mov(r12, rcx) // reset rcx to current utile of c.
|
||||
|
||||
@@ -1534,12 +1441,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
|
||||
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
|
||||
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
|
||||
|
||||
// now avoid loading C if beta == 0
|
||||
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
|
||||
vucomisd(xmm0, xmm1) // set ZF if beta_r == 0.
|
||||
@@ -1547,21 +1448,20 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
|
||||
sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 );
|
||||
and(r13b, r15b) // set ZF if r13b & r15b == 1.
|
||||
jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
|
||||
jne(.ZBETAZERO1) // if ZF = 1, jump to beta == 0 case
|
||||
|
||||
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
|
||||
jz(.SCOLSTORED) // jump to column storage case
|
||||
jz(.ZCOLSTORED1) // jump to column storage case
|
||||
|
||||
label(.SROWSTORED)
|
||||
label(.ZROWSTORED1)
|
||||
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE1) // jump to end.
|
||||
|
||||
label(.SCOLSTORED)
|
||||
label(.ZCOLSTORED1)
|
||||
/*|--------| |-------|
|
||||
| | | |
|
||||
| 1x2 | | 2x1 |
|
||||
@@ -1572,17 +1472,11 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
|
||||
|
||||
lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a
|
||||
|
||||
ZGEMM_INPUT_SCALE_CS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm4)
|
||||
|
||||
|
||||
/****3x4 tile going to save into 4x3 tile in C*****/
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
/******************Transpose tile 1x2***************************/
|
||||
vmovups(xmm4, mem(rcx))
|
||||
add(rsi, rcx)
|
||||
@@ -1591,22 +1485,22 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
vmovups(xmm4, mem(rcx))
|
||||
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE1) // jump to end.
|
||||
|
||||
label(.SBETAZERO)
|
||||
label(.ZBETAZERO1)
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8.
|
||||
jz(.SCOLSTORBZ) // jump to column storage case
|
||||
jz(.ZCOLSTORBZ1) // jump to column storage case
|
||||
|
||||
label(.SROWSTORBZ)
|
||||
label(.ZROWSTORBZ1)
|
||||
|
||||
vmovupd(ymm4, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE1) // jump to end.
|
||||
|
||||
label(.SCOLSTORBZ)
|
||||
label(.ZCOLSTORBZ1)
|
||||
|
||||
/****1x2 tile going to save into 2x1 tile in C*****/
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
@@ -1622,7 +1516,7 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
vextractf128(imm(0x1), ymm4, xmm4)
|
||||
vmovups(xmm4, mem(rcx))
|
||||
|
||||
label(.SDONE)
|
||||
label(.ZDONE1)
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
@@ -1635,7 +1529,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
[cs_a] "m" (cs_a),
|
||||
[b] "m" (b),
|
||||
[rs_b] "m" (rs_b),
|
||||
[cs_b] "m" (cs_b),
|
||||
[alpha] "m" (alpha),
|
||||
[beta] "m" (beta),
|
||||
[c] "m" (c),
|
||||
@@ -1644,7 +1537,7 @@ void bli_zgemmsup_rv_zen_asm_1x2
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"rax", "rbx", "rcx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
|
||||
@@ -38,8 +38,9 @@
|
||||
#define BLIS_ASM_SYNTAX_ATT
|
||||
#include "bli_x86_asm_macros.h"
|
||||
|
||||
// assumes beta.r, beta.i have been broadcast into ymm1, ymm2.
|
||||
// outputs to ymm0
|
||||
/* Assumes beta.r, beta.i have been broadcast into ymm1, ymm2.
|
||||
and store outputs to ymm0
|
||||
(creal,cimag)*(betar,beati) where c is stored in col major order*/
|
||||
#define ZGEMM_INPUT_SCALE_CS_BETA_NZ \
|
||||
vmovupd(mem(rcx), xmm0) \
|
||||
vmovupd(mem(rcx, rsi, 1), xmm3) \
|
||||
@@ -49,6 +50,7 @@
|
||||
vmulpd(ymm2, ymm3, ymm3) \
|
||||
vaddsubpd(ymm3, ymm0, ymm0)
|
||||
|
||||
//(creal,cimag)*(betar,beati) where c is stored in row major order
|
||||
#define ZGEMM_INPUT_SCALE_RS_BETA_NZ \
|
||||
vmovupd(mem(rcx), ymm0) \
|
||||
vpermilpd(imm(0x5), ymm0, ymm3) \
|
||||
@@ -62,18 +64,22 @@
|
||||
#define ZGEMM_OUTPUT_RS \
|
||||
vmovupd(ymm0, mem(rcx)) \
|
||||
|
||||
/*(cNextRowreal,cNextRowimag)*(betar,beati)
|
||||
where c is stored in row major order
|
||||
rsi = cs_c * sizeof((real +imag)dt)*numofElements
|
||||
numofElements = 2, 2 elements are processed at a time*/
|
||||
#define ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT \
|
||||
vmovupd(mem(rcx, rsi, 8), ymm0) \
|
||||
vmovupd(mem(rcx, rsi, 1), ymm0) \
|
||||
vpermilpd(imm(0x5), ymm0, ymm3) \
|
||||
vmulpd(ymm1, ymm0, ymm0) \
|
||||
vmulpd(ymm2, ymm3, ymm3) \
|
||||
vaddsubpd(ymm3, ymm0, ymm0)
|
||||
|
||||
#define ZGEMM_INPUT_RS_BETA_ONE_NEXT \
|
||||
vmovupd(mem(rcx, rsi, 8), ymm0)
|
||||
vmovupd(mem(rcx, rsi, 1), ymm0)
|
||||
|
||||
#define ZGEMM_OUTPUT_RS_NEXT \
|
||||
vmovupd(ymm0, mem(rcx, rsi, 8))
|
||||
vmovupd(ymm0, mem(rcx, rsi, 1))
|
||||
|
||||
/*
|
||||
rrr:
|
||||
@@ -174,7 +180,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
uint64_t rs_a = rs_a0;
|
||||
uint64_t cs_a = cs_a0;
|
||||
uint64_t rs_b = rs_b0;
|
||||
uint64_t cs_b = cs_b0;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
@@ -227,8 +232,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
lea(mem(, r9, 8), r9) // cs_a *= sizeof( real dt)
|
||||
lea(mem(, r9, 2), r9) // cs_a *= sizeof((real + imag) dt)
|
||||
|
||||
//lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
|
||||
mov(var(rs_b), r10) // load rs_b
|
||||
lea(mem(, r10, 8), r10) // rs_b *= sizeof(real dt)
|
||||
lea(mem(, r10, 2), r10) // rs_b *= sizeof((real +imag) dt)
|
||||
@@ -252,7 +255,7 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
|
||||
mov(var(m_iter), r11) // ii = m_iter;
|
||||
|
||||
label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
label(.ZLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
|
||||
vzeroall() // zero all xmm/ymm registers.
|
||||
|
||||
@@ -260,31 +263,22 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
mov(r14, rax) // reset rax to current upanel of a.
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
|
||||
jz(.SCOLPFETCH) // jump to column storage case
|
||||
label(.SROWPFETCH) // row-stored pre-fetching on c // not used
|
||||
jz(.ZCOLPFETCH) // jump to column storage case
|
||||
label(.ZROWPFETCH) // row-stored pre-fetching on c // not used
|
||||
|
||||
lea(mem(r12, rdi, 2), rdx) //
|
||||
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
|
||||
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
|
||||
label(.SCOLPFETCH) // column-stored pre-fetching c
|
||||
jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c
|
||||
label(.ZCOLPFETCH) // column-stored pre-fetching c
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
|
||||
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
|
||||
lea(mem(r12, rsi, 2), rdx) //
|
||||
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
|
||||
|
||||
label(.SPOSTPFETCH) // done prefetching c
|
||||
|
||||
lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a;
|
||||
lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines
|
||||
lea(mem(rdx, r8, 2), rdx) // from next upanel of a.
|
||||
label(.ZPOSTPFETCH) // done prefetching c
|
||||
|
||||
mov(var(k_iter), rsi) // i = k_iter;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.SCONSIDKLEFT) // if i == 0, jump to code that
|
||||
je(.ZCONSIDKLEFT) // if i == 0, jump to code that
|
||||
// contains the k_left loop.
|
||||
|
||||
label(.SLOOPKITER) // MAIN LOOP
|
||||
label(.ZLOOPKITER) // MAIN LOOP
|
||||
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
@@ -383,8 +377,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
// ---------------------------------- iteration 3
|
||||
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
|
||||
|
||||
vmovupd(mem(rbx, 0*32), ymm0)
|
||||
vmovupd(mem(rbx, 1*32), ymm1)
|
||||
add(r10, rbx) // b += rs_b;
|
||||
@@ -416,16 +408,16 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.SLOOPKITER) // iterate again if i != 0.
|
||||
jne(.ZLOOPKITER) // iterate again if i != 0.
|
||||
|
||||
label(.SCONSIDKLEFT)
|
||||
label(.ZCONSIDKLEFT)
|
||||
|
||||
mov(var(k_left), rsi) // i = k_left;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
|
||||
je(.ZPOSTACCUM) // if i == 0, we're done; jump to end.
|
||||
// else, we prepare to enter k_left loop.
|
||||
|
||||
label(.SLOOPKLEFT) // EDGE LOOP
|
||||
label(.ZLOOPKLEFT) // EDGE LOOP
|
||||
|
||||
vmovupd(mem(rbx, 0*32), ymm0)
|
||||
vmovupd(mem(rbx, 1*32), ymm1)
|
||||
@@ -458,9 +450,9 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.SLOOPKLEFT) // iterate again if i != 0.
|
||||
jne(.ZLOOPKLEFT) // iterate again if i != 0.
|
||||
|
||||
label(.SPOSTACCUM)
|
||||
label(.ZPOSTACCUM)
|
||||
|
||||
mov(r12, rcx) // reset rcx to current utile of c.
|
||||
|
||||
@@ -483,6 +475,10 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
vaddsubpd(ymm14, ymm12, ymm12)
|
||||
vaddsubpd(ymm15, ymm13, ymm13)
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt)
|
||||
|
||||
//if(alpha_mul_type == BLIS_MUL_MINUS_ONE)
|
||||
mov(var(alpha_mul_type), al)
|
||||
cmp(imm(0xFF), al)
|
||||
@@ -541,19 +537,18 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
label(.ALPHA_REAL_ONE)
|
||||
// Beta multiplication
|
||||
/* (br + bi)x C + ((ar + ai) x AB) */
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
mov(var(beta_mul_type), al)
|
||||
cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO)
|
||||
je(.SBETAZERO) //jump to beta == 0 case
|
||||
|
||||
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
je(.ZBETAZERO) //jump to beta == 0 case
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16.
|
||||
jz(.SCOLSTORED) // jump to column storage case
|
||||
jz(.ZCOLSTORED) // jump to column storage case
|
||||
|
||||
label(.ZROWSTORED)
|
||||
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) * numofElements
|
||||
|
||||
label(.SROWSTORED)
|
||||
cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT)
|
||||
je(.ROW_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication.
|
||||
|
||||
@@ -586,7 +581,7 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
ZGEMM_INPUT_RS_BETA_ONE_NEXT
|
||||
vaddpd(ymm13, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS_NEXT
|
||||
jmp(.SDONE)
|
||||
jmp(.ZDONE)
|
||||
|
||||
|
||||
//CASE 2: beta is real = -1
|
||||
@@ -616,7 +611,7 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
ZGEMM_INPUT_RS_BETA_ONE_NEXT
|
||||
vsubpd(ymm0, ymm13, ymm0)
|
||||
ZGEMM_OUTPUT_RS_NEXT
|
||||
jmp(.SDONE)
|
||||
jmp(.ZDONE)
|
||||
|
||||
|
||||
//CASE 3: Default case with multiplication
|
||||
@@ -651,9 +646,9 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT
|
||||
vaddpd(ymm13, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS_NEXT
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE) // jump to end.
|
||||
|
||||
label(.SCOLSTORED)
|
||||
label(.ZCOLSTORED)
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
|
||||
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
|
||||
@@ -663,11 +658,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
|--------| |-------|
|
||||
*/
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt)
|
||||
lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a
|
||||
|
||||
ZGEMM_INPUT_SCALE_CS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm4)
|
||||
|
||||
@@ -696,9 +686,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
|
||||
|
||||
/****3x4 tile going to save into 4x3 tile in C*****/
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt)
|
||||
|
||||
/******************Transpose top tile 4x3***************************/
|
||||
vmovups(xmm4, mem(rcx))
|
||||
@@ -729,34 +716,34 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
vmovups(xmm9, mem(rcx, 16))
|
||||
vmovups(xmm13,mem(rcx,32))
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE) // jump to end.
|
||||
|
||||
label(.SBETAZERO)
|
||||
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
label(.ZBETAZERO)
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
|
||||
jz(.SCOLSTORBZ) // jump to column storage case
|
||||
jz(.ZCOLSTORBZ) // jump to column storage case
|
||||
|
||||
label(.SROWSTORBZ)
|
||||
label(.ZROWSTORBZ)
|
||||
/* Store 3x4 elements to C matrix where is C row major order*/
|
||||
|
||||
// rsi = cs_c * sizeof((real +imag)dt) *numofElements
|
||||
lea(mem(, rsi, 2), rsi)
|
||||
|
||||
vmovupd(ymm4, mem(rcx))
|
||||
vmovupd(ymm5, mem(rcx, rsi, 8))
|
||||
vmovupd(ymm5, mem(rcx, rsi, 1))
|
||||
add(rdi, rcx)
|
||||
|
||||
vmovupd(ymm8, mem(rcx))
|
||||
vmovupd(ymm9, mem(rcx, rsi, 8))
|
||||
vmovupd(ymm9, mem(rcx, rsi, 1))
|
||||
add(rdi, rcx)
|
||||
|
||||
vmovupd(ymm12, mem(rcx))
|
||||
vmovupd(ymm13, mem(rcx, rsi, 8))
|
||||
vmovupd(ymm13, mem(rcx, rsi, 1))
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE) // jump to end.
|
||||
|
||||
label(.SCOLSTORBZ)
|
||||
label(.ZCOLSTORBZ)
|
||||
|
||||
/****3x4 tile going to save into 4x3 tile in C*****/
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
/******************Transpose top tile 4x3***************************/
|
||||
vmovups(xmm4, mem(rcx))
|
||||
@@ -787,7 +774,7 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
vmovups(xmm9, mem(rcx, 16))
|
||||
vmovups(xmm13,mem(rcx,32))
|
||||
|
||||
label(.SDONE)
|
||||
label(.ZDONE)
|
||||
|
||||
lea(mem(r12, rdi, 2), r12)
|
||||
lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c
|
||||
@@ -796,9 +783,9 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a
|
||||
|
||||
dec(r11) // ii -= 1;
|
||||
jne(.SLOOP3X8I) // iterate again if ii != 0.
|
||||
jne(.ZLOOP3X4I) // iterate again if ii != 0.
|
||||
|
||||
label(.SRETURN)
|
||||
label(.ZRETURN)
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
@@ -813,7 +800,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
[cs_a] "m" (cs_a),
|
||||
[b] "m" (b),
|
||||
[rs_b] "m" (rs_b),
|
||||
[cs_b] "m" (cs_b),
|
||||
[alpha] "m" (alpha),
|
||||
[beta] "m" (beta),
|
||||
[c] "m" (c),
|
||||
@@ -822,8 +808,8 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"rax", "rbx", "rcx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
"xmm8", "xmm9", "xmm10", "xmm11",
|
||||
@@ -896,7 +882,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
uint64_t rs_a = rs_a0;
|
||||
uint64_t cs_a = cs_a0;
|
||||
uint64_t rs_b = rs_b0;
|
||||
uint64_t cs_b = cs_b0;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
@@ -914,8 +899,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt)
|
||||
lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt)
|
||||
|
||||
// lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
|
||||
mov(var(rs_b), r10) // load rs_b
|
||||
lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt)
|
||||
lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt)
|
||||
@@ -939,41 +922,31 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
|
||||
mov(var(m_iter), r11) // ii = m_iter;
|
||||
|
||||
label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
label(.ZLOOP3X2I) // LOOP OVER ii = [ m_iter ... 1 0 ]
|
||||
|
||||
vzeroall() // zero all xmm/ymm registers.
|
||||
|
||||
mov(var(b), rbx) // load address of b.
|
||||
//mov(r12, rcx) // reset rcx to current utile of c.
|
||||
mov(r14, rax) // reset rax to current upanel of a.
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
|
||||
jz(.SCOLPFETCH) // jump to column storage case
|
||||
label(.SROWPFETCH) // row-stored pre-fetching on c // not used
|
||||
jz(.ZCOLPFETCH) // jump to column storage case
|
||||
label(.ZROWPFETCH) // row-stored pre-fetching on c // not used
|
||||
|
||||
lea(mem(r12, rdi, 2), rdx) //
|
||||
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
|
||||
|
||||
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
|
||||
label(.SCOLPFETCH) // column-stored pre-fetching c
|
||||
jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c
|
||||
label(.ZCOLPFETCH) // column-stored pre-fetching c
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
|
||||
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
|
||||
lea(mem(r12, rsi, 2), rdx) //
|
||||
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
|
||||
|
||||
label(.SPOSTPFETCH) // done prefetching c
|
||||
|
||||
lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a;
|
||||
lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines
|
||||
lea(mem(rdx, r8, 2), rdx) // from next upanel of a.
|
||||
label(.ZPOSTPFETCH) // done prefetching c
|
||||
|
||||
mov(var(k_iter), rsi) // i = k_iter;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.SCONSIDKLEFT) // if i == 0, jump to code that
|
||||
je(.ZCONSIDKLEFT) // if i == 0, jump to code that
|
||||
// contains the k_left loop.
|
||||
|
||||
label(.SLOOPKITER) // MAIN LOOP
|
||||
label(.ZLOOPKITER) // MAIN LOOP
|
||||
|
||||
// ---------------------------------- iteration 0
|
||||
|
||||
@@ -1051,8 +1024,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
// ---------------------------------- iteration 3
|
||||
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
|
||||
|
||||
vmovupd(mem(rbx, 0*32), ymm0)
|
||||
add(r10, rbx) // b += rs_b;
|
||||
|
||||
@@ -1077,16 +1048,16 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.SLOOPKITER) // iterate again if i != 0.
|
||||
jne(.ZLOOPKITER) // iterate again if i != 0.
|
||||
|
||||
label(.SCONSIDKLEFT)
|
||||
label(.ZCONSIDKLEFT)
|
||||
|
||||
mov(var(k_left), rsi) // i = k_left;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
|
||||
je(.ZPOSTACCUM) // if i == 0, we're done; jump to end.
|
||||
// else, we prepare to enter k_left loop.
|
||||
|
||||
label(.SLOOPKLEFT) // EDGE LOOP
|
||||
label(.ZLOOPKLEFT) // EDGE LOOP
|
||||
|
||||
vmovupd(mem(rbx, 0*32), ymm0)
|
||||
add(r10, rbx) // b += rs_b;
|
||||
@@ -1112,9 +1083,9 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
add(r9, rax) // a += cs_a;
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.SLOOPKLEFT) // iterate again if i != 0.
|
||||
jne(.ZLOOPKLEFT) // iterate again if i != 0.
|
||||
|
||||
label(.SPOSTACCUM)
|
||||
label(.ZPOSTACCUM)
|
||||
|
||||
mov(r12, rcx) // reset rcx to current utile of c.
|
||||
|
||||
@@ -1126,9 +1097,7 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
|
||||
// subtract/add even/odd elements
|
||||
vaddsubpd(ymm6, ymm4, ymm4)
|
||||
|
||||
vaddsubpd(ymm10, ymm8, ymm8)
|
||||
|
||||
vaddsubpd(ymm14, ymm12, ymm12)
|
||||
|
||||
/* (ar + ai) x AB */
|
||||
@@ -1156,12 +1125,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
|
||||
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
|
||||
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
|
||||
|
||||
// now avoid loading C if beta == 0
|
||||
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
|
||||
vucomisd(xmm0, xmm1) // set ZF if beta_r == 0.
|
||||
@@ -1169,13 +1132,12 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
|
||||
sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 );
|
||||
and(r13b, r15b) // set ZF if r13b & r15b == 1.
|
||||
jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
|
||||
jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case
|
||||
|
||||
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
|
||||
jz(.SCOLSTORED) // jump to column storage case
|
||||
jz(.ZCOLSTORED) // jump to column storage case
|
||||
|
||||
label(.SROWSTORED)
|
||||
label(.ZROWSTORED)
|
||||
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm0)
|
||||
@@ -1193,9 +1155,9 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
vaddpd(ymm12, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE) // jump to end.
|
||||
|
||||
label(.SCOLSTORED)
|
||||
label(.ZCOLSTORED)
|
||||
/*|--------| |-------|
|
||||
| | | |
|
||||
| 3x2 | | 2x3 |
|
||||
@@ -1206,8 +1168,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
|
||||
|
||||
lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a
|
||||
|
||||
ZGEMM_INPUT_SCALE_CS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm4)
|
||||
|
||||
@@ -1222,9 +1182,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
mov(r12, rcx) // reset rcx to current utile of c.
|
||||
|
||||
/****3x2 tile going to save into 2x3 tile in C*****/
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
/******************Transpose top tile 2x3***************************/
|
||||
vmovups(xmm4, mem(rcx))
|
||||
@@ -1241,14 +1198,14 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
vmovups(xmm12, mem(rcx,32))
|
||||
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE) // jump to end.
|
||||
|
||||
label(.SBETAZERO)
|
||||
label(.ZBETAZERO)
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8.
|
||||
jz(.SCOLSTORBZ) // jump to column storage case
|
||||
jz(.ZCOLSTORBZ) // jump to column storage case
|
||||
|
||||
label(.SROWSTORBZ)
|
||||
label(.ZROWSTORBZ)
|
||||
|
||||
vmovupd(ymm4, mem(rcx))
|
||||
add(rdi, rcx)
|
||||
@@ -1258,14 +1215,14 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
|
||||
vmovupd(ymm12, mem(rcx))
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
jmp(.ZDONE) // jump to end.
|
||||
|
||||
label(.SCOLSTORBZ)
|
||||
label(.ZCOLSTORBZ)
|
||||
|
||||
/****3x2 tile going to save into 2x3 tile in C*****/
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
|
||||
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
|
||||
|
||||
/******************Transpose tile 3x2***************************/
|
||||
vmovups(xmm4, mem(rcx))
|
||||
@@ -1281,7 +1238,7 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
vmovups(xmm8, mem(rcx, 16))
|
||||
vmovups(xmm12, mem(rcx,32))
|
||||
|
||||
label(.SDONE)
|
||||
label(.ZDONE)
|
||||
|
||||
lea(mem(r12, rdi, 2), r12)
|
||||
lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c
|
||||
@@ -1290,9 +1247,9 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a
|
||||
|
||||
dec(r11) // ii -= 1;
|
||||
jne(.SLOOP3X8I) // iterate again if ii != 0.
|
||||
jne(.ZLOOP3X2I) // iterate again if ii != 0.
|
||||
|
||||
label(.SRETURN)
|
||||
label(.ZRETURN)
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
@@ -1305,7 +1262,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
[cs_a] "m" (cs_a),
|
||||
[b] "m" (b),
|
||||
[rs_b] "m" (rs_b),
|
||||
[cs_b] "m" (cs_b),
|
||||
[alpha] "m" (alpha),
|
||||
[beta] "m" (beta),
|
||||
[c] "m" (c),
|
||||
@@ -1314,7 +1270,7 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
[a_next] "m" (a_next),
|
||||
[b_next] "m" (b_next)*/
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
|
||||
"rax", "rbx", "rcx", "rsi", "rdi",
|
||||
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"xmm4", "xmm5", "xmm6", "xmm7",
|
||||
|
||||
@@ -767,7 +767,7 @@ void bli_zgemmsup_rv_zen_asm_2x4n
|
||||
ymm2 = _mm256_loadu_pd((double const *)(tC));
|
||||
ymm3 = _mm256_permute_pd(ymm2, 5);
|
||||
ymm2 = _mm256_mul_pd(ymm0, ymm2);
|
||||
ymm3 =_mm256_mul_pd(ymm1, ymm3);
|
||||
ymm3 = _mm256_mul_pd(ymm1, ymm3);
|
||||
ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3));
|
||||
|
||||
ymm2 = _mm256_loadu_pd((double const *)(tC+2));
|
||||
|
||||
Reference in New Issue
Block a user