mirror of
https://github.com/amd/blis.git
synced 2026-05-11 17:50:00 +00:00
Handling zgemm real(+/-1) alpha and beta
1.Improved performance when zgemm's alpha and beta are real and equal to +/-1. 2.change done in bli_zgemmsup_rv_zen_asm_3x4n. 3.change done in bli_zgemmsup_rv_zen_asm_3x4m. 4.change done in bli_zgemm_haswell_asm_3x4. Change-Id: Ic14d8507b264c24a8748febf6bc73eb60e476430 AMD-Internal: [CPUPL-1352]
This commit is contained in:
committed by
Madan Mohan Manokar
parent
1ff4981203
commit
3ab9104dae
@@ -5,6 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -35,3 +36,12 @@
|
||||
void bli_const_init( void );
|
||||
void bli_const_finalize( void );
|
||||
|
||||
/* constant used to check 1 and -1 when double converted uint64 */
|
||||
#define BLIS_DOUBLE_TO_UINT64_ONE_ABS 0x3ff0000000000000
|
||||
#define BLIS_DOUBLE_TO_UINT64_MINUS_ONE 0xbff0000000000000
|
||||
|
||||
/* enum used to clasify alpha and beta to one of the below category */
|
||||
enum mulfactor { BLIS_MUL_MINUS_ONE = -1,
|
||||
BLIS_MUL_ZERO,
|
||||
BLIS_MUL_ONE,
|
||||
BLIS_MUL_DEFAULT };
|
||||
@@ -5,7 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc.
|
||||
Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -2197,10 +2197,11 @@ void bli_cgemm_haswell_asm_3x8
|
||||
vmulpd(ymm1, ymm0, ymm0) \
|
||||
vmulpd(ymm2, ymm3, ymm3) \
|
||||
vaddsubpd(ymm3, ymm0, ymm0)
|
||||
|
||||
|
||||
#define ZGEMM_OUTPUT_RS \
|
||||
vmovupd(ymm0, mem(rcx)) \
|
||||
|
||||
|
||||
void bli_zgemm_haswell_asm_3x4
|
||||
(
|
||||
dim_t k0,
|
||||
@@ -2223,71 +2224,105 @@ void bli_zgemm_haswell_asm_3x4
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
//handling case when alpha and beta are real and +/-1.
|
||||
uint64_t alpha_real_one = *((uint64_t*)(&alpha->real));
|
||||
uint64_t beta_real_one = *((uint64_t*)(&beta->real));
|
||||
|
||||
uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1);
|
||||
uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1);
|
||||
|
||||
char alpha_mul_type = BLIS_MUL_DEFAULT;
|
||||
char beta_mul_type = BLIS_MUL_DEFAULT;
|
||||
|
||||
if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1)
|
||||
{
|
||||
alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1
|
||||
if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE)
|
||||
{
|
||||
alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1
|
||||
}
|
||||
}
|
||||
|
||||
if(beta->imag == 0)// beta is real
|
||||
{
|
||||
if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1)
|
||||
{
|
||||
beta_mul_type = BLIS_MUL_ONE;
|
||||
if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE)
|
||||
{
|
||||
beta_mul_type = BLIS_MUL_MINUS_ONE;
|
||||
}
|
||||
}
|
||||
else if(beta_real_one == 0)
|
||||
{
|
||||
beta_mul_type = BLIS_MUL_ZERO;
|
||||
}
|
||||
}
|
||||
|
||||
begin_asm()
|
||||
|
||||
|
||||
vzeroall() // zero all xmm/ymm registers.
|
||||
|
||||
|
||||
|
||||
mov(var(a), rax) // load address of a.
|
||||
mov(var(b), rbx) // load address of b.
|
||||
//mov(%9, r15) // load address of b_next.
|
||||
|
||||
|
||||
add(imm(32*4), rbx)
|
||||
// initialize loop by pre-loading
|
||||
vmovapd(mem(rbx, -4*32), ymm0)
|
||||
vmovapd(mem(rbx, -3*32), ymm1)
|
||||
|
||||
|
||||
mov(var(c), rcx) // load address of c
|
||||
mov(var(rs_c), rdi) // load rs_c
|
||||
lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dcomplex)
|
||||
lea(mem(, rdi, 2), rdi)
|
||||
|
||||
|
||||
lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*rs_c;
|
||||
lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*rs_c;
|
||||
|
||||
|
||||
prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c
|
||||
prefetch(0, mem(r11, 7*8)) // prefetch c + 1*rs_c
|
||||
prefetch(0, mem(r12, 7*8)) // prefetch c + 2*rs_c
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mov(var(k_iter), rsi) // i = k_iter;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.ZCONSIDKLEFT) // if i == 0, jump to code that
|
||||
// contains the k_left loop.
|
||||
|
||||
|
||||
|
||||
|
||||
label(.ZLOOPKITER) // MAIN LOOP
|
||||
|
||||
|
||||
|
||||
|
||||
// iteration 0
|
||||
prefetch(0, mem(rax, 32*16))
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 0*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 1*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm4)
|
||||
vfmadd231pd(ymm1, ymm2, ymm5)
|
||||
vfmadd231pd(ymm0, ymm3, ymm6)
|
||||
vfmadd231pd(ymm1, ymm3, ymm7)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 2*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 3*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm8)
|
||||
vfmadd231pd(ymm1, ymm2, ymm9)
|
||||
vfmadd231pd(ymm0, ymm3, ymm10)
|
||||
vfmadd231pd(ymm1, ymm3, ymm11)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 4*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 5*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm12)
|
||||
vfmadd231pd(ymm1, ymm2, ymm13)
|
||||
vfmadd231pd(ymm0, ymm3, ymm14)
|
||||
vfmadd231pd(ymm1, ymm3, ymm15)
|
||||
|
||||
|
||||
vmovapd(mem(rbx, -2*32), ymm0)
|
||||
vmovapd(mem(rbx, -1*32), ymm1)
|
||||
|
||||
|
||||
// iteration 1
|
||||
vbroadcastsd(mem(rax, 6*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 7*8), ymm3)
|
||||
@@ -2295,51 +2330,51 @@ void bli_zgemm_haswell_asm_3x4
|
||||
vfmadd231pd(ymm1, ymm2, ymm5)
|
||||
vfmadd231pd(ymm0, ymm3, ymm6)
|
||||
vfmadd231pd(ymm1, ymm3, ymm7)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 8*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 9*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm8)
|
||||
vfmadd231pd(ymm1, ymm2, ymm9)
|
||||
vfmadd231pd(ymm0, ymm3, ymm10)
|
||||
vfmadd231pd(ymm1, ymm3, ymm11)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 10*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 11*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm12)
|
||||
vfmadd231pd(ymm1, ymm2, ymm13)
|
||||
vfmadd231pd(ymm0, ymm3, ymm14)
|
||||
vfmadd231pd(ymm1, ymm3, ymm15)
|
||||
|
||||
|
||||
vmovapd(mem(rbx, 0*32), ymm0)
|
||||
vmovapd(mem(rbx, 1*32), ymm1)
|
||||
|
||||
|
||||
// iteration 2
|
||||
prefetch(0, mem(rax, 38*16))
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 12*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 13*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm4)
|
||||
vfmadd231pd(ymm1, ymm2, ymm5)
|
||||
vfmadd231pd(ymm0, ymm3, ymm6)
|
||||
vfmadd231pd(ymm1, ymm3, ymm7)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 14*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 15*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm8)
|
||||
vfmadd231pd(ymm1, ymm2, ymm9)
|
||||
vfmadd231pd(ymm0, ymm3, ymm10)
|
||||
vfmadd231pd(ymm1, ymm3, ymm11)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 16*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 17*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm12)
|
||||
vfmadd231pd(ymm1, ymm2, ymm13)
|
||||
vfmadd231pd(ymm0, ymm3, ymm14)
|
||||
vfmadd231pd(ymm1, ymm3, ymm15)
|
||||
|
||||
|
||||
vmovapd(mem(rbx, 2*32), ymm0)
|
||||
vmovapd(mem(rbx, 3*32), ymm1)
|
||||
|
||||
|
||||
// iteration 3
|
||||
vbroadcastsd(mem(rax, 18*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 19*8), ymm3)
|
||||
@@ -2347,83 +2382,83 @@ void bli_zgemm_haswell_asm_3x4
|
||||
vfmadd231pd(ymm1, ymm2, ymm5)
|
||||
vfmadd231pd(ymm0, ymm3, ymm6)
|
||||
vfmadd231pd(ymm1, ymm3, ymm7)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 20*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 21*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm8)
|
||||
vfmadd231pd(ymm1, ymm2, ymm9)
|
||||
vfmadd231pd(ymm0, ymm3, ymm10)
|
||||
vfmadd231pd(ymm1, ymm3, ymm11)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 22*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 23*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm12)
|
||||
vfmadd231pd(ymm1, ymm2, ymm13)
|
||||
vfmadd231pd(ymm0, ymm3, ymm14)
|
||||
vfmadd231pd(ymm1, ymm3, ymm15)
|
||||
|
||||
|
||||
add(imm(4*3*16), rax) // a += 4*3 (unroll x mr)
|
||||
add(imm(4*4*16), rbx) // b += 4*4 (unroll x nr)
|
||||
|
||||
|
||||
vmovapd(mem(rbx, -4*32), ymm0)
|
||||
vmovapd(mem(rbx, -3*32), ymm1)
|
||||
|
||||
|
||||
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.ZLOOPKITER) // iterate again if i != 0.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.ZCONSIDKLEFT)
|
||||
|
||||
|
||||
mov(var(k_left), rsi) // i = k_left;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.ZPOSTACCUM) // if i == 0, we're done; jump to end.
|
||||
// else, we prepare to enter k_left loop.
|
||||
|
||||
|
||||
|
||||
|
||||
label(.ZLOOPKLEFT) // EDGE LOOP
|
||||
|
||||
|
||||
prefetch(0, mem(rax, 32*16))
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 0*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 1*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm4)
|
||||
vfmadd231pd(ymm1, ymm2, ymm5)
|
||||
vfmadd231pd(ymm0, ymm3, ymm6)
|
||||
vfmadd231pd(ymm1, ymm3, ymm7)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 2*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 3*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm8)
|
||||
vfmadd231pd(ymm1, ymm2, ymm9)
|
||||
vfmadd231pd(ymm0, ymm3, ymm10)
|
||||
vfmadd231pd(ymm1, ymm3, ymm11)
|
||||
|
||||
|
||||
vbroadcastsd(mem(rax, 4*8), ymm2)
|
||||
vbroadcastsd(mem(rax, 5*8), ymm3)
|
||||
vfmadd231pd(ymm0, ymm2, ymm12)
|
||||
vfmadd231pd(ymm1, ymm2, ymm13)
|
||||
vfmadd231pd(ymm0, ymm3, ymm14)
|
||||
vfmadd231pd(ymm1, ymm3, ymm15)
|
||||
|
||||
|
||||
add(imm(1*3*16), rax) // a += 1*3 (unroll x mr)
|
||||
add(imm(1*4*16), rbx) // b += 1*4 (unroll x nr)
|
||||
|
||||
|
||||
vmovapd(mem(rbx, -4*32), ymm0)
|
||||
vmovapd(mem(rbx, -3*32), ymm1)
|
||||
|
||||
|
||||
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.ZLOOPKLEFT) // iterate again if i != 0.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.ZPOSTACCUM)
|
||||
|
||||
|
||||
// permute even and odd elements
|
||||
// of ymm6/7, ymm10/11, ymm/14/15
|
||||
vpermilpd(imm(0x5), ymm6, ymm6)
|
||||
@@ -2432,251 +2467,306 @@ void bli_zgemm_haswell_asm_3x4
|
||||
vpermilpd(imm(0x5), ymm11, ymm11)
|
||||
vpermilpd(imm(0x5), ymm14, ymm14)
|
||||
vpermilpd(imm(0x5), ymm15, ymm15)
|
||||
|
||||
|
||||
|
||||
|
||||
// subtract/add even/odd elements
|
||||
vaddsubpd(ymm6, ymm4, ymm4)
|
||||
vaddsubpd(ymm7, ymm5, ymm5)
|
||||
|
||||
|
||||
vaddsubpd(ymm10, ymm8, ymm8)
|
||||
vaddsubpd(ymm11, ymm9, ymm9)
|
||||
|
||||
|
||||
vaddsubpd(ymm14, ymm12, ymm12)
|
||||
vaddsubpd(ymm15, ymm13, ymm13)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
//if(alpha_mul_type == BLIS_MUL_MINUS_ONE)
|
||||
mov(var(alpha_mul_type), al)
|
||||
cmp(imm(0xFF), al)
|
||||
jne(.ALPHA_NOT_MINUS1)
|
||||
|
||||
// when alpha = -1 and real.
|
||||
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
|
||||
vsubpd(ymm4, ymm0, ymm4)
|
||||
vsubpd(ymm5, ymm0, ymm5)
|
||||
vsubpd(ymm8, ymm0, ymm8)
|
||||
vsubpd(ymm9, ymm0, ymm9)
|
||||
vsubpd(ymm12, ymm0, ymm12)
|
||||
vsubpd(ymm13, ymm0, ymm13)
|
||||
jmp(.ALPHA_REAL_ONE)
|
||||
|
||||
label(.ALPHA_NOT_MINUS1)
|
||||
//when alpha is real and +/-1, multiplication is skipped.
|
||||
cmp(imm(2), al)//if(alpha_mul_type != BLIS_MUL_DEFAULT) skip below multiplication.
|
||||
jne(.ALPHA_REAL_ONE)
|
||||
|
||||
|
||||
mov(var(alpha), rax) // load address of alpha
|
||||
vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate
|
||||
vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate
|
||||
|
||||
|
||||
|
||||
|
||||
vpermilpd(imm(0x5), ymm4, ymm3)
|
||||
vmulpd(ymm0, ymm4, ymm4)
|
||||
vmulpd(ymm1, ymm3, ymm3)
|
||||
vaddsubpd(ymm3, ymm4, ymm4)
|
||||
|
||||
|
||||
vpermilpd(imm(0x5), ymm5, ymm3)
|
||||
vmulpd(ymm0, ymm5, ymm5)
|
||||
vmulpd(ymm1, ymm3, ymm3)
|
||||
vaddsubpd(ymm3, ymm5, ymm5)
|
||||
|
||||
|
||||
|
||||
|
||||
vpermilpd(imm(0x5), ymm8, ymm3)
|
||||
vmulpd(ymm0, ymm8, ymm8)
|
||||
vmulpd(ymm1, ymm3, ymm3)
|
||||
vaddsubpd(ymm3, ymm8, ymm8)
|
||||
|
||||
|
||||
vpermilpd(imm(0x5), ymm9, ymm3)
|
||||
vmulpd(ymm0, ymm9, ymm9)
|
||||
vmulpd(ymm1, ymm3, ymm3)
|
||||
vaddsubpd(ymm3, ymm9, ymm9)
|
||||
|
||||
|
||||
|
||||
|
||||
vpermilpd(imm(0x5), ymm12, ymm3)
|
||||
vmulpd(ymm0, ymm12, ymm12)
|
||||
vmulpd(ymm1, ymm3, ymm3)
|
||||
vaddsubpd(ymm3, ymm12, ymm12)
|
||||
|
||||
|
||||
vpermilpd(imm(0x5), ymm13, ymm3)
|
||||
vmulpd(ymm0, ymm13, ymm13)
|
||||
vmulpd(ymm1, ymm3, ymm3)
|
||||
vaddsubpd(ymm3, ymm13, ymm13)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
|
||||
|
||||
|
||||
|
||||
label(.ALPHA_REAL_ONE)
|
||||
// Beta multiplication
|
||||
/* (br + bi)x C + ((ar + ai) x AB) */
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
|
||||
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dcomplex)
|
||||
lea(mem(, rsi, 2), rsi)
|
||||
lea(mem(, rsi, 2), rdx) // rdx = 2*cs_c;
|
||||
|
||||
|
||||
|
||||
// now avoid loading C if beta == 0
|
||||
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
|
||||
vucomisd(xmm0, xmm1) // set ZF if beta_r == 0.
|
||||
sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 );
|
||||
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
|
||||
sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 );
|
||||
and(r8b, r9b) // set ZF if r8b & r9b == 1.
|
||||
jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case
|
||||
|
||||
|
||||
|
||||
// now avoid loading C if beta == 0
|
||||
mov(var(beta_mul_type), al)
|
||||
cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO)
|
||||
je(.ZBETAZERO) //jump to beta == 0 case
|
||||
|
||||
|
||||
cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16.
|
||||
jz(.ZROWSTORED) // jump to row storage case
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.ZGENSTORED)
|
||||
|
||||
|
||||
ZGEMM_INPUT_SCALE_GS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_GS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
|
||||
|
||||
ZGEMM_INPUT_SCALE_GS_BETA_NZ
|
||||
vaddpd(ymm5, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_GS
|
||||
mov(r11, rcx) // rcx = c + 1*rs_c
|
||||
|
||||
|
||||
|
||||
|
||||
ZGEMM_INPUT_SCALE_GS_BETA_NZ
|
||||
vaddpd(ymm8, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_GS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
|
||||
|
||||
ZGEMM_INPUT_SCALE_GS_BETA_NZ
|
||||
vaddpd(ymm9, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_GS
|
||||
mov(r12, rcx) // rcx = c + 2*rs_c
|
||||
|
||||
|
||||
|
||||
|
||||
ZGEMM_INPUT_SCALE_GS_BETA_NZ
|
||||
vaddpd(ymm12, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_GS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
|
||||
|
||||
ZGEMM_INPUT_SCALE_GS_BETA_NZ
|
||||
vaddpd(ymm13, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_GS
|
||||
|
||||
|
||||
|
||||
jmp(.ZDONE) // jump to end.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
/* Row stored of C */
|
||||
label(.ZROWSTORED)
|
||||
|
||||
|
||||
cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT)
|
||||
je(.GEN_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication.
|
||||
|
||||
cmp(imm(0xFF), al) // if(beta_mul_type == BLIS_MUL_MINUS_ONE)
|
||||
je(.GEN_BETA_REAL_MINUS1) // jump to beta real = -1 section.
|
||||
|
||||
//CASE 1: beta is real = 1
|
||||
label(.GEN_BETA_REAL_ONE)
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
vaddpd(ymm4, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
vaddpd(ymm5, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
mov(r11, rcx) // rcx = c + 1*rs_c
|
||||
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
vaddpd(ymm8, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
vaddpd(ymm9, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
mov(r12, rcx) // rcx = c + 2*rs_c
|
||||
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
vaddpd(ymm12, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
vaddpd(ymm13, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
jmp(.ZDONE) // jump to end.
|
||||
|
||||
//CASE 2: beta is real = -1
|
||||
label(.GEN_BETA_REAL_MINUS1)
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
vsubpd(ymm0, ymm4, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
vsubpd(ymm0, ymm5, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
mov(r11, rcx) // rcx = c + 1*rs_c
|
||||
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
vsubpd(ymm0, ymm8, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
vsubpd(ymm0, ymm9, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
mov(r12, rcx) // rcx = c + 2*rs_c
|
||||
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
vsubpd(ymm0, ymm12, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
vsubpd(ymm0, ymm13, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
jmp(.ZDONE) // jump to end.
|
||||
|
||||
//CASE 3: Default case with multiplication
|
||||
// beta not equal to (+/-1) or zero, do normal multiplication.
|
||||
label(.GEN_BETA_NOT_REAL_ONE)
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
|
||||
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ
|
||||
vaddpd(ymm5, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
mov(r11, rcx) // rcx = c + 1*rs_c
|
||||
|
||||
|
||||
|
||||
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ
|
||||
vaddpd(ymm8, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
|
||||
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ
|
||||
vaddpd(ymm9, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
mov(r12, rcx) // rcx = c + 2*rs_c
|
||||
|
||||
|
||||
|
||||
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ
|
||||
vaddpd(ymm12, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
|
||||
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ
|
||||
vaddpd(ymm13, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
|
||||
|
||||
|
||||
jmp(.ZDONE) // jump to end.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.ZBETAZERO)
|
||||
|
||||
cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16.
|
||||
jz(.ZROWSTORBZ) // jump to row storage case
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.ZGENSTORBZ)
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm4, ymm0)
|
||||
ZGEMM_OUTPUT_GS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm5, ymm0)
|
||||
ZGEMM_OUTPUT_GS
|
||||
mov(r11, rcx) // rcx = c + 1*rs_c
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm8, ymm0)
|
||||
ZGEMM_OUTPUT_GS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm9, ymm0)
|
||||
ZGEMM_OUTPUT_GS
|
||||
mov(r12, rcx) // rcx = c + 2*rs_c
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm12, ymm0)
|
||||
ZGEMM_OUTPUT_GS
|
||||
add(rdx, rcx) // c += 2*cs_c;
|
||||
|
||||
|
||||
|
||||
|
||||
vmovapd(ymm13, ymm0)
|
||||
ZGEMM_OUTPUT_GS
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
jmp(.ZDONE) // jump to end.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.ZROWSTORBZ)
|
||||
|
||||
|
||||
|
||||
vmovupd(ymm4, mem(rcx))
|
||||
vmovupd(ymm5, mem(rcx, rdx, 1))
|
||||
|
||||
|
||||
vmovupd(ymm8, mem(r11))
|
||||
vmovupd(ymm9, mem(r11, rdx, 1))
|
||||
|
||||
|
||||
vmovupd(ymm12, mem(r12))
|
||||
vmovupd(ymm13, mem(r12, rdx, 1))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
label(.ZDONE)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
: // input operands
|
||||
[alpha_mul_type] "m" (alpha_mul_type),
|
||||
[beta_mul_type] "m" (beta_mul_type),
|
||||
[k_iter] "m" (k_iter), // 0
|
||||
[k_left] "m" (k_left), // 1
|
||||
[a] "m" (a), // 2
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2020, Advanced Micro Devices, Inc.
|
||||
Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -56,6 +56,9 @@
|
||||
vmulpd(ymm2, ymm3, ymm3) \
|
||||
vaddsubpd(ymm3, ymm0, ymm0)
|
||||
|
||||
#define ZGEMM_INPUT_RS_BETA_ONE \
|
||||
vmovupd(mem(rcx), ymm0)
|
||||
|
||||
#define ZGEMM_OUTPUT_RS \
|
||||
vmovupd(ymm0, mem(rcx)) \
|
||||
|
||||
@@ -66,8 +69,11 @@
|
||||
vmulpd(ymm2, ymm3, ymm3) \
|
||||
vaddsubpd(ymm3, ymm0, ymm0)
|
||||
|
||||
#define ZGEMM_INPUT_RS_BETA_ONE_NEXT \
|
||||
vmovupd(mem(rcx, rsi, 8), ymm0)
|
||||
|
||||
#define ZGEMM_OUTPUT_RS_NEXT \
|
||||
vmovupd(ymm0, mem(rcx, rsi, 8)) \
|
||||
vmovupd(ymm0, mem(rcx, rsi, 8))
|
||||
|
||||
/*
|
||||
rrr:
|
||||
@@ -174,6 +180,41 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
|
||||
if ( m_iter == 0 ) goto consider_edge_cases;
|
||||
|
||||
//handling case when alpha and beta are real and +/-1.
|
||||
uint64_t alpha_real_one = *((uint64_t*)(&alpha->real));
|
||||
uint64_t beta_real_one = *((uint64_t*)(&beta->real));
|
||||
|
||||
uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1);
|
||||
uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1);
|
||||
|
||||
char alpha_mul_type = BLIS_MUL_DEFAULT;
|
||||
char beta_mul_type = BLIS_MUL_DEFAULT;
|
||||
|
||||
if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1)
|
||||
{
|
||||
alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1
|
||||
if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE)
|
||||
{
|
||||
alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1
|
||||
}
|
||||
}
|
||||
|
||||
if(beta->imag == 0)// beta is real
|
||||
{
|
||||
if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1)
|
||||
{
|
||||
beta_mul_type = BLIS_MUL_ONE;
|
||||
if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE)
|
||||
{
|
||||
beta_mul_type = BLIS_MUL_MINUS_ONE;
|
||||
}
|
||||
}
|
||||
else if(beta_real_one == 0)
|
||||
{
|
||||
beta_mul_type = BLIS_MUL_ZERO;
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
begin_asm()
|
||||
@@ -442,10 +483,30 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
vaddsubpd(ymm14, ymm12, ymm12)
|
||||
vaddsubpd(ymm15, ymm13, ymm13)
|
||||
|
||||
//if(alpha_mul_type == BLIS_MUL_MINUS_ONE)
|
||||
mov(var(alpha_mul_type), al)
|
||||
cmp(imm(0xFF), al)
|
||||
jne(.ALPHA_NOT_MINUS1)
|
||||
|
||||
// when alpha = -1 and real.
|
||||
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
|
||||
vsubpd(ymm4, ymm0, ymm4)
|
||||
vsubpd(ymm5, ymm0, ymm5)
|
||||
vsubpd(ymm8, ymm0, ymm8)
|
||||
vsubpd(ymm9, ymm0, ymm9)
|
||||
vsubpd(ymm12, ymm0, ymm12)
|
||||
vsubpd(ymm13, ymm0, ymm13)
|
||||
jmp(.ALPHA_REAL_ONE)
|
||||
|
||||
label(.ALPHA_NOT_MINUS1)
|
||||
//when alpha is real and +/-1, multiplication is skipped.
|
||||
cmp(imm(2), al)//if(alpha_mul_type != BLIS_MUL_DEFAULT) skip below multiplication.
|
||||
jne(.ALPHA_REAL_ONE)
|
||||
|
||||
/* (ar + ai) x AB */
|
||||
mov(var(alpha), rax) // load address of alpha
|
||||
vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate
|
||||
vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate
|
||||
mov(var(alpha), rax) // load address of alpha
|
||||
vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate
|
||||
vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate
|
||||
|
||||
vpermilpd(imm(0x5), ymm4, ymm3)
|
||||
vmulpd(ymm0, ymm4, ymm4)
|
||||
@@ -477,32 +538,93 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
vmulpd(ymm1, ymm3, ymm3)
|
||||
vaddsubpd(ymm3, ymm13, ymm13)
|
||||
|
||||
/* (ßr + ßi)x C + ((ar + ai) x AB) */
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
|
||||
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
|
||||
label(.ALPHA_REAL_ONE)
|
||||
// Beta multiplication
|
||||
/* (br + bi)x C + ((ar + ai) x AB) */
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
mov(var(cs_c), rsi) // load cs_c
|
||||
lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt)
|
||||
|
||||
lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c;
|
||||
lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c;
|
||||
|
||||
// now avoid loading C if beta == 0
|
||||
vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
|
||||
vucomisd(xmm0, xmm1) // set ZF if beta_r == 0.
|
||||
sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 );
|
||||
vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
|
||||
sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 );
|
||||
and(r13b, r15b) // set ZF if r13b & r15b == 1.
|
||||
jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
|
||||
mov(var(beta_mul_type), al)
|
||||
cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO)
|
||||
je(.SBETAZERO) //jump to beta == 0 case
|
||||
|
||||
lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a
|
||||
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16.
|
||||
cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16.
|
||||
jz(.SCOLSTORED) // jump to column storage case
|
||||
|
||||
label(.SROWSTORED)
|
||||
cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT)
|
||||
je(.ROW_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication.
|
||||
|
||||
cmp(imm(0xFF), al) // if(beta_mul_type == BLIS_MUL_MINUS_ONE)
|
||||
je(.ROW_BETA_REAL_MINUS1) // jump to beta real = -1 section.
|
||||
|
||||
//CASE 1: beta is real = 1
|
||||
ZGEMM_INPUT_RS_BETA_ONE
|
||||
vaddpd(ymm4, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
|
||||
ZGEMM_INPUT_RS_BETA_ONE_NEXT
|
||||
vaddpd(ymm5, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS_NEXT
|
||||
add(rdi, rcx) // rcx = c + 1*rs_c
|
||||
|
||||
ZGEMM_INPUT_RS_BETA_ONE
|
||||
vaddpd(ymm8, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
|
||||
ZGEMM_INPUT_RS_BETA_ONE_NEXT
|
||||
vaddpd(ymm9, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS_NEXT
|
||||
add(rdi, rcx) // rcx = c + 2*rs_c
|
||||
|
||||
ZGEMM_INPUT_RS_BETA_ONE
|
||||
vaddpd(ymm12, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
|
||||
ZGEMM_INPUT_RS_BETA_ONE_NEXT
|
||||
vaddpd(ymm13, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS_NEXT
|
||||
jmp(.SDONE)
|
||||
|
||||
|
||||
//CASE 2: beta is real = -1
|
||||
label(.ROW_BETA_REAL_MINUS1)
|
||||
ZGEMM_INPUT_RS_BETA_ONE
|
||||
vsubpd(ymm0, ymm4, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
|
||||
ZGEMM_INPUT_RS_BETA_ONE_NEXT
|
||||
vsubpd(ymm0, ymm5, ymm0)
|
||||
ZGEMM_OUTPUT_RS_NEXT
|
||||
add(rdi, rcx) // rcx = c + 1*rs_c
|
||||
|
||||
ZGEMM_INPUT_RS_BETA_ONE
|
||||
vsubpd(ymm0, ymm8, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
|
||||
ZGEMM_INPUT_RS_BETA_ONE_NEXT
|
||||
vsubpd(ymm0, ymm9, ymm0)
|
||||
ZGEMM_OUTPUT_RS_NEXT
|
||||
add(rdi, rcx) // rcx = c + 2*rs_c
|
||||
|
||||
ZGEMM_INPUT_RS_BETA_ONE
|
||||
vsubpd(ymm0, ymm12, ymm0)
|
||||
ZGEMM_OUTPUT_RS
|
||||
|
||||
ZGEMM_INPUT_RS_BETA_ONE_NEXT
|
||||
vsubpd(ymm0, ymm13, ymm0)
|
||||
ZGEMM_OUTPUT_RS_NEXT
|
||||
jmp(.SDONE)
|
||||
|
||||
|
||||
//CASE 3: Default case with multiplication
|
||||
// beta not equal to (+/-1) or zero, do normal multiplication.
|
||||
label(.ROW_BETA_NOT_REAL_ONE)
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
|
||||
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
|
||||
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ
|
||||
vaddpd(ymm4, ymm0, ymm0)
|
||||
@@ -529,10 +651,12 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT
|
||||
vaddpd(ymm13, ymm0, ymm0)
|
||||
ZGEMM_OUTPUT_RS_NEXT
|
||||
|
||||
jmp(.SDONE) // jump to end.
|
||||
|
||||
label(.SCOLSTORED)
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
|
||||
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
|
||||
/*|--------| |-------|
|
||||
| | | |
|
||||
| 3x4 | | 4x3 |
|
||||
@@ -679,6 +803,8 @@ void bli_zgemmsup_rv_zen_asm_3x4m
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
: // input operands
|
||||
[alpha_mul_type] "m" (alpha_mul_type),
|
||||
[beta_mul_type] "m" (beta_mul_type),
|
||||
[m_iter] "m" (m_iter),
|
||||
[k_iter] "m" (k_iter),
|
||||
[k_left] "m" (k_left),
|
||||
@@ -1025,7 +1151,7 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
vmulpd(ymm1, ymm3, ymm3)
|
||||
vaddsubpd(ymm3, ymm12, ymm12)
|
||||
|
||||
/* (ßr + ßi)x C + ((ar + ai) x AB) */
|
||||
/* (br + bi)x C + ((ar + ai) x AB) */
|
||||
mov(var(beta), rbx) // load address of beta
|
||||
vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
|
||||
vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
|
||||
@@ -1226,4 +1352,4 @@ void bli_zgemmsup_rv_zen_asm_3x2m
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2020, Advanced Micro Devices, Inc.
|
||||
Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -123,6 +123,41 @@ void bli_zgemmsup_rv_zen_asm_3x4n
|
||||
|
||||
if ( n_iter == 0 ) goto consider_edge_cases;
|
||||
|
||||
//handling case when alpha and beta are real and +/-1.
|
||||
uint64_t alpha_real_one = *((uint64_t*)(&alpha->real));
|
||||
uint64_t beta_real_one = *((uint64_t*)(&beta->real));
|
||||
|
||||
uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1);
|
||||
uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1);
|
||||
|
||||
char alpha_mul_type = BLIS_MUL_DEFAULT;
|
||||
char beta_mul_type = BLIS_MUL_DEFAULT;
|
||||
|
||||
if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1)
|
||||
{
|
||||
alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1
|
||||
if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE)
|
||||
{
|
||||
alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1
|
||||
}
|
||||
}
|
||||
|
||||
if(beta->imag == 0)// beta is real
|
||||
{
|
||||
if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1)
|
||||
{
|
||||
beta_mul_type = BLIS_MUL_ONE;
|
||||
if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE)
|
||||
{
|
||||
beta_mul_type = BLIS_MUL_MINUS_ONE;
|
||||
}
|
||||
}
|
||||
else if(beta_real_one == 0)
|
||||
{
|
||||
beta_mul_type = BLIS_MUL_ZERO;
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
//scratch registers
|
||||
__m256d ymm0, ymm1, ymm2, ymm3;
|
||||
@@ -217,45 +252,60 @@ void bli_zgemmsup_rv_zen_asm_3x4n
|
||||
ymm12 = _mm256_addsub_pd( ymm12, ymm14);
|
||||
ymm13 = _mm256_addsub_pd( ymm13, ymm15);
|
||||
|
||||
// alpha, beta multiplication.
|
||||
//When alpha_real = -1.0, instead of multiplying with -1, sign is changed.
|
||||
if(alpha_mul_type == BLIS_MUL_MINUS_ONE)// equivalent to if(alpha->real == -1.0)
|
||||
{
|
||||
ymm0 = _mm256_setzero_pd();
|
||||
ymm4 = _mm256_sub_pd(ymm0,ymm4);
|
||||
ymm5 = _mm256_sub_pd(ymm0, ymm5);
|
||||
ymm8 = _mm256_sub_pd(ymm0, ymm8);
|
||||
ymm9 = _mm256_sub_pd(ymm0, ymm9);
|
||||
ymm12 = _mm256_sub_pd(ymm0, ymm12);
|
||||
ymm13 = _mm256_sub_pd(ymm0, ymm13);
|
||||
}
|
||||
|
||||
/* (ar + ai) x AB */
|
||||
ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate
|
||||
ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate
|
||||
//when alpha is real and +/-1, multiplication is skipped.
|
||||
if(alpha_mul_type == BLIS_MUL_DEFAULT)
|
||||
{
|
||||
// alpha, beta multiplication.
|
||||
/* (ar + ai) x AB */
|
||||
ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate
|
||||
ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate
|
||||
|
||||
ymm3 = _mm256_permute_pd(ymm4, 5);
|
||||
ymm4 = _mm256_mul_pd(ymm0, ymm4);
|
||||
ymm3 =_mm256_mul_pd(ymm1, ymm3);
|
||||
ymm4 = _mm256_addsub_pd(ymm4, ymm3);
|
||||
ymm3 = _mm256_permute_pd(ymm4, 5);
|
||||
ymm4 = _mm256_mul_pd(ymm0, ymm4);
|
||||
ymm3 =_mm256_mul_pd(ymm1, ymm3);
|
||||
ymm4 = _mm256_addsub_pd(ymm4, ymm3);
|
||||
|
||||
ymm3 = _mm256_permute_pd(ymm5, 5);
|
||||
ymm5 = _mm256_mul_pd(ymm0, ymm5);
|
||||
ymm3 = _mm256_mul_pd(ymm1, ymm3);
|
||||
ymm5 = _mm256_addsub_pd(ymm5, ymm3);
|
||||
ymm3 = _mm256_permute_pd(ymm5, 5);
|
||||
ymm5 = _mm256_mul_pd(ymm0, ymm5);
|
||||
ymm3 = _mm256_mul_pd(ymm1, ymm3);
|
||||
ymm5 = _mm256_addsub_pd(ymm5, ymm3);
|
||||
|
||||
ymm3 = _mm256_permute_pd(ymm8, 5);
|
||||
ymm8 = _mm256_mul_pd(ymm0, ymm8);
|
||||
ymm3 = _mm256_mul_pd(ymm1, ymm3);
|
||||
ymm8 = _mm256_addsub_pd(ymm8, ymm3);
|
||||
ymm3 = _mm256_permute_pd(ymm8, 5);
|
||||
ymm8 = _mm256_mul_pd(ymm0, ymm8);
|
||||
ymm3 = _mm256_mul_pd(ymm1, ymm3);
|
||||
ymm8 = _mm256_addsub_pd(ymm8, ymm3);
|
||||
|
||||
ymm3 = _mm256_permute_pd(ymm9, 5);
|
||||
ymm9 = _mm256_mul_pd(ymm0, ymm9);
|
||||
ymm3 = _mm256_mul_pd(ymm1, ymm3);
|
||||
ymm9 = _mm256_addsub_pd(ymm9, ymm3);
|
||||
ymm3 = _mm256_permute_pd(ymm9, 5);
|
||||
ymm9 = _mm256_mul_pd(ymm0, ymm9);
|
||||
ymm3 = _mm256_mul_pd(ymm1, ymm3);
|
||||
ymm9 = _mm256_addsub_pd(ymm9, ymm3);
|
||||
|
||||
ymm3 = _mm256_permute_pd(ymm12, 5);
|
||||
ymm12 = _mm256_mul_pd(ymm0, ymm12);
|
||||
ymm3 = _mm256_mul_pd(ymm1, ymm3);
|
||||
ymm12 = _mm256_addsub_pd(ymm12, ymm3);
|
||||
ymm3 = _mm256_permute_pd(ymm12, 5);
|
||||
ymm12 = _mm256_mul_pd(ymm0, ymm12);
|
||||
ymm3 = _mm256_mul_pd(ymm1, ymm3);
|
||||
ymm12 = _mm256_addsub_pd(ymm12, ymm3);
|
||||
|
||||
ymm3 = _mm256_permute_pd(ymm13, 5);
|
||||
ymm13 = _mm256_mul_pd(ymm0, ymm13);
|
||||
ymm3 = _mm256_mul_pd(ymm1, ymm3);
|
||||
ymm13 = _mm256_addsub_pd(ymm13, ymm3);
|
||||
ymm3 = _mm256_permute_pd(ymm13, 5);
|
||||
ymm13 = _mm256_mul_pd(ymm0, ymm13);
|
||||
ymm3 = _mm256_mul_pd(ymm1, ymm3);
|
||||
ymm13 = _mm256_addsub_pd(ymm13, ymm3);
|
||||
}
|
||||
|
||||
if(tc_inc_row == 1) //col stored
|
||||
{
|
||||
if(beta->real == 0.0 && beta->imag == 0.0)
|
||||
if(beta_mul_type == BLIS_MUL_ZERO)
|
||||
{
|
||||
//transpose left 3x2
|
||||
_mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4));
|
||||
@@ -364,7 +414,7 @@ void bli_zgemmsup_rv_zen_asm_3x4n
|
||||
}
|
||||
else
|
||||
{
|
||||
if(beta->real == 0.0 && beta->imag == 0.0)
|
||||
if(beta_mul_type == BLIS_MUL_ZERO)
|
||||
{
|
||||
_mm256_storeu_pd((double *)(tC), ymm4);
|
||||
_mm256_storeu_pd((double *)(tC + 2), ymm5);
|
||||
@@ -373,6 +423,28 @@ void bli_zgemmsup_rv_zen_asm_3x4n
|
||||
_mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12);
|
||||
_mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13);
|
||||
}
|
||||
else if(beta_mul_type == BLIS_MUL_ONE)// equivalent to if(beta->real == 1.0)
|
||||
{
|
||||
ymm2 = _mm256_loadu_pd((double const *)(tC));
|
||||
ymm4 = _mm256_add_pd(ymm4,ymm2);
|
||||
ymm2 = _mm256_loadu_pd((double const *)(tC+2));
|
||||
ymm5 = _mm256_add_pd(ymm5,ymm2);
|
||||
ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row));
|
||||
ymm8 = _mm256_add_pd(ymm8,ymm2);
|
||||
ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2));
|
||||
ymm9 = _mm256_add_pd(ymm9,ymm2);
|
||||
ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2));
|
||||
ymm12 = _mm256_add_pd(ymm12,ymm2);
|
||||
ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2));
|
||||
ymm13 = _mm256_add_pd(ymm13,ymm2);
|
||||
|
||||
_mm256_storeu_pd((double *)(tC), ymm4);
|
||||
_mm256_storeu_pd((double *)(tC + 2), ymm5);
|
||||
_mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8);
|
||||
_mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9);
|
||||
_mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12);
|
||||
_mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13);
|
||||
}
|
||||
else{
|
||||
/* (br + bi) C + (ar + ai) AB */
|
||||
ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate
|
||||
|
||||
Reference in New Issue
Block a user