Merge "ZGEMM SUP: Removed unused assembly intructions" into amd-staging-milan-3.1

This commit is contained in:
Mangala V
2021-04-19 03:08:31 -04:00
committed by Gerrit Code Review
4 changed files with 222 additions and 370 deletions

View File

@@ -6,7 +6,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2020, Advanced Micro Devices, Inc.
Copyright (C) 2020-2021, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -1281,6 +1281,9 @@ void bli_cgemmsup_rv_zen_asm_3x4
ymm1 = _mm256_broadcast_ss((float const *)(beta)); // load alpha_r and duplicate
ymm2 = _mm256_broadcast_ss((float const *)(&beta->imag)); // load alpha_i and duplicate
xmm0 = _mm_setzero_ps();
xmm3 = _mm_setzero_ps();
//Multiply ymm4 with beta
xmm0 = _mm_loadl_pi(xmm0, (__m64 const *) (tC)) ;
xmm0 = _mm_loadh_pi(xmm0, (__m64 const *) (tC + tc_inc_col));

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2020, Advanced Micro Devices, Inc.
Copyright (C) 2020-2021, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -38,8 +38,9 @@
#define BLIS_ASM_SYNTAX_ATT
#include "bli_x86_asm_macros.h"
// assumes beta.r, beta.i have been broadcast into ymm1, ymm2.
// outputs to ymm0
/* Assumes beta.r, beta.i have been broadcast into ymm1, ymm2.
and store outputs to ymm0
(creal,cimag)*(betar,beati) where c is stored in col major order*/
#define ZGEMM_INPUT_SCALE_CS_BETA_NZ \
vmovupd(mem(rcx), xmm0) \
vmovupd(mem(rcx, rsi, 1), xmm3) \
@@ -49,6 +50,7 @@
vmulpd(ymm2, ymm3, ymm3) \
vaddsubpd(ymm3, ymm0, ymm0)
//(creal,cimag)*(betar,beati) where c is stored in row major order
#define ZGEMM_INPUT_SCALE_RS_BETA_NZ \
vmovupd(mem(rcx), ymm0) \
vpermilpd(imm(0x5), ymm0, ymm3) \
@@ -59,15 +61,19 @@
#define ZGEMM_OUTPUT_RS \
vmovupd(ymm0, mem(rcx)) \
/*(cNextRowreal,cNextRowimag)*(betar,beati)
where c is stored in row major order
rsi = cs_c * sizeof((real +imag)dt)*numofElements
numofElements = 2, 2 elements are processed at a time*/
#define ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT \
vmovupd(mem(rcx, rsi, 8), ymm0) \
vmovupd(mem(rcx, rsi, 1), ymm0) \
vpermilpd(imm(0x5), ymm0, ymm3) \
vmulpd(ymm1, ymm0, ymm0) \
vmulpd(ymm2, ymm3, ymm3) \
vaddsubpd(ymm3, ymm0, ymm0)
#define ZGEMM_OUTPUT_RS_NEXT \
vmovupd(ymm0, mem(rcx, rsi, 8)) \
vmovupd(ymm0, mem(rcx, rsi, 1)) \
void bli_zgemmsup_rv_zen_asm_2x4
@@ -100,7 +106,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
uint64_t rs_a = rs_a0;
uint64_t cs_a = cs_a0;
uint64_t rs_b = rs_b0;
uint64_t cs_b = cs_b0;
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
@@ -116,8 +121,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt)
lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt)
//lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
mov(var(rs_b), r10) // load rs_b
lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt)
lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt)
@@ -141,40 +144,29 @@ void bli_zgemmsup_rv_zen_asm_2x4
mov(var(m_iter), r11) // ii = m_iter;
label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ]
vzeroall() // zero all xmm/ymm registers.
mov(var(b), rbx) // load address of b.
//mov(r12, rcx) // reset rcx to current utile of c.
mov(r14, rax) // reset rax to current upanel of a.
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
jz(.SCOLPFETCH) // jump to column storage case
label(.SROWPFETCH) // row-stored pre-fetching on c // not used
jz(.ZCOLPFETCH1) // jump to column storage case
label(.ZROWPFETCH1) // row-stored pre-fetching on c // not used
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
jmp(.ZPOSTPFETCH1) // jump to end of pre-fetching c
label(.ZCOLPFETCH1) // column-stored pre-fetching c
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
label(.SPOSTPFETCH) // done prefetching c
lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a;
lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines
lea(mem(rdx, r8, 2), rdx) // from next upanel of a.
label(.ZPOSTPFETCH1) // done prefetching c
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.SCONSIDKLEFT) // if i == 0, jump to code that
je(.ZCONSIDKLEFT1) // if i == 0, jump to code that
// contains the k_left loop.
label(.SLOOPKITER) // MAIN LOOP
label(.ZLOOPKITER1) // MAIN LOOP
// ---------------------------------- iteration 0
@@ -251,8 +243,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
add(r10, rbx) // b += rs_b;
@@ -277,16 +267,16 @@ void bli_zgemmsup_rv_zen_asm_2x4
add(r9, rax) // a += cs_a;
dec(rsi) // i -= 1;
jne(.SLOOPKITER) // iterate again if i != 0.
jne(.ZLOOPKITER1) // iterate again if i != 0.
label(.SCONSIDKLEFT)
label(.ZCONSIDKLEFT1)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
je(.ZPOSTACCUM1) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.SLOOPKLEFT) // EDGE LOOP
label(.ZLOOPKLEFT1) // EDGE LOOP
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -312,9 +302,9 @@ void bli_zgemmsup_rv_zen_asm_2x4
add(r9, rax) // a += cs_a;
dec(rsi) // i -= 1;
jne(.SLOOPKLEFT) // iterate again if i != 0.
jne(.ZLOOPKLEFT1) // iterate again if i != 0.
label(.SPOSTACCUM)
label(.ZPOSTACCUM1)
mov(r12, rcx) // reset rcx to current utile of c.
@@ -363,10 +353,8 @@ void bli_zgemmsup_rv_zen_asm_2x4
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
// now avoid loading C if beta == 0
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
@@ -375,14 +363,13 @@ void bli_zgemmsup_rv_zen_asm_2x4
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 );
and(r13b, r15b) // set ZF if r13b & r15b == 1.
jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
jne(.ZBETAZERO1) // if ZF = 1, jump to beta == 0 case
cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16.
jz(.SCOLSTORED) // jump to column storage case
jz(.ZCOLSTORED1) // jump to column storage case
label(.SROWSTORED)
label(.ZROWSTORED1)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) * numofElements
ZGEMM_INPUT_SCALE_RS_BETA_NZ
vaddpd(ymm4, ymm0, ymm0)
@@ -401,21 +388,15 @@ void bli_zgemmsup_rv_zen_asm_2x4
vaddpd(ymm9, ymm0, ymm0)
ZGEMM_OUTPUT_RS_NEXT
jmp(.SDONE) // jump to end.
jmp(.ZDONE1) // jump to end.
label(.SCOLSTORED)
label(.ZCOLSTORED1)
/*|--------| |-------|
| | | |
| 2x4 | | 4x2 |
|--------| |-------|
*/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a
ZGEMM_INPUT_SCALE_CS_BETA_NZ
vaddpd(ymm4, ymm0, ymm4)
@@ -438,9 +419,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
/****3x4 tile going to save into 4x2 tile in C*****/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
/******************Transpose top tile 4x3***************************/
vmovups(xmm4, mem(rcx))
@@ -465,29 +443,27 @@ void bli_zgemmsup_rv_zen_asm_2x4
vmovups(xmm5, mem(rcx))
vmovups(xmm9, mem(rcx, 16))
jmp(.SDONE) // jump to end.
jmp(.ZDONE1) // jump to end.
label(.SBETAZERO)
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
label(.ZBETAZERO1)
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
jz(.SCOLSTORBZ) // jump to column storage case
jz(.ZCOLSTORBZ1) // jump to column storage case
label(.SROWSTORBZ)
label(.ZROWSTORBZ1)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) * numofElements
vmovupd(ymm4, mem(rcx))
vmovupd(ymm5, mem(rcx, rsi, 8))
vmovupd(ymm5, mem(rcx, rsi, 1))
add(rdi, rcx)
vmovupd(ymm8, mem(rcx))
vmovupd(ymm9, mem(rcx, rsi, 8))
vmovupd(ymm9, mem(rcx, rsi, 1))
jmp(.SDONE) // jump to end.
jmp(.ZDONE1) // jump to end.
label(.SCOLSTORBZ)
label(.ZCOLSTORBZ1)
/****2x4 tile going to save into 4x2 tile in C*****/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
/******************Transpose tile 2x4***************************/
vmovups(xmm4, mem(rcx))
@@ -512,7 +488,7 @@ void bli_zgemmsup_rv_zen_asm_2x4
vmovups(xmm5, mem(rcx))
vmovups(xmm9, mem(rcx, 16))
label(.SDONE)
label(.ZDONE1)
end_asm(
: // output operands (none)
@@ -525,7 +501,6 @@ void bli_zgemmsup_rv_zen_asm_2x4
[cs_a] "m" (cs_a),
[b] "m" (b),
[rs_b] "m" (rs_b),
[cs_b] "m" (cs_b),
[alpha] "m" (alpha),
[beta] "m" (beta),
[c] "m" (c),
@@ -534,7 +509,7 @@ void bli_zgemmsup_rv_zen_asm_2x4
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"rax", "rbx", "rcx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -577,7 +552,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
uint64_t rs_a = rs_a0;
uint64_t cs_a = cs_a0;
uint64_t rs_b = rs_b0;
uint64_t cs_b = cs_b0;
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
@@ -593,8 +567,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt)
lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt)
//lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
mov(var(rs_b), r10) // load rs_b
lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt)
lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt)
@@ -618,8 +590,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
mov(var(m_iter), r11) // ii = m_iter;
label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ]
vzeroall() // zero all xmm/ymm registers.
mov(var(b), rbx) // load address of b.
@@ -627,31 +597,23 @@ void bli_zgemmsup_rv_zen_asm_1x4
mov(r14, rax) // reset rax to current upanel of a.
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
jz(.SCOLPFETCH) // jump to column storage case
label(.SROWPFETCH) // row-stored pre-fetching on c // not used
jz(.ZCOLPFETCH1) // jump to column storage case
label(.ZROWPFETCH1) // row-stored pre-fetching on c // not used
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
jmp(.ZPOSTPFETCH1) // jump to end of pre-fetching c
label(.ZCOLPFETCH1) // column-stored pre-fetching c
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
label(.SPOSTPFETCH) // done prefetching c
lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a;
lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines
lea(mem(rdx, r8, 2), rdx) // from next upanel of a.
label(.ZPOSTPFETCH1) // done prefetching c
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.SCONSIDKLEFT) // if i == 0, jump to code that
je(.ZCONSIDKLEFT1) // if i == 0, jump to code that
// contains the k_left loop.
label(.SLOOPKITER) // MAIN LOOP
label(.ZLOOPKITER1) // MAIN LOOP
// ---------------------------------- iteration 0
@@ -707,7 +669,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -725,16 +686,16 @@ void bli_zgemmsup_rv_zen_asm_1x4
add(r9, rax) // a += cs_a;
dec(rsi) // i -= 1;
jne(.SLOOPKITER) // iterate again if i != 0.
jne(.ZLOOPKITER1) // iterate again if i != 0.
label(.SCONSIDKLEFT)
label(.ZCONSIDKLEFT1)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
je(.ZPOSTACCUM1) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.SLOOPKLEFT) // EDGE LOOP
label(.ZLOOPKLEFT1) // EDGE LOOP
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -751,9 +712,9 @@ void bli_zgemmsup_rv_zen_asm_1x4
add(r9, rax) // a += cs_a;
dec(rsi) // i -= 1;
jne(.SLOOPKLEFT) // iterate again if i != 0.
jne(.ZLOOPKLEFT1) // iterate again if i != 0.
label(.SPOSTACCUM)
label(.ZPOSTACCUM1)
mov(r12, rcx) // reset rcx to current utile of c.
@@ -787,10 +748,8 @@ void bli_zgemmsup_rv_zen_asm_1x4
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
// now avoid loading C if beta == 0
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
@@ -799,14 +758,14 @@ void bli_zgemmsup_rv_zen_asm_1x4
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 );
and(r13b, r15b) // set ZF if r13b & r15b == 1.
jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
jne(.ZBETAZERO1) // if ZF = 1, jump to beta == 0 case
cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16.
jz(.SCOLSTORED) // jump to column storage case
jz(.ZCOLSTORED1) // jump to column storage case
label(.SROWSTORED)
label(.ZROWSTORED1)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) * numofElements
ZGEMM_INPUT_SCALE_RS_BETA_NZ
vaddpd(ymm4, ymm0, ymm0)
@@ -816,21 +775,15 @@ void bli_zgemmsup_rv_zen_asm_1x4
vaddpd(ymm5, ymm0, ymm0)
ZGEMM_OUTPUT_RS_NEXT
jmp(.SDONE) // jump to end.
jmp(.ZDONE1) // jump to end.
label(.SCOLSTORED)
label(.ZCOLSTORED1)
/*|--------| |-------|
| | | |
| 1x4 | | 4x1 |
|--------| |-------|
*/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a
ZGEMM_INPUT_SCALE_CS_BETA_NZ
vaddpd(ymm4, ymm0, ymm4)
@@ -842,9 +795,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
mov(r12, rcx) // reset rcx to current utile of c.
/****1x4 tile going to save into 4x1 tile in C*****/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
vmovups(xmm4, mem(rcx))
@@ -862,21 +812,21 @@ void bli_zgemmsup_rv_zen_asm_1x4
vextractf128(imm(0x1), ymm5, xmm5)
vmovups(xmm5, mem(rcx))
jmp(.SDONE) // jump to end.
jmp(.ZDONE1) // jump to end.
label(.SBETAZERO)
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
label(.ZBETAZERO1)
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
jz(.SCOLSTORBZ) // jump to column storage case
jz(.ZCOLSTORBZ1) // jump to column storage case
label(.SROWSTORBZ)
label(.ZROWSTORBZ1)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt) * numofElements
vmovupd(ymm4, mem(rcx))
vmovupd(ymm5, mem(rcx, rsi, 8))
vmovupd(ymm5, mem(rcx, rsi, 1))
jmp(.SDONE) // jump to end.
jmp(.ZDONE1) // jump to end.
label(.SCOLSTORBZ)
label(.ZCOLSTORBZ1)
/****1x4 tile going to save into 4x1 tile in C*****/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
@@ -898,7 +848,7 @@ void bli_zgemmsup_rv_zen_asm_1x4
vextractf128(imm(0x1), ymm5, xmm5)
vmovups(xmm5, mem(rcx))
label(.SDONE)
label(.ZDONE1)
end_asm(
: // output operands (none)
@@ -911,7 +861,6 @@ void bli_zgemmsup_rv_zen_asm_1x4
[cs_a] "m" (cs_a),
[b] "m" (b),
[rs_b] "m" (rs_b),
[cs_b] "m" (cs_b),
[alpha] "m" (alpha),
[beta] "m" (beta),
[c] "m" (c),
@@ -920,7 +869,7 @@ void bli_zgemmsup_rv_zen_asm_1x4
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"rax", "rbx", "rcx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -962,7 +911,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
uint64_t rs_a = rs_a0;
uint64_t cs_a = cs_a0;
uint64_t rs_b = rs_b0;
uint64_t cs_b = cs_b0;
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
@@ -1001,41 +949,29 @@ void bli_zgemmsup_rv_zen_asm_2x2
mov(var(m_iter), r11) // ii = m_iter;
label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ]
vzeroall() // zero all xmm/ymm registers.
mov(var(b), rbx) // load address of b.
//mov(r12, rcx) // reset rcx to current utile of c.
mov(r14, rax) // reset rax to current upanel of a.
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
jz(.SCOLPFETCH) // jump to column storage case
label(.SROWPFETCH) // row-stored pre-fetching on c // not used
jz(.ZCOLPFETCH1) // jump to column storage case
label(.ZROWPFETCH1) // row-stored pre-fetching on c // not used
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
jmp(.ZPOSTPFETCH1) // jump to end of pre-fetching c
label(.ZCOLPFETCH1) // column-stored pre-fetching c
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
label(.SPOSTPFETCH) // done prefetching c
lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a;
lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines
lea(mem(rdx, r8, 2), rdx) // from next upanel of a.
label(.ZPOSTPFETCH1) // done prefetching c
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.SCONSIDKLEFT) // if i == 0, jump to code that
je(.ZCONSIDKLEFT1) // if i == 0, jump to code that
// contains the k_left loop.
label(.SLOOPKITER) // MAIN LOOP
label(.ZLOOPKITER1) // MAIN LOOP
// ---------------------------------- iteration 0
@@ -1096,7 +1032,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1117,16 +1052,16 @@ void bli_zgemmsup_rv_zen_asm_2x2
add(r9, rax) // a += cs_a;
dec(rsi) // i -= 1;
jne(.SLOOPKITER) // iterate again if i != 0.
jne(.ZLOOPKITER1) // iterate again if i != 0.
label(.SCONSIDKLEFT)
label(.ZCONSIDKLEFT1)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
je(.ZPOSTACCUM1) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.SLOOPKLEFT) // EDGE LOOP
label(.ZLOOPKLEFT1) // EDGE LOOP
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1147,9 +1082,9 @@ void bli_zgemmsup_rv_zen_asm_2x2
add(r9, rax) // a += cs_a;
dec(rsi) // i -= 1;
jne(.SLOOPKLEFT) // iterate again if i != 0.
jne(.ZLOOPKLEFT1) // iterate again if i != 0.
label(.SPOSTACCUM)
label(.ZPOSTACCUM1)
mov(r12, rcx) // reset rcx to current utile of c.
@@ -1183,12 +1118,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
// now avoid loading C if beta == 0
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
vucomisd(xmm0, xmm1) // set ZF if beta_r == 0.
@@ -1196,13 +1125,12 @@ void bli_zgemmsup_rv_zen_asm_2x2
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 );
and(r13b, r15b) // set ZF if r13b & r15b == 1.
jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
jne(.ZBETAZERO1) // if ZF = 1, jump to beta == 0 case
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
jz(.SCOLSTORED) // jump to column storage case
jz(.ZCOLSTORED1) // jump to column storage case
label(.SROWSTORED)
label(.ZROWSTORED1)
ZGEMM_INPUT_SCALE_RS_BETA_NZ
vaddpd(ymm4, ymm0, ymm0)
@@ -1214,9 +1142,9 @@ void bli_zgemmsup_rv_zen_asm_2x2
vaddpd(ymm8, ymm0, ymm0)
ZGEMM_OUTPUT_RS
jmp(.SDONE) // jump to end.
jmp(.ZDONE1) // jump to end.
label(.SCOLSTORED)
label(.ZCOLSTORED1)
/*|--------| |-------|
| | | |
| 2x2 | | 2x2 |
@@ -1227,8 +1155,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a
ZGEMM_INPUT_SCALE_CS_BETA_NZ
vaddpd(ymm4, ymm0, ymm4)
@@ -1239,9 +1165,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
mov(r12, rcx) // reset rcx to current utile of c.
/****2x2 tile going to save into 2x2 tile in C*****/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
vmovups(xmm4, mem(rcx))
vmovups(xmm8, mem(rcx, 16))
@@ -1253,23 +1176,23 @@ void bli_zgemmsup_rv_zen_asm_2x2
vmovups(xmm4, mem(rcx))
vmovups(xmm8, mem(rcx, 16))
jmp(.SDONE) // jump to end.
jmp(.ZDONE1) // jump to end.
label(.SBETAZERO)
label(.ZBETAZERO1)
cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8.
jz(.SCOLSTORBZ) // jump to column storage case
jz(.ZCOLSTORBZ1) // jump to column storage case
label(.SROWSTORBZ)
label(.ZROWSTORBZ1)
vmovupd(ymm4, mem(rcx))
add(rdi, rcx)
vmovupd(ymm8, mem(rcx))
jmp(.SDONE) // jump to end.
jmp(.ZDONE1) // jump to end.
label(.SCOLSTORBZ)
label(.ZCOLSTORBZ1)
/****2x2 tile going to save into 2x2 tile in C*****/
mov(var(cs_c), rsi) // load cs_c
@@ -1286,7 +1209,7 @@ void bli_zgemmsup_rv_zen_asm_2x2
vmovups(xmm4, mem(rcx))
vmovups(xmm8, mem(rcx, 16))
label(.SDONE)
label(.ZDONE1)
end_asm(
: // output operands (none)
@@ -1299,7 +1222,6 @@ void bli_zgemmsup_rv_zen_asm_2x2
[cs_a] "m" (cs_a),
[b] "m" (b),
[rs_b] "m" (rs_b),
[cs_b] "m" (cs_b),
[alpha] "m" (alpha),
[beta] "m" (beta),
[c] "m" (c),
@@ -1308,7 +1230,7 @@ void bli_zgemmsup_rv_zen_asm_2x2
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"rax", "rbx", "rcx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
@@ -1350,7 +1272,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
uint64_t rs_a = rs_a0;
uint64_t cs_a = cs_a0;
uint64_t rs_b = rs_b0;
uint64_t cs_b = cs_b0;
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
@@ -1366,7 +1287,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt)
lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt)
// lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
mov(var(rs_b), r10) // load rs_b
lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt)
@@ -1391,8 +1311,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
mov(var(m_iter), r11) // ii = m_iter;
label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ]
vzeroall() // zero all xmm/ymm registers.
mov(var(b), rbx) // load address of b.
@@ -1400,32 +1318,23 @@ void bli_zgemmsup_rv_zen_asm_1x2
mov(r14, rax) // reset rax to current upanel of a.
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
jz(.SCOLPFETCH) // jump to column storage case
label(.SROWPFETCH) // row-stored pre-fetching on c // not used
jz(.ZCOLPFETCH1) // jump to column storage case
label(.ZROWPFETCH1) // row-stored pre-fetching on c // not used
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
jmp(.ZPOSTPFETCH1) // jump to end of pre-fetching c
label(.ZCOLPFETCH1) // column-stored pre-fetching c
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
label(.SPOSTPFETCH) // done prefetching c
lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a;
lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines
lea(mem(rdx, r8, 2), rdx) // from next upanel of a.
label(.ZPOSTPFETCH1) // done prefetching c
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.SCONSIDKLEFT) // if i == 0, jump to code that
je(.ZCONSIDKLEFT1) // if i == 0, jump to code that
// contains the k_left loop.
label(.SLOOPKITER) // MAIN LOOP
label(.ZLOOPKITER1) // MAIN LOOP
// ---------------------------------- iteration 0
@@ -1469,8 +1378,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1483,16 +1390,16 @@ void bli_zgemmsup_rv_zen_asm_1x2
add(r9, rax) // a += cs_a;
dec(rsi) // i -= 1;
jne(.SLOOPKITER) // iterate again if i != 0.
jne(.ZLOOPKITER1) // iterate again if i != 0.
label(.SCONSIDKLEFT)
label(.ZCONSIDKLEFT1)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
je(.ZPOSTACCUM1) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.SLOOPKLEFT) // EDGE LOOP
label(.ZLOOPKLEFT1) // EDGE LOOP
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1506,9 +1413,9 @@ void bli_zgemmsup_rv_zen_asm_1x2
add(r9, rax) // a += cs_a;
dec(rsi) // i -= 1;
jne(.SLOOPKLEFT) // iterate again if i != 0.
jne(.ZLOOPKLEFT1) // iterate again if i != 0.
label(.SPOSTACCUM)
label(.ZPOSTACCUM1)
mov(r12, rcx) // reset rcx to current utile of c.
@@ -1534,12 +1441,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
// now avoid loading C if beta == 0
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
vucomisd(xmm0, xmm1) // set ZF if beta_r == 0.
@@ -1547,21 +1448,20 @@ void bli_zgemmsup_rv_zen_asm_1x2
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 );
and(r13b, r15b) // set ZF if r13b & r15b == 1.
jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
jne(.ZBETAZERO1) // if ZF = 1, jump to beta == 0 case
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
jz(.SCOLSTORED) // jump to column storage case
jz(.ZCOLSTORED1) // jump to column storage case
label(.SROWSTORED)
label(.ZROWSTORED1)
ZGEMM_INPUT_SCALE_RS_BETA_NZ
vaddpd(ymm4, ymm0, ymm0)
ZGEMM_OUTPUT_RS
jmp(.SDONE) // jump to end.
jmp(.ZDONE1) // jump to end.
label(.SCOLSTORED)
label(.ZCOLSTORED1)
/*|--------| |-------|
| | | |
| 1x2 | | 2x1 |
@@ -1572,17 +1472,11 @@ void bli_zgemmsup_rv_zen_asm_1x2
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a
ZGEMM_INPUT_SCALE_CS_BETA_NZ
vaddpd(ymm4, ymm0, ymm4)
/****3x4 tile going to save into 4x3 tile in C*****/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
/******************Transpose tile 1x2***************************/
vmovups(xmm4, mem(rcx))
add(rsi, rcx)
@@ -1591,22 +1485,22 @@ void bli_zgemmsup_rv_zen_asm_1x2
vmovups(xmm4, mem(rcx))
jmp(.SDONE) // jump to end.
jmp(.ZDONE1) // jump to end.
label(.SBETAZERO)
label(.ZBETAZERO1)
cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8.
jz(.SCOLSTORBZ) // jump to column storage case
jz(.ZCOLSTORBZ1) // jump to column storage case
label(.SROWSTORBZ)
label(.ZROWSTORBZ1)
vmovupd(ymm4, mem(rcx))
add(rdi, rcx)
jmp(.SDONE) // jump to end.
jmp(.ZDONE1) // jump to end.
label(.SCOLSTORBZ)
label(.ZCOLSTORBZ1)
/****1x2 tile going to save into 2x1 tile in C*****/
mov(var(cs_c), rsi) // load cs_c
@@ -1622,7 +1516,7 @@ void bli_zgemmsup_rv_zen_asm_1x2
vextractf128(imm(0x1), ymm4, xmm4)
vmovups(xmm4, mem(rcx))
label(.SDONE)
label(.ZDONE1)
end_asm(
: // output operands (none)
@@ -1635,7 +1529,6 @@ void bli_zgemmsup_rv_zen_asm_1x2
[cs_a] "m" (cs_a),
[b] "m" (b),
[rs_b] "m" (rs_b),
[cs_b] "m" (cs_b),
[alpha] "m" (alpha),
[beta] "m" (beta),
[c] "m" (c),
@@ -1644,7 +1537,7 @@ void bli_zgemmsup_rv_zen_asm_1x2
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"rax", "rbx", "rcx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",

View File

@@ -38,8 +38,9 @@
#define BLIS_ASM_SYNTAX_ATT
#include "bli_x86_asm_macros.h"
// assumes beta.r, beta.i have been broadcast into ymm1, ymm2.
// outputs to ymm0
/* Assumes beta.r, beta.i have been broadcast into ymm1, ymm2.
and store outputs to ymm0
(creal,cimag)*(betar,beati) where c is stored in col major order*/
#define ZGEMM_INPUT_SCALE_CS_BETA_NZ \
vmovupd(mem(rcx), xmm0) \
vmovupd(mem(rcx, rsi, 1), xmm3) \
@@ -49,6 +50,7 @@
vmulpd(ymm2, ymm3, ymm3) \
vaddsubpd(ymm3, ymm0, ymm0)
//(creal,cimag)*(betar,beati) where c is stored in row major order
#define ZGEMM_INPUT_SCALE_RS_BETA_NZ \
vmovupd(mem(rcx), ymm0) \
vpermilpd(imm(0x5), ymm0, ymm3) \
@@ -62,18 +64,22 @@
#define ZGEMM_OUTPUT_RS \
vmovupd(ymm0, mem(rcx)) \
/*(cNextRowreal,cNextRowimag)*(betar,beati)
where c is stored in row major order
rsi = cs_c * sizeof((real +imag)dt)*numofElements
numofElements = 2, 2 elements are processed at a time*/
#define ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT \
vmovupd(mem(rcx, rsi, 8), ymm0) \
vmovupd(mem(rcx, rsi, 1), ymm0) \
vpermilpd(imm(0x5), ymm0, ymm3) \
vmulpd(ymm1, ymm0, ymm0) \
vmulpd(ymm2, ymm3, ymm3) \
vaddsubpd(ymm3, ymm0, ymm0)
#define ZGEMM_INPUT_RS_BETA_ONE_NEXT \
vmovupd(mem(rcx, rsi, 8), ymm0)
vmovupd(mem(rcx, rsi, 1), ymm0)
#define ZGEMM_OUTPUT_RS_NEXT \
vmovupd(ymm0, mem(rcx, rsi, 8))
vmovupd(ymm0, mem(rcx, rsi, 1))
/*
rrr:
@@ -174,7 +180,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
uint64_t rs_a = rs_a0;
uint64_t cs_a = cs_a0;
uint64_t rs_b = rs_b0;
uint64_t cs_b = cs_b0;
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
@@ -227,8 +232,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
lea(mem(, r9, 8), r9) // cs_a *= sizeof( real dt)
lea(mem(, r9, 2), r9) // cs_a *= sizeof((real + imag) dt)
//lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
mov(var(rs_b), r10) // load rs_b
lea(mem(, r10, 8), r10) // rs_b *= sizeof(real dt)
lea(mem(, r10, 2), r10) // rs_b *= sizeof((real +imag) dt)
@@ -252,7 +255,7 @@ void bli_zgemmsup_rv_zen_asm_3x4m
mov(var(m_iter), r11) // ii = m_iter;
label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ]
label(.ZLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ]
vzeroall() // zero all xmm/ymm registers.
@@ -260,31 +263,22 @@ void bli_zgemmsup_rv_zen_asm_3x4m
mov(r14, rax) // reset rax to current upanel of a.
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
jz(.SCOLPFETCH) // jump to column storage case
label(.SROWPFETCH) // row-stored pre-fetching on c // not used
jz(.ZCOLPFETCH) // jump to column storage case
label(.ZROWPFETCH) // row-stored pre-fetching on c // not used
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c
label(.ZCOLPFETCH) // column-stored pre-fetching c
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
label(.SPOSTPFETCH) // done prefetching c
lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a;
lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines
lea(mem(rdx, r8, 2), rdx) // from next upanel of a.
label(.ZPOSTPFETCH) // done prefetching c
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.SCONSIDKLEFT) // if i == 0, jump to code that
je(.ZCONSIDKLEFT) // if i == 0, jump to code that
// contains the k_left loop.
label(.SLOOPKITER) // MAIN LOOP
label(.ZLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
@@ -383,8 +377,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
add(r10, rbx) // b += rs_b;
@@ -416,16 +408,16 @@ void bli_zgemmsup_rv_zen_asm_3x4m
add(r9, rax) // a += cs_a;
dec(rsi) // i -= 1;
jne(.SLOOPKITER) // iterate again if i != 0.
jne(.ZLOOPKITER) // iterate again if i != 0.
label(.SCONSIDKLEFT)
label(.ZCONSIDKLEFT)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
je(.ZPOSTACCUM) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.SLOOPKLEFT) // EDGE LOOP
label(.ZLOOPKLEFT) // EDGE LOOP
vmovupd(mem(rbx, 0*32), ymm0)
vmovupd(mem(rbx, 1*32), ymm1)
@@ -458,9 +450,9 @@ void bli_zgemmsup_rv_zen_asm_3x4m
add(r9, rax) // a += cs_a;
dec(rsi) // i -= 1;
jne(.SLOOPKLEFT) // iterate again if i != 0.
jne(.ZLOOPKLEFT) // iterate again if i != 0.
label(.SPOSTACCUM)
label(.ZPOSTACCUM)
mov(r12, rcx) // reset rcx to current utile of c.
@@ -483,6 +475,10 @@ void bli_zgemmsup_rv_zen_asm_3x4m
vaddsubpd(ymm14, ymm12, ymm12)
vaddsubpd(ymm15, ymm13, ymm13)
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt)
//if(alpha_mul_type == BLIS_MUL_MINUS_ONE)
mov(var(alpha_mul_type), al)
cmp(imm(0xFF), al)
@@ -541,19 +537,18 @@ void bli_zgemmsup_rv_zen_asm_3x4m
label(.ALPHA_REAL_ONE)
// Beta multiplication
/* (br + bi)x C + ((ar + ai) x AB) */
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
mov(var(beta_mul_type), al)
cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO)
je(.SBETAZERO) //jump to beta == 0 case
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
je(.ZBETAZERO) //jump to beta == 0 case
cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16.
jz(.SCOLSTORED) // jump to column storage case
jz(.ZCOLSTORED) // jump to column storage case
label(.ZROWSTORED)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) * numofElements
label(.SROWSTORED)
cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT)
je(.ROW_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication.
@@ -586,7 +581,7 @@ void bli_zgemmsup_rv_zen_asm_3x4m
ZGEMM_INPUT_RS_BETA_ONE_NEXT
vaddpd(ymm13, ymm0, ymm0)
ZGEMM_OUTPUT_RS_NEXT
jmp(.SDONE)
jmp(.ZDONE)
//CASE 2: beta is real = -1
@@ -616,7 +611,7 @@ void bli_zgemmsup_rv_zen_asm_3x4m
ZGEMM_INPUT_RS_BETA_ONE_NEXT
vsubpd(ymm0, ymm13, ymm0)
ZGEMM_OUTPUT_RS_NEXT
jmp(.SDONE)
jmp(.ZDONE)
//CASE 3: Default case with multiplication
@@ -651,9 +646,9 @@ void bli_zgemmsup_rv_zen_asm_3x4m
ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT
vaddpd(ymm13, ymm0, ymm0)
ZGEMM_OUTPUT_RS_NEXT
jmp(.SDONE) // jump to end.
jmp(.ZDONE) // jump to end.
label(.SCOLSTORED)
label(.ZCOLSTORED)
mov(var(beta), rbx) // load address of beta
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
@@ -663,11 +658,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|--------| |-------|
*/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt)
lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a
ZGEMM_INPUT_SCALE_CS_BETA_NZ
vaddpd(ymm4, ymm0, ymm4)
@@ -696,9 +686,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
/****3x4 tile going to save into 4x3 tile in C*****/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt)
/******************Transpose top tile 4x3***************************/
vmovups(xmm4, mem(rcx))
@@ -729,34 +716,34 @@ void bli_zgemmsup_rv_zen_asm_3x4m
vmovups(xmm9, mem(rcx, 16))
vmovups(xmm13,mem(rcx,32))
jmp(.SDONE) // jump to end.
jmp(.ZDONE) // jump to end.
label(.SBETAZERO)
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
label(.ZBETAZERO)
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
jz(.SCOLSTORBZ) // jump to column storage case
jz(.ZCOLSTORBZ) // jump to column storage case
label(.SROWSTORBZ)
label(.ZROWSTORBZ)
/* Store 3x4 elements to C matrix where is C row major order*/
// rsi = cs_c * sizeof((real +imag)dt) *numofElements
lea(mem(, rsi, 2), rsi)
vmovupd(ymm4, mem(rcx))
vmovupd(ymm5, mem(rcx, rsi, 8))
vmovupd(ymm5, mem(rcx, rsi, 1))
add(rdi, rcx)
vmovupd(ymm8, mem(rcx))
vmovupd(ymm9, mem(rcx, rsi, 8))
vmovupd(ymm9, mem(rcx, rsi, 1))
add(rdi, rcx)
vmovupd(ymm12, mem(rcx))
vmovupd(ymm13, mem(rcx, rsi, 8))
vmovupd(ymm13, mem(rcx, rsi, 1))
jmp(.SDONE) // jump to end.
jmp(.ZDONE) // jump to end.
label(.SCOLSTORBZ)
label(.ZCOLSTORBZ)
/****3x4 tile going to save into 4x3 tile in C*****/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
/******************Transpose top tile 4x3***************************/
vmovups(xmm4, mem(rcx))
@@ -787,7 +774,7 @@ void bli_zgemmsup_rv_zen_asm_3x4m
vmovups(xmm9, mem(rcx, 16))
vmovups(xmm13,mem(rcx,32))
label(.SDONE)
label(.ZDONE)
lea(mem(r12, rdi, 2), r12)
lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c
@@ -796,9 +783,9 @@ void bli_zgemmsup_rv_zen_asm_3x4m
lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a
dec(r11) // ii -= 1;
jne(.SLOOP3X8I) // iterate again if ii != 0.
jne(.ZLOOP3X4I) // iterate again if ii != 0.
label(.SRETURN)
label(.ZRETURN)
end_asm(
: // output operands (none)
@@ -813,7 +800,6 @@ void bli_zgemmsup_rv_zen_asm_3x4m
[cs_a] "m" (cs_a),
[b] "m" (b),
[rs_b] "m" (rs_b),
[cs_b] "m" (cs_b),
[alpha] "m" (alpha),
[beta] "m" (beta),
[c] "m" (c),
@@ -822,8 +808,8 @@ void bli_zgemmsup_rv_zen_asm_3x4m
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"rax", "rbx", "rcx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",
"xmm8", "xmm9", "xmm10", "xmm11",
@@ -896,7 +882,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
uint64_t rs_a = rs_a0;
uint64_t cs_a = cs_a0;
uint64_t rs_b = rs_b0;
uint64_t cs_b = cs_b0;
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
@@ -914,8 +899,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt)
lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt)
// lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
mov(var(rs_b), r10) // load rs_b
lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt)
lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt)
@@ -939,41 +922,31 @@ void bli_zgemmsup_rv_zen_asm_3x2m
mov(var(m_iter), r11) // ii = m_iter;
label(.SLOOP3X8I) // LOOP OVER ii = [ m_iter ... 1 0 ]
label(.ZLOOP3X2I) // LOOP OVER ii = [ m_iter ... 1 0 ]
vzeroall() // zero all xmm/ymm registers.
mov(var(b), rbx) // load address of b.
//mov(r12, rcx) // reset rcx to current utile of c.
mov(r14, rax) // reset rax to current upanel of a.
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
jz(.SCOLPFETCH) // jump to column storage case
label(.SROWPFETCH) // row-stored pre-fetching on c // not used
jz(.ZCOLPFETCH) // jump to column storage case
label(.ZROWPFETCH) // row-stored pre-fetching on c // not used
lea(mem(r12, rdi, 2), rdx) //
lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c;
jmp(.SPOSTPFETCH) // jump to end of pre-fetching c
label(.SCOLPFETCH) // column-stored pre-fetching c
jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c
label(.ZCOLPFETCH) // column-stored pre-fetching c
mov(var(cs_c), rsi) // load cs_c to rsi (temporarily)
lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt)
lea(mem(r12, rsi, 2), rdx) //
lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c;
label(.SPOSTPFETCH) // done prefetching c
lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a;
lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines
lea(mem(rdx, r8, 2), rdx) // from next upanel of a.
label(.ZPOSTPFETCH) // done prefetching c
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.SCONSIDKLEFT) // if i == 0, jump to code that
je(.ZCONSIDKLEFT) // if i == 0, jump to code that
// contains the k_left loop.
label(.SLOOPKITER) // MAIN LOOP
label(.ZLOOPKITER) // MAIN LOOP
// ---------------------------------- iteration 0
@@ -1051,8 +1024,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
add(r9, rax) // a += cs_a;
// ---------------------------------- iteration 3
lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a;
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1077,16 +1048,16 @@ void bli_zgemmsup_rv_zen_asm_3x2m
add(r9, rax) // a += cs_a;
dec(rsi) // i -= 1;
jne(.SLOOPKITER) // iterate again if i != 0.
jne(.ZLOOPKITER) // iterate again if i != 0.
label(.SCONSIDKLEFT)
label(.ZCONSIDKLEFT)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
je(.ZPOSTACCUM) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.SLOOPKLEFT) // EDGE LOOP
label(.ZLOOPKLEFT) // EDGE LOOP
vmovupd(mem(rbx, 0*32), ymm0)
add(r10, rbx) // b += rs_b;
@@ -1112,9 +1083,9 @@ void bli_zgemmsup_rv_zen_asm_3x2m
add(r9, rax) // a += cs_a;
dec(rsi) // i -= 1;
jne(.SLOOPKLEFT) // iterate again if i != 0.
jne(.ZLOOPKLEFT) // iterate again if i != 0.
label(.SPOSTACCUM)
label(.ZPOSTACCUM)
mov(r12, rcx) // reset rcx to current utile of c.
@@ -1126,9 +1097,7 @@ void bli_zgemmsup_rv_zen_asm_3x2m
// subtract/add even/odd elements
vaddsubpd(ymm6, ymm4, ymm4)
vaddsubpd(ymm10, ymm8, ymm8)
vaddsubpd(ymm14, ymm12, ymm12)
/* (ar + ai) x AB */
@@ -1156,12 +1125,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
// now avoid loading C if beta == 0
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
vucomisd(xmm0, xmm1) // set ZF if beta_r == 0.
@@ -1169,13 +1132,12 @@ void bli_zgemmsup_rv_zen_asm_3x2m
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 );
and(r13b, r15b) // set ZF if r13b & r15b == 1.
jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16.
jz(.SCOLSTORED) // jump to column storage case
jz(.ZCOLSTORED) // jump to column storage case
label(.SROWSTORED)
label(.ZROWSTORED)
ZGEMM_INPUT_SCALE_RS_BETA_NZ
vaddpd(ymm4, ymm0, ymm0)
@@ -1193,9 +1155,9 @@ void bli_zgemmsup_rv_zen_asm_3x2m
vaddpd(ymm12, ymm0, ymm0)
ZGEMM_OUTPUT_RS
jmp(.SDONE) // jump to end.
jmp(.ZDONE) // jump to end.
label(.SCOLSTORED)
label(.ZCOLSTORED)
/*|--------| |-------|
| | | |
| 3x2 | | 2x3 |
@@ -1206,8 +1168,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_a
ZGEMM_INPUT_SCALE_CS_BETA_NZ
vaddpd(ymm4, ymm0, ymm4)
@@ -1222,9 +1182,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
mov(r12, rcx) // reset rcx to current utile of c.
/****3x2 tile going to save into 2x3 tile in C*****/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
/******************Transpose top tile 2x3***************************/
vmovups(xmm4, mem(rcx))
@@ -1241,14 +1198,14 @@ void bli_zgemmsup_rv_zen_asm_3x2m
vmovups(xmm12, mem(rcx,32))
jmp(.SDONE) // jump to end.
jmp(.ZDONE) // jump to end.
label(.SBETAZERO)
label(.ZBETAZERO)
cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8.
jz(.SCOLSTORBZ) // jump to column storage case
jz(.ZCOLSTORBZ) // jump to column storage case
label(.SROWSTORBZ)
label(.ZROWSTORBZ)
vmovupd(ymm4, mem(rcx))
add(rdi, rcx)
@@ -1258,14 +1215,14 @@ void bli_zgemmsup_rv_zen_asm_3x2m
vmovupd(ymm12, mem(rcx))
jmp(.SDONE) // jump to end.
jmp(.ZDONE) // jump to end.
label(.SCOLSTORBZ)
label(.ZCOLSTORBZ)
/****3x2 tile going to save into 2x3 tile in C*****/
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt)
lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt)
/******************Transpose tile 3x2***************************/
vmovups(xmm4, mem(rcx))
@@ -1281,7 +1238,7 @@ void bli_zgemmsup_rv_zen_asm_3x2m
vmovups(xmm8, mem(rcx, 16))
vmovups(xmm12, mem(rcx,32))
label(.SDONE)
label(.ZDONE)
lea(mem(r12, rdi, 2), r12)
lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c
@@ -1290,9 +1247,9 @@ void bli_zgemmsup_rv_zen_asm_3x2m
lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a
dec(r11) // ii -= 1;
jne(.SLOOP3X8I) // iterate again if ii != 0.
jne(.ZLOOP3X2I) // iterate again if ii != 0.
label(.SRETURN)
label(.ZRETURN)
end_asm(
: // output operands (none)
@@ -1305,7 +1262,6 @@ void bli_zgemmsup_rv_zen_asm_3x2m
[cs_a] "m" (cs_a),
[b] "m" (b),
[rs_b] "m" (rs_b),
[cs_b] "m" (cs_b),
[alpha] "m" (alpha),
[beta] "m" (beta),
[c] "m" (c),
@@ -1314,7 +1270,7 @@ void bli_zgemmsup_rv_zen_asm_3x2m
[a_next] "m" (a_next),
[b_next] "m" (b_next)*/
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi", "rdi",
"rax", "rbx", "rcx", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
"xmm0", "xmm1", "xmm2", "xmm3",
"xmm4", "xmm5", "xmm6", "xmm7",

View File

@@ -767,7 +767,7 @@ void bli_zgemmsup_rv_zen_asm_2x4n
ymm2 = _mm256_loadu_pd((double const *)(tC));
ymm3 = _mm256_permute_pd(ymm2, 5);
ymm2 = _mm256_mul_pd(ymm0, ymm2);
ymm3 =_mm256_mul_pd(ymm1, ymm3);
ymm3 = _mm256_mul_pd(ymm1, ymm3);
ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3));
ymm2 = _mm256_loadu_pd((double const *)(tC+2));