Handling zgemm real(+/-1) alpha and beta

1.Improved performance when zgemm's alpha and beta are real and equal to +/-1.
2.change done in bli_zgemmsup_rv_zen_asm_3x4n.
3.change done in bli_zgemmsup_rv_zen_asm_3x4m.
4.change done in bli_zgemm_haswell_asm_3x4.

Change-Id: Ic14d8507b264c24a8748febf6bc73eb60e476430
AMD-Internal: [CPUPL-1352]
This commit is contained in:
Madan mohan Manokar
2020-12-11 11:57:59 +05:30
committed by Madan Mohan Manokar
parent 1ff4981203
commit 3ab9104dae
4 changed files with 534 additions and 236 deletions

View File

@@ -5,6 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -35,3 +36,12 @@
void bli_const_init( void );
void bli_const_finalize( void );
/* constant used to check 1 and -1 when double converted uint64 */
#define BLIS_DOUBLE_TO_UINT64_ONE_ABS 0x3ff0000000000000
#define BLIS_DOUBLE_TO_UINT64_MINUS_ONE 0xbff0000000000000
/* enum used to clasify alpha and beta to one of the below category */
enum mulfactor { BLIS_MUL_MINUS_ONE = -1,
BLIS_MUL_ZERO,
BLIS_MUL_ONE,
BLIS_MUL_DEFAULT };

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc.
Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -2197,10 +2197,11 @@ void bli_cgemm_haswell_asm_3x8
vmulpd(ymm1, ymm0, ymm0) \
vmulpd(ymm2, ymm3, ymm3) \
vaddsubpd(ymm3, ymm0, ymm0)
#define ZGEMM_OUTPUT_RS \
vmovupd(ymm0, mem(rcx)) \
void bli_zgemm_haswell_asm_3x4
(
dim_t k0,
@@ -2223,71 +2224,105 @@ void bli_zgemm_haswell_asm_3x4
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
//handling case when alpha and beta are real and +/-1.
uint64_t alpha_real_one = *((uint64_t*)(&alpha->real));
uint64_t beta_real_one = *((uint64_t*)(&beta->real));
uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1);
uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1);
char alpha_mul_type = BLIS_MUL_DEFAULT;
char beta_mul_type = BLIS_MUL_DEFAULT;
if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1)
{
alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1
if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE)
{
alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1
}
}
if(beta->imag == 0)// beta is real
{
if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1)
{
beta_mul_type = BLIS_MUL_ONE;
if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE)
{
beta_mul_type = BLIS_MUL_MINUS_ONE;
}
}
else if(beta_real_one == 0)
{
beta_mul_type = BLIS_MUL_ZERO;
}
}
begin_asm()
vzeroall() // zero all xmm/ymm registers.
mov(var(a), rax) // load address of a.
mov(var(b), rbx) // load address of b.
//mov(%9, r15) // load address of b_next.
add(imm(32*4), rbx)
// initialize loop by pre-loading
vmovapd(mem(rbx, -4*32), ymm0)
vmovapd(mem(rbx, -3*32), ymm1)
mov(var(c), rcx) // load address of c
mov(var(rs_c), rdi) // load rs_c
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dcomplex)
lea(mem(, rdi, 2), rdi)
lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*rs_c;
lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*rs_c;
prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c
prefetch(0, mem(r11, 7*8)) // prefetch c + 1*rs_c
prefetch(0, mem(r12, 7*8)) // prefetch c + 2*rs_c
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.ZCONSIDKLEFT) // if i == 0, jump to code that
// contains the k_left loop.
label(.ZLOOPKITER) // MAIN LOOP
// iteration 0
prefetch(0, mem(rax, 32*16))
vbroadcastsd(mem(rax, 0*8), ymm2)
vbroadcastsd(mem(rax, 1*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 2*8), ymm2)
vbroadcastsd(mem(rax, 3*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 4*8), ymm2)
vbroadcastsd(mem(rax, 5*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
vmovapd(mem(rbx, -2*32), ymm0)
vmovapd(mem(rbx, -1*32), ymm1)
// iteration 1
vbroadcastsd(mem(rax, 6*8), ymm2)
vbroadcastsd(mem(rax, 7*8), ymm3)
@@ -2295,51 +2330,51 @@ void bli_zgemm_haswell_asm_3x4
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 8*8), ymm2)
vbroadcastsd(mem(rax, 9*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 10*8), ymm2)
vbroadcastsd(mem(rax, 11*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
vmovapd(mem(rbx, 0*32), ymm0)
vmovapd(mem(rbx, 1*32), ymm1)
// iteration 2
prefetch(0, mem(rax, 38*16))
vbroadcastsd(mem(rax, 12*8), ymm2)
vbroadcastsd(mem(rax, 13*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 14*8), ymm2)
vbroadcastsd(mem(rax, 15*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 16*8), ymm2)
vbroadcastsd(mem(rax, 17*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
vmovapd(mem(rbx, 2*32), ymm0)
vmovapd(mem(rbx, 3*32), ymm1)
// iteration 3
vbroadcastsd(mem(rax, 18*8), ymm2)
vbroadcastsd(mem(rax, 19*8), ymm3)
@@ -2347,83 +2382,83 @@ void bli_zgemm_haswell_asm_3x4
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 20*8), ymm2)
vbroadcastsd(mem(rax, 21*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 22*8), ymm2)
vbroadcastsd(mem(rax, 23*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
add(imm(4*3*16), rax) // a += 4*3 (unroll x mr)
add(imm(4*4*16), rbx) // b += 4*4 (unroll x nr)
vmovapd(mem(rbx, -4*32), ymm0)
vmovapd(mem(rbx, -3*32), ymm1)
dec(rsi) // i -= 1;
jne(.ZLOOPKITER) // iterate again if i != 0.
label(.ZCONSIDKLEFT)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.ZPOSTACCUM) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.ZLOOPKLEFT) // EDGE LOOP
prefetch(0, mem(rax, 32*16))
vbroadcastsd(mem(rax, 0*8), ymm2)
vbroadcastsd(mem(rax, 1*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm4)
vfmadd231pd(ymm1, ymm2, ymm5)
vfmadd231pd(ymm0, ymm3, ymm6)
vfmadd231pd(ymm1, ymm3, ymm7)
vbroadcastsd(mem(rax, 2*8), ymm2)
vbroadcastsd(mem(rax, 3*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm8)
vfmadd231pd(ymm1, ymm2, ymm9)
vfmadd231pd(ymm0, ymm3, ymm10)
vfmadd231pd(ymm1, ymm3, ymm11)
vbroadcastsd(mem(rax, 4*8), ymm2)
vbroadcastsd(mem(rax, 5*8), ymm3)
vfmadd231pd(ymm0, ymm2, ymm12)
vfmadd231pd(ymm1, ymm2, ymm13)
vfmadd231pd(ymm0, ymm3, ymm14)
vfmadd231pd(ymm1, ymm3, ymm15)
add(imm(1*3*16), rax) // a += 1*3 (unroll x mr)
add(imm(1*4*16), rbx) // b += 1*4 (unroll x nr)
vmovapd(mem(rbx, -4*32), ymm0)
vmovapd(mem(rbx, -3*32), ymm1)
dec(rsi) // i -= 1;
jne(.ZLOOPKLEFT) // iterate again if i != 0.
label(.ZPOSTACCUM)
// permute even and odd elements
// of ymm6/7, ymm10/11, ymm/14/15
vpermilpd(imm(0x5), ymm6, ymm6)
@@ -2432,251 +2467,306 @@ void bli_zgemm_haswell_asm_3x4
vpermilpd(imm(0x5), ymm11, ymm11)
vpermilpd(imm(0x5), ymm14, ymm14)
vpermilpd(imm(0x5), ymm15, ymm15)
// subtract/add even/odd elements
vaddsubpd(ymm6, ymm4, ymm4)
vaddsubpd(ymm7, ymm5, ymm5)
vaddsubpd(ymm10, ymm8, ymm8)
vaddsubpd(ymm11, ymm9, ymm9)
vaddsubpd(ymm14, ymm12, ymm12)
vaddsubpd(ymm15, ymm13, ymm13)
//if(alpha_mul_type == BLIS_MUL_MINUS_ONE)
mov(var(alpha_mul_type), al)
cmp(imm(0xFF), al)
jne(.ALPHA_NOT_MINUS1)
// when alpha = -1 and real.
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
vsubpd(ymm4, ymm0, ymm4)
vsubpd(ymm5, ymm0, ymm5)
vsubpd(ymm8, ymm0, ymm8)
vsubpd(ymm9, ymm0, ymm9)
vsubpd(ymm12, ymm0, ymm12)
vsubpd(ymm13, ymm0, ymm13)
jmp(.ALPHA_REAL_ONE)
label(.ALPHA_NOT_MINUS1)
//when alpha is real and +/-1, multiplication is skipped.
cmp(imm(2), al)//if(alpha_mul_type != BLIS_MUL_DEFAULT) skip below multiplication.
jne(.ALPHA_REAL_ONE)
mov(var(alpha), rax) // load address of alpha
vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate
vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate
vpermilpd(imm(0x5), ymm4, ymm3)
vmulpd(ymm0, ymm4, ymm4)
vmulpd(ymm1, ymm3, ymm3)
vaddsubpd(ymm3, ymm4, ymm4)
vpermilpd(imm(0x5), ymm5, ymm3)
vmulpd(ymm0, ymm5, ymm5)
vmulpd(ymm1, ymm3, ymm3)
vaddsubpd(ymm3, ymm5, ymm5)
vpermilpd(imm(0x5), ymm8, ymm3)
vmulpd(ymm0, ymm8, ymm8)
vmulpd(ymm1, ymm3, ymm3)
vaddsubpd(ymm3, ymm8, ymm8)
vpermilpd(imm(0x5), ymm9, ymm3)
vmulpd(ymm0, ymm9, ymm9)
vmulpd(ymm1, ymm3, ymm3)
vaddsubpd(ymm3, ymm9, ymm9)
vpermilpd(imm(0x5), ymm12, ymm3)
vmulpd(ymm0, ymm12, ymm12)
vmulpd(ymm1, ymm3, ymm3)
vaddsubpd(ymm3, ymm12, ymm12)
vpermilpd(imm(0x5), ymm13, ymm3)
vmulpd(ymm0, ymm13, ymm13)
vmulpd(ymm1, ymm3, ymm3)
vaddsubpd(ymm3, ymm13, ymm13)
mov(var(beta), rbx) // load address of beta
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
label(.ALPHA_REAL_ONE)
// Beta multiplication
/* (br + bi)x C + ((ar + ai) x AB) */
mov(var(beta), rbx) // load address of beta
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dcomplex)
lea(mem(, rsi, 2), rsi)
lea(mem(, rsi, 2), rdx) // rdx = 2*cs_c;
// now avoid loading C if beta == 0
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
vucomisd(xmm0, xmm1) // set ZF if beta_r == 0.
sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 );
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 );
and(r8b, r9b) // set ZF if r8b & r9b == 1.
jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case
// now avoid loading C if beta == 0
mov(var(beta_mul_type), al)
cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO)
je(.ZBETAZERO) //jump to beta == 0 case
cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16.
jz(.ZROWSTORED) // jump to row storage case
label(.ZGENSTORED)
ZGEMM_INPUT_SCALE_GS_BETA_NZ
vaddpd(ymm4, ymm0, ymm0)
ZGEMM_OUTPUT_GS
add(rdx, rcx) // c += 2*cs_c;
ZGEMM_INPUT_SCALE_GS_BETA_NZ
vaddpd(ymm5, ymm0, ymm0)
ZGEMM_OUTPUT_GS
mov(r11, rcx) // rcx = c + 1*rs_c
ZGEMM_INPUT_SCALE_GS_BETA_NZ
vaddpd(ymm8, ymm0, ymm0)
ZGEMM_OUTPUT_GS
add(rdx, rcx) // c += 2*cs_c;
ZGEMM_INPUT_SCALE_GS_BETA_NZ
vaddpd(ymm9, ymm0, ymm0)
ZGEMM_OUTPUT_GS
mov(r12, rcx) // rcx = c + 2*rs_c
ZGEMM_INPUT_SCALE_GS_BETA_NZ
vaddpd(ymm12, ymm0, ymm0)
ZGEMM_OUTPUT_GS
add(rdx, rcx) // c += 2*cs_c;
ZGEMM_INPUT_SCALE_GS_BETA_NZ
vaddpd(ymm13, ymm0, ymm0)
ZGEMM_OUTPUT_GS
jmp(.ZDONE) // jump to end.
/* Row stored of C */
label(.ZROWSTORED)
cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT)
je(.GEN_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication.
cmp(imm(0xFF), al) // if(beta_mul_type == BLIS_MUL_MINUS_ONE)
je(.GEN_BETA_REAL_MINUS1) // jump to beta real = -1 section.
//CASE 1: beta is real = 1
label(.GEN_BETA_REAL_ONE)
vmovupd(mem(rcx), ymm0)
vaddpd(ymm4, ymm0, ymm0)
ZGEMM_OUTPUT_RS
add(rdx, rcx) // c += 2*cs_c;
vmovupd(mem(rcx), ymm0)
vaddpd(ymm5, ymm0, ymm0)
ZGEMM_OUTPUT_RS
mov(r11, rcx) // rcx = c + 1*rs_c
vmovupd(mem(rcx), ymm0)
vaddpd(ymm8, ymm0, ymm0)
ZGEMM_OUTPUT_RS
add(rdx, rcx) // c += 2*cs_c;
vmovupd(mem(rcx), ymm0)
vaddpd(ymm9, ymm0, ymm0)
ZGEMM_OUTPUT_RS
mov(r12, rcx) // rcx = c + 2*rs_c
vmovupd(mem(rcx), ymm0)
vaddpd(ymm12, ymm0, ymm0)
ZGEMM_OUTPUT_RS
add(rdx, rcx) // c += 2*cs_c;
vmovupd(mem(rcx), ymm0)
vaddpd(ymm13, ymm0, ymm0)
ZGEMM_OUTPUT_RS
jmp(.ZDONE) // jump to end.
//CASE 2: beta is real = -1
label(.GEN_BETA_REAL_MINUS1)
vmovupd(mem(rcx), ymm0)
vsubpd(ymm0, ymm4, ymm0)
ZGEMM_OUTPUT_RS
add(rdx, rcx) // c += 2*cs_c;
vmovupd(mem(rcx), ymm0)
vsubpd(ymm0, ymm5, ymm0)
ZGEMM_OUTPUT_RS
mov(r11, rcx) // rcx = c + 1*rs_c
vmovupd(mem(rcx), ymm0)
vsubpd(ymm0, ymm8, ymm0)
ZGEMM_OUTPUT_RS
add(rdx, rcx) // c += 2*cs_c;
vmovupd(mem(rcx), ymm0)
vsubpd(ymm0, ymm9, ymm0)
ZGEMM_OUTPUT_RS
mov(r12, rcx) // rcx = c + 2*rs_c
vmovupd(mem(rcx), ymm0)
vsubpd(ymm0, ymm12, ymm0)
ZGEMM_OUTPUT_RS
add(rdx, rcx) // c += 2*cs_c;
vmovupd(mem(rcx), ymm0)
vsubpd(ymm0, ymm13, ymm0)
ZGEMM_OUTPUT_RS
jmp(.ZDONE) // jump to end.
//CASE 3: Default case with multiplication
// beta not equal to (+/-1) or zero, do normal multiplication.
label(.GEN_BETA_NOT_REAL_ONE)
ZGEMM_INPUT_SCALE_RS_BETA_NZ
vaddpd(ymm4, ymm0, ymm0)
ZGEMM_OUTPUT_RS
add(rdx, rcx) // c += 2*cs_c;
ZGEMM_INPUT_SCALE_RS_BETA_NZ
vaddpd(ymm5, ymm0, ymm0)
ZGEMM_OUTPUT_RS
mov(r11, rcx) // rcx = c + 1*rs_c
ZGEMM_INPUT_SCALE_RS_BETA_NZ
vaddpd(ymm8, ymm0, ymm0)
ZGEMM_OUTPUT_RS
add(rdx, rcx) // c += 2*cs_c;
ZGEMM_INPUT_SCALE_RS_BETA_NZ
vaddpd(ymm9, ymm0, ymm0)
ZGEMM_OUTPUT_RS
mov(r12, rcx) // rcx = c + 2*rs_c
ZGEMM_INPUT_SCALE_RS_BETA_NZ
vaddpd(ymm12, ymm0, ymm0)
ZGEMM_OUTPUT_RS
add(rdx, rcx) // c += 2*cs_c;
ZGEMM_INPUT_SCALE_RS_BETA_NZ
vaddpd(ymm13, ymm0, ymm0)
ZGEMM_OUTPUT_RS
jmp(.ZDONE) // jump to end.
label(.ZBETAZERO)
cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16.
jz(.ZROWSTORBZ) // jump to row storage case
label(.ZGENSTORBZ)
vmovapd(ymm4, ymm0)
ZGEMM_OUTPUT_GS
add(rdx, rcx) // c += 2*cs_c;
vmovapd(ymm5, ymm0)
ZGEMM_OUTPUT_GS
mov(r11, rcx) // rcx = c + 1*rs_c
vmovapd(ymm8, ymm0)
ZGEMM_OUTPUT_GS
add(rdx, rcx) // c += 2*cs_c;
vmovapd(ymm9, ymm0)
ZGEMM_OUTPUT_GS
mov(r12, rcx) // rcx = c + 2*rs_c
vmovapd(ymm12, ymm0)
ZGEMM_OUTPUT_GS
add(rdx, rcx) // c += 2*cs_c;
vmovapd(ymm13, ymm0)
ZGEMM_OUTPUT_GS
jmp(.ZDONE) // jump to end.
label(.ZROWSTORBZ)
vmovupd(ymm4, mem(rcx))
vmovupd(ymm5, mem(rcx, rdx, 1))
vmovupd(ymm8, mem(r11))
vmovupd(ymm9, mem(r11, rdx, 1))
vmovupd(ymm12, mem(r12))
vmovupd(ymm13, mem(r12, rdx, 1))
label(.ZDONE)
end_asm(
: // output operands (none)
: // input operands
[alpha_mul_type] "m" (alpha_mul_type),
[beta_mul_type] "m" (beta_mul_type),
[k_iter] "m" (k_iter), // 0
[k_left] "m" (k_left), // 1
[a] "m" (a), // 2

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2020, Advanced Micro Devices, Inc.
Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -56,6 +56,9 @@
vmulpd(ymm2, ymm3, ymm3) \
vaddsubpd(ymm3, ymm0, ymm0)
#define ZGEMM_INPUT_RS_BETA_ONE \
vmovupd(mem(rcx), ymm0)
#define ZGEMM_OUTPUT_RS \
vmovupd(ymm0, mem(rcx)) \
@@ -66,8 +69,11 @@
vmulpd(ymm2, ymm3, ymm3) \
vaddsubpd(ymm3, ymm0, ymm0)
#define ZGEMM_INPUT_RS_BETA_ONE_NEXT \
vmovupd(mem(rcx, rsi, 8), ymm0)
#define ZGEMM_OUTPUT_RS_NEXT \
vmovupd(ymm0, mem(rcx, rsi, 8)) \
vmovupd(ymm0, mem(rcx, rsi, 8))
/*
rrr:
@@ -174,6 +180,41 @@ void bli_zgemmsup_rv_zen_asm_3x4m
if ( m_iter == 0 ) goto consider_edge_cases;
//handling case when alpha and beta are real and +/-1.
uint64_t alpha_real_one = *((uint64_t*)(&alpha->real));
uint64_t beta_real_one = *((uint64_t*)(&beta->real));
uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1);
uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1);
char alpha_mul_type = BLIS_MUL_DEFAULT;
char beta_mul_type = BLIS_MUL_DEFAULT;
if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1)
{
alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1
if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE)
{
alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1
}
}
if(beta->imag == 0)// beta is real
{
if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1)
{
beta_mul_type = BLIS_MUL_ONE;
if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE)
{
beta_mul_type = BLIS_MUL_MINUS_ONE;
}
}
else if(beta_real_one == 0)
{
beta_mul_type = BLIS_MUL_ZERO;
}
}
// -------------------------------------------------------------------------
begin_asm()
@@ -442,10 +483,30 @@ void bli_zgemmsup_rv_zen_asm_3x4m
vaddsubpd(ymm14, ymm12, ymm12)
vaddsubpd(ymm15, ymm13, ymm13)
//if(alpha_mul_type == BLIS_MUL_MINUS_ONE)
mov(var(alpha_mul_type), al)
cmp(imm(0xFF), al)
jne(.ALPHA_NOT_MINUS1)
// when alpha = -1 and real.
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
vsubpd(ymm4, ymm0, ymm4)
vsubpd(ymm5, ymm0, ymm5)
vsubpd(ymm8, ymm0, ymm8)
vsubpd(ymm9, ymm0, ymm9)
vsubpd(ymm12, ymm0, ymm12)
vsubpd(ymm13, ymm0, ymm13)
jmp(.ALPHA_REAL_ONE)
label(.ALPHA_NOT_MINUS1)
//when alpha is real and +/-1, multiplication is skipped.
cmp(imm(2), al)//if(alpha_mul_type != BLIS_MUL_DEFAULT) skip below multiplication.
jne(.ALPHA_REAL_ONE)
/* (ar + ai) x AB */
mov(var(alpha), rax) // load address of alpha
vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate
vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate
mov(var(alpha), rax) // load address of alpha
vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate
vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate
vpermilpd(imm(0x5), ymm4, ymm3)
vmulpd(ymm0, ymm4, ymm4)
@@ -477,32 +538,93 @@ void bli_zgemmsup_rv_zen_asm_3x4m
vmulpd(ymm1, ymm3, ymm3)
vaddsubpd(ymm3, ymm13, ymm13)
/* (ßr + ßi)x C + ((ar + ai) x AB) */
mov(var(beta), rbx) // load address of beta
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
label(.ALPHA_REAL_ONE)
// Beta multiplication
/* (br + bi)x C + ((ar + ai) x AB) */
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
mov(var(cs_c), rsi) // load cs_c
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
// now avoid loading C if beta == 0
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
vucomisd(xmm0, xmm1) // set ZF if beta_r == 0.
sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 );
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 );
and(r13b, r15b) // set ZF if r13b & r15b == 1.
jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
mov(var(beta_mul_type), al)
cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO)
je(.SBETAZERO) //jump to beta == 0 case
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16.
cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16.
jz(.SCOLSTORED) // jump to column storage case
label(.SROWSTORED)
cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT)
je(.ROW_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication.
cmp(imm(0xFF), al) // if(beta_mul_type == BLIS_MUL_MINUS_ONE)
je(.ROW_BETA_REAL_MINUS1) // jump to beta real = -1 section.
//CASE 1: beta is real = 1
ZGEMM_INPUT_RS_BETA_ONE
vaddpd(ymm4, ymm0, ymm0)
ZGEMM_OUTPUT_RS
ZGEMM_INPUT_RS_BETA_ONE_NEXT
vaddpd(ymm5, ymm0, ymm0)
ZGEMM_OUTPUT_RS_NEXT
add(rdi, rcx) // rcx = c + 1*rs_c
ZGEMM_INPUT_RS_BETA_ONE
vaddpd(ymm8, ymm0, ymm0)
ZGEMM_OUTPUT_RS
ZGEMM_INPUT_RS_BETA_ONE_NEXT
vaddpd(ymm9, ymm0, ymm0)
ZGEMM_OUTPUT_RS_NEXT
add(rdi, rcx) // rcx = c + 2*rs_c
ZGEMM_INPUT_RS_BETA_ONE
vaddpd(ymm12, ymm0, ymm0)
ZGEMM_OUTPUT_RS
ZGEMM_INPUT_RS_BETA_ONE_NEXT
vaddpd(ymm13, ymm0, ymm0)
ZGEMM_OUTPUT_RS_NEXT
jmp(.SDONE)
//CASE 2: beta is real = -1
label(.ROW_BETA_REAL_MINUS1)
ZGEMM_INPUT_RS_BETA_ONE
vsubpd(ymm0, ymm4, ymm0)
ZGEMM_OUTPUT_RS
ZGEMM_INPUT_RS_BETA_ONE_NEXT
vsubpd(ymm0, ymm5, ymm0)
ZGEMM_OUTPUT_RS_NEXT
add(rdi, rcx) // rcx = c + 1*rs_c
ZGEMM_INPUT_RS_BETA_ONE
vsubpd(ymm0, ymm8, ymm0)
ZGEMM_OUTPUT_RS
ZGEMM_INPUT_RS_BETA_ONE_NEXT
vsubpd(ymm0, ymm9, ymm0)
ZGEMM_OUTPUT_RS_NEXT
add(rdi, rcx) // rcx = c + 2*rs_c
ZGEMM_INPUT_RS_BETA_ONE
vsubpd(ymm0, ymm12, ymm0)
ZGEMM_OUTPUT_RS
ZGEMM_INPUT_RS_BETA_ONE_NEXT
vsubpd(ymm0, ymm13, ymm0)
ZGEMM_OUTPUT_RS_NEXT
jmp(.SDONE)
//CASE 3: Default case with multiplication
// beta not equal to (+/-1) or zero, do normal multiplication.
label(.ROW_BETA_NOT_REAL_ONE)
mov(var(beta), rbx) // load address of beta
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
ZGEMM_INPUT_SCALE_RS_BETA_NZ
vaddpd(ymm4, ymm0, ymm0)
@@ -529,10 +651,12 @@ void bli_zgemmsup_rv_zen_asm_3x4m
ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT
vaddpd(ymm13, ymm0, ymm0)
ZGEMM_OUTPUT_RS_NEXT
jmp(.SDONE) // jump to end.
label(.SCOLSTORED)
mov(var(beta), rbx) // load address of beta
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
/*|--------| |-------|
| | | |
| 3x4 | | 4x3 |
@@ -679,6 +803,8 @@ void bli_zgemmsup_rv_zen_asm_3x4m
end_asm(
: // output operands (none)
: // input operands
[alpha_mul_type] "m" (alpha_mul_type),
[beta_mul_type] "m" (beta_mul_type),
[m_iter] "m" (m_iter),
[k_iter] "m" (k_iter),
[k_left] "m" (k_left),
@@ -1025,7 +1151,7 @@ void bli_zgemmsup_rv_zen_asm_3x2m
vmulpd(ymm1, ymm3, ymm3)
vaddsubpd(ymm3, ymm12, ymm12)
/* (ßr + ßi)x C + ((ar + ai) x AB) */
/* (br + bi)x C + ((ar + ai) x AB) */
mov(var(beta), rbx) // load address of beta
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
@@ -1226,4 +1352,4 @@ void bli_zgemmsup_rv_zen_asm_3x2m
);
return;
}
}
}

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2020, Advanced Micro Devices, Inc.
Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -123,6 +123,41 @@ void bli_zgemmsup_rv_zen_asm_3x4n
if ( n_iter == 0 ) goto consider_edge_cases;
//handling case when alpha and beta are real and +/-1.
uint64_t alpha_real_one = *((uint64_t*)(&alpha->real));
uint64_t beta_real_one = *((uint64_t*)(&beta->real));
uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1);
uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1);
char alpha_mul_type = BLIS_MUL_DEFAULT;
char beta_mul_type = BLIS_MUL_DEFAULT;
if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1)
{
alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1
if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE)
{
alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1
}
}
if(beta->imag == 0)// beta is real
{
if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1)
{
beta_mul_type = BLIS_MUL_ONE;
if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE)
{
beta_mul_type = BLIS_MUL_MINUS_ONE;
}
}
else if(beta_real_one == 0)
{
beta_mul_type = BLIS_MUL_ZERO;
}
}
// -------------------------------------------------------------------------
//scratch registers
__m256d ymm0, ymm1, ymm2, ymm3;
@@ -217,45 +252,60 @@ void bli_zgemmsup_rv_zen_asm_3x4n
ymm12 = _mm256_addsub_pd( ymm12, ymm14);
ymm13 = _mm256_addsub_pd( ymm13, ymm15);
// alpha, beta multiplication.
//When alpha_real = -1.0, instead of multiplying with -1, sign is changed.
if(alpha_mul_type == BLIS_MUL_MINUS_ONE)// equivalent to if(alpha->real == -1.0)
{
ymm0 = _mm256_setzero_pd();
ymm4 = _mm256_sub_pd(ymm0,ymm4);
ymm5 = _mm256_sub_pd(ymm0, ymm5);
ymm8 = _mm256_sub_pd(ymm0, ymm8);
ymm9 = _mm256_sub_pd(ymm0, ymm9);
ymm12 = _mm256_sub_pd(ymm0, ymm12);
ymm13 = _mm256_sub_pd(ymm0, ymm13);
}
/* (ar + ai) x AB */
ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate
ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate
//when alpha is real and +/-1, multiplication is skipped.
if(alpha_mul_type == BLIS_MUL_DEFAULT)
{
// alpha, beta multiplication.
/* (ar + ai) x AB */
ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate
ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate
ymm3 = _mm256_permute_pd(ymm4, 5);
ymm4 = _mm256_mul_pd(ymm0, ymm4);
ymm3 =_mm256_mul_pd(ymm1, ymm3);
ymm4 = _mm256_addsub_pd(ymm4, ymm3);
ymm3 = _mm256_permute_pd(ymm4, 5);
ymm4 = _mm256_mul_pd(ymm0, ymm4);
ymm3 =_mm256_mul_pd(ymm1, ymm3);
ymm4 = _mm256_addsub_pd(ymm4, ymm3);
ymm3 = _mm256_permute_pd(ymm5, 5);
ymm5 = _mm256_mul_pd(ymm0, ymm5);
ymm3 = _mm256_mul_pd(ymm1, ymm3);
ymm5 = _mm256_addsub_pd(ymm5, ymm3);
ymm3 = _mm256_permute_pd(ymm5, 5);
ymm5 = _mm256_mul_pd(ymm0, ymm5);
ymm3 = _mm256_mul_pd(ymm1, ymm3);
ymm5 = _mm256_addsub_pd(ymm5, ymm3);
ymm3 = _mm256_permute_pd(ymm8, 5);
ymm8 = _mm256_mul_pd(ymm0, ymm8);
ymm3 = _mm256_mul_pd(ymm1, ymm3);
ymm8 = _mm256_addsub_pd(ymm8, ymm3);
ymm3 = _mm256_permute_pd(ymm8, 5);
ymm8 = _mm256_mul_pd(ymm0, ymm8);
ymm3 = _mm256_mul_pd(ymm1, ymm3);
ymm8 = _mm256_addsub_pd(ymm8, ymm3);
ymm3 = _mm256_permute_pd(ymm9, 5);
ymm9 = _mm256_mul_pd(ymm0, ymm9);
ymm3 = _mm256_mul_pd(ymm1, ymm3);
ymm9 = _mm256_addsub_pd(ymm9, ymm3);
ymm3 = _mm256_permute_pd(ymm9, 5);
ymm9 = _mm256_mul_pd(ymm0, ymm9);
ymm3 = _mm256_mul_pd(ymm1, ymm3);
ymm9 = _mm256_addsub_pd(ymm9, ymm3);
ymm3 = _mm256_permute_pd(ymm12, 5);
ymm12 = _mm256_mul_pd(ymm0, ymm12);
ymm3 = _mm256_mul_pd(ymm1, ymm3);
ymm12 = _mm256_addsub_pd(ymm12, ymm3);
ymm3 = _mm256_permute_pd(ymm12, 5);
ymm12 = _mm256_mul_pd(ymm0, ymm12);
ymm3 = _mm256_mul_pd(ymm1, ymm3);
ymm12 = _mm256_addsub_pd(ymm12, ymm3);
ymm3 = _mm256_permute_pd(ymm13, 5);
ymm13 = _mm256_mul_pd(ymm0, ymm13);
ymm3 = _mm256_mul_pd(ymm1, ymm3);
ymm13 = _mm256_addsub_pd(ymm13, ymm3);
ymm3 = _mm256_permute_pd(ymm13, 5);
ymm13 = _mm256_mul_pd(ymm0, ymm13);
ymm3 = _mm256_mul_pd(ymm1, ymm3);
ymm13 = _mm256_addsub_pd(ymm13, ymm3);
}
if(tc_inc_row == 1) //col stored
{
if(beta->real == 0.0 && beta->imag == 0.0)
if(beta_mul_type == BLIS_MUL_ZERO)
{
//transpose left 3x2
_mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4));
@@ -364,7 +414,7 @@ void bli_zgemmsup_rv_zen_asm_3x4n
}
else
{
if(beta->real == 0.0 && beta->imag == 0.0)
if(beta_mul_type == BLIS_MUL_ZERO)
{
_mm256_storeu_pd((double *)(tC), ymm4);
_mm256_storeu_pd((double *)(tC + 2), ymm5);
@@ -373,6 +423,28 @@ void bli_zgemmsup_rv_zen_asm_3x4n
_mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12);
_mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13);
}
else if(beta_mul_type == BLIS_MUL_ONE)// equivalent to if(beta->real == 1.0)
{
ymm2 = _mm256_loadu_pd((double const *)(tC));
ymm4 = _mm256_add_pd(ymm4,ymm2);
ymm2 = _mm256_loadu_pd((double const *)(tC+2));
ymm5 = _mm256_add_pd(ymm5,ymm2);
ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row));
ymm8 = _mm256_add_pd(ymm8,ymm2);
ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2));
ymm9 = _mm256_add_pd(ymm9,ymm2);
ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2));
ymm12 = _mm256_add_pd(ymm12,ymm2);
ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2));
ymm13 = _mm256_add_pd(ymm13,ymm2);
_mm256_storeu_pd((double *)(tC), ymm4);
_mm256_storeu_pd((double *)(tC + 2), ymm5);
_mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8);
_mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9);
_mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12);
_mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13);
}
else{
/* (br + bi) C + (ar + ai) AB */
ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate