From 3ab9104daebc5c7d7ec2e855c80b718a06825b4c Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Fri, 11 Dec 2020 11:57:59 +0530 Subject: [PATCH] Handling zgemm real(+/-1) alpha and beta 1.Improved performance when zgemm's alpha and beta are real and equal to +/-1. 2.change done in bli_zgemmsup_rv_zen_asm_3x4n. 3.change done in bli_zgemmsup_rv_zen_asm_3x4m. 4.change done in bli_zgemm_haswell_asm_3x4. Change-Id: Ic14d8507b264c24a8748febf6bc73eb60e476430 AMD-Internal: [CPUPL-1352] --- frame/base/bli_const.h | 10 + kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c | 446 +++++++++++------- .../zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c | 180 +++++-- .../zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c | 134 ++++-- 4 files changed, 534 insertions(+), 236 deletions(-) diff --git a/frame/base/bli_const.h b/frame/base/bli_const.h index 1b9799482..781b56cb8 100644 --- a/frame/base/bli_const.h +++ b/frame/base/bli_const.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,3 +36,12 @@ void bli_const_init( void ); void bli_const_finalize( void ); +/* constant used to check 1 and -1 when double converted uint64 */ +#define BLIS_DOUBLE_TO_UINT64_ONE_ABS 0x3ff0000000000000 +#define BLIS_DOUBLE_TO_UINT64_MINUS_ONE 0xbff0000000000000 + +/* enum used to clasify alpha and beta to one of the below category */ +enum mulfactor { BLIS_MUL_MINUS_ONE = -1, + BLIS_MUL_ZERO, + BLIS_MUL_ONE, + BLIS_MUL_DEFAULT }; \ No newline at end of file diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index 4a21b2f76..315894b17 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -2197,10 +2197,11 @@ void bli_cgemm_haswell_asm_3x8 vmulpd(ymm1, ymm0, ymm0) \ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) - + #define ZGEMM_OUTPUT_RS \ vmovupd(ymm0, mem(rcx)) \ + void bli_zgemm_haswell_asm_3x4 ( dim_t k0, @@ -2223,71 +2224,105 @@ void bli_zgemm_haswell_asm_3x4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + //handling case when alpha and beta are real and +/-1. + uint64_t alpha_real_one = *((uint64_t*)(&alpha->real)); + uint64_t beta_real_one = *((uint64_t*)(&beta->real)); + + uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1); + uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1); + + char alpha_mul_type = BLIS_MUL_DEFAULT; + char beta_mul_type = BLIS_MUL_DEFAULT; + + if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1) + { + alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1 + if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) + { + alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1 + } + } + + if(beta->imag == 0)// beta is real + { + if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1) + { + beta_mul_type = BLIS_MUL_ONE; + if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) + { + beta_mul_type = BLIS_MUL_MINUS_ONE; + } + } + else if(beta_real_one == 0) + { + beta_mul_type = BLIS_MUL_ZERO; + } + } + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rbx) // initialize loop by pre-loading vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) - + lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*rs_c; lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*rs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c prefetch(0, mem(r11, 7*8)) // prefetch c + 1*rs_c prefetch(0, mem(r12, 7*8)) // prefetch c + 2*rs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.ZLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 32*16)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - + // iteration 1 vbroadcastsd(mem(rax, 6*8), ymm2) vbroadcastsd(mem(rax, 7*8), ymm3) @@ -2295,51 +2330,51 @@ void bli_zgemm_haswell_asm_3x4 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 8*8), ymm2) vbroadcastsd(mem(rax, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 10*8), ymm2) vbroadcastsd(mem(rax, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 38*16)) - + vbroadcastsd(mem(rax, 12*8), ymm2) vbroadcastsd(mem(rax, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 14*8), ymm2) vbroadcastsd(mem(rax, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 16*8), ymm2) vbroadcastsd(mem(rax, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) @@ -2347,83 +2382,83 @@ void bli_zgemm_haswell_asm_3x4 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 20*8), ymm2) vbroadcastsd(mem(rax, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 22*8), ymm2) vbroadcastsd(mem(rax, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*3*16), rax) // a += 4*3 (unroll x mr) add(imm(4*4*16), rbx) // b += 4*4 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 32*16)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*3*16), rax) // a += 1*3 (unroll x mr) add(imm(1*4*16), rbx) // b += 1*4 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.ZPOSTACCUM) - + // permute even and odd elements // of ymm6/7, ymm10/11, ymm/14/15 vpermilpd(imm(0x5), ymm6, ymm6) @@ -2432,251 +2467,306 @@ void bli_zgemm_haswell_asm_3x4 vpermilpd(imm(0x5), ymm11, ymm11) vpermilpd(imm(0x5), ymm14, ymm14) vpermilpd(imm(0x5), ymm15, ymm15) - - + + // subtract/add even/odd elements vaddsubpd(ymm6, ymm4, ymm4) vaddsubpd(ymm7, ymm5, ymm5) - + vaddsubpd(ymm10, ymm8, ymm8) vaddsubpd(ymm11, ymm9, ymm9) - + vaddsubpd(ymm14, ymm12, ymm12) vaddsubpd(ymm15, ymm13, ymm13) - - - - + + //if(alpha_mul_type == BLIS_MUL_MINUS_ONE) + mov(var(alpha_mul_type), al) + cmp(imm(0xFF), al) + jne(.ALPHA_NOT_MINUS1) + + // when alpha = -1 and real. + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vsubpd(ymm4, ymm0, ymm4) + vsubpd(ymm5, ymm0, ymm5) + vsubpd(ymm8, ymm0, ymm8) + vsubpd(ymm9, ymm0, ymm9) + vsubpd(ymm12, ymm0, ymm12) + vsubpd(ymm13, ymm0, ymm13) + jmp(.ALPHA_REAL_ONE) + + label(.ALPHA_NOT_MINUS1) + //when alpha is real and +/-1, multiplication is skipped. + cmp(imm(2), al)//if(alpha_mul_type != BLIS_MUL_DEFAULT) skip below multiplication. + jne(.ALPHA_REAL_ONE) + + mov(var(alpha), rax) // load address of alpha vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate - - + + vpermilpd(imm(0x5), ymm4, ymm3) vmulpd(ymm0, ymm4, ymm4) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm4, ymm4) - + vpermilpd(imm(0x5), ymm5, ymm3) vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm5, ymm5) - - + + vpermilpd(imm(0x5), ymm8, ymm3) vmulpd(ymm0, ymm8, ymm8) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm9, ymm3) vmulpd(ymm0, ymm9, ymm9) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm9, ymm9) - - + + vpermilpd(imm(0x5), ymm12, ymm3) vmulpd(ymm0, ymm12, ymm12) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm12, ymm12) - + vpermilpd(imm(0x5), ymm13, ymm3) vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm13, ymm13) - - - - - - mov(var(beta), rbx) // load address of beta - vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + + + + label(.ALPHA_REAL_ONE) + // Beta multiplication + /* (br + bi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - - - - + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dcomplex) lea(mem(, rsi, 2), rsi) lea(mem(, rsi, 2), rdx) // rdx = 2*cs_c; - - - - // now avoid loading C if beta == 0 - vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. - vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. - sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); - vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. - sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); - and(r8b, r9b) // set ZF if r8b & r9b == 1. - jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + // now avoid loading C if beta == 0 + mov(var(beta_mul_type), al) + cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO) + je(.ZBETAZERO) //jump to beta == 0 case + + cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. jz(.ZROWSTORED) // jump to row storage case - - - + + label(.ZGENSTORED) - - ZGEMM_INPUT_SCALE_GS_BETA_NZ vaddpd(ymm4, ymm0, ymm0) ZGEMM_OUTPUT_GS add(rdx, rcx) // c += 2*cs_c; - - + ZGEMM_INPUT_SCALE_GS_BETA_NZ vaddpd(ymm5, ymm0, ymm0) ZGEMM_OUTPUT_GS mov(r11, rcx) // rcx = c + 1*rs_c - - - + ZGEMM_INPUT_SCALE_GS_BETA_NZ vaddpd(ymm8, ymm0, ymm0) ZGEMM_OUTPUT_GS add(rdx, rcx) // c += 2*cs_c; - - + ZGEMM_INPUT_SCALE_GS_BETA_NZ vaddpd(ymm9, ymm0, ymm0) ZGEMM_OUTPUT_GS mov(r12, rcx) // rcx = c + 2*rs_c - - - + ZGEMM_INPUT_SCALE_GS_BETA_NZ vaddpd(ymm12, ymm0, ymm0) ZGEMM_OUTPUT_GS add(rdx, rcx) // c += 2*cs_c; - - + ZGEMM_INPUT_SCALE_GS_BETA_NZ vaddpd(ymm13, ymm0, ymm0) ZGEMM_OUTPUT_GS - - - jmp(.ZDONE) // jump to end. - - - + + + + /* Row stored of C */ label(.ZROWSTORED) - - + cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT) + je(.GEN_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication. + + cmp(imm(0xFF), al) // if(beta_mul_type == BLIS_MUL_MINUS_ONE) + je(.GEN_BETA_REAL_MINUS1) // jump to beta real = -1 section. + + //CASE 1: beta is real = 1 + label(.GEN_BETA_REAL_ONE) + vmovupd(mem(rcx), ymm0) + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS + add(rdx, rcx) // c += 2*cs_c; + + vmovupd(mem(rcx), ymm0) + vaddpd(ymm5, ymm0, ymm0) + ZGEMM_OUTPUT_RS + mov(r11, rcx) // rcx = c + 1*rs_c + + vmovupd(mem(rcx), ymm0) + vaddpd(ymm8, ymm0, ymm0) + ZGEMM_OUTPUT_RS + add(rdx, rcx) // c += 2*cs_c; + + vmovupd(mem(rcx), ymm0) + vaddpd(ymm9, ymm0, ymm0) + ZGEMM_OUTPUT_RS + mov(r12, rcx) // rcx = c + 2*rs_c + + vmovupd(mem(rcx), ymm0) + vaddpd(ymm12, ymm0, ymm0) + ZGEMM_OUTPUT_RS + add(rdx, rcx) // c += 2*cs_c; + + vmovupd(mem(rcx), ymm0) + vaddpd(ymm13, ymm0, ymm0) + ZGEMM_OUTPUT_RS + jmp(.ZDONE) // jump to end. + + //CASE 2: beta is real = -1 + label(.GEN_BETA_REAL_MINUS1) + vmovupd(mem(rcx), ymm0) + vsubpd(ymm0, ymm4, ymm0) + ZGEMM_OUTPUT_RS + add(rdx, rcx) // c += 2*cs_c; + + vmovupd(mem(rcx), ymm0) + vsubpd(ymm0, ymm5, ymm0) + ZGEMM_OUTPUT_RS + mov(r11, rcx) // rcx = c + 1*rs_c + + vmovupd(mem(rcx), ymm0) + vsubpd(ymm0, ymm8, ymm0) + ZGEMM_OUTPUT_RS + add(rdx, rcx) // c += 2*cs_c; + + vmovupd(mem(rcx), ymm0) + vsubpd(ymm0, ymm9, ymm0) + ZGEMM_OUTPUT_RS + mov(r12, rcx) // rcx = c + 2*rs_c + + vmovupd(mem(rcx), ymm0) + vsubpd(ymm0, ymm12, ymm0) + ZGEMM_OUTPUT_RS + add(rdx, rcx) // c += 2*cs_c; + + vmovupd(mem(rcx), ymm0) + vsubpd(ymm0, ymm13, ymm0) + ZGEMM_OUTPUT_RS + jmp(.ZDONE) // jump to end. + + //CASE 3: Default case with multiplication + // beta not equal to (+/-1) or zero, do normal multiplication. + label(.GEN_BETA_NOT_REAL_ONE) ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm4, ymm0, ymm0) ZGEMM_OUTPUT_RS add(rdx, rcx) // c += 2*cs_c; - - + ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm5, ymm0, ymm0) ZGEMM_OUTPUT_RS mov(r11, rcx) // rcx = c + 1*rs_c - - - + ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm8, ymm0, ymm0) ZGEMM_OUTPUT_RS add(rdx, rcx) // c += 2*cs_c; - - + ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm9, ymm0, ymm0) ZGEMM_OUTPUT_RS mov(r12, rcx) // rcx = c + 2*rs_c - - - + ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm12, ymm0, ymm0) ZGEMM_OUTPUT_RS add(rdx, rcx) // c += 2*cs_c; - - + ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm13, ymm0, ymm0) ZGEMM_OUTPUT_RS - - - jmp(.ZDONE) // jump to end. - - - + + + label(.ZBETAZERO) - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. jz(.ZROWSTORBZ) // jump to row storage case - - - + + + label(.ZGENSTORBZ) - - + + vmovapd(ymm4, ymm0) ZGEMM_OUTPUT_GS add(rdx, rcx) // c += 2*cs_c; - - + + vmovapd(ymm5, ymm0) ZGEMM_OUTPUT_GS mov(r11, rcx) // rcx = c + 1*rs_c - - - + + + vmovapd(ymm8, ymm0) ZGEMM_OUTPUT_GS add(rdx, rcx) // c += 2*cs_c; - - + + vmovapd(ymm9, ymm0) ZGEMM_OUTPUT_GS mov(r12, rcx) // rcx = c + 2*rs_c - - - + + + vmovapd(ymm12, ymm0) ZGEMM_OUTPUT_GS add(rdx, rcx) // c += 2*cs_c; - - + + vmovapd(ymm13, ymm0) ZGEMM_OUTPUT_GS - - - + + + jmp(.ZDONE) // jump to end. - - - + + + label(.ZROWSTORBZ) - - + vmovupd(ymm4, mem(rcx)) vmovupd(ymm5, mem(rcx, rdx, 1)) - + vmovupd(ymm8, mem(r11)) vmovupd(ymm9, mem(r11, rdx, 1)) - + vmovupd(ymm12, mem(r12)) vmovupd(ymm13, mem(r12, rdx, 1)) - - - - - - + label(.ZDONE) - - + + end_asm( : // output operands (none) : // input operands + [alpha_mul_type] "m" (alpha_mul_type), + [beta_mul_type] "m" (beta_mul_type), [k_iter] "m" (k_iter), // 0 [k_left] "m" (k_left), // 1 [a] "m" (a), // 2 diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c index 05e05dfec..f5b42c623 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -56,6 +56,9 @@ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) +#define ZGEMM_INPUT_RS_BETA_ONE \ + vmovupd(mem(rcx), ymm0) + #define ZGEMM_OUTPUT_RS \ vmovupd(ymm0, mem(rcx)) \ @@ -66,8 +69,11 @@ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) +#define ZGEMM_INPUT_RS_BETA_ONE_NEXT \ + vmovupd(mem(rcx, rsi, 8), ymm0) + #define ZGEMM_OUTPUT_RS_NEXT \ - vmovupd(ymm0, mem(rcx, rsi, 8)) \ + vmovupd(ymm0, mem(rcx, rsi, 8)) /* rrr: @@ -174,6 +180,41 @@ void bli_zgemmsup_rv_zen_asm_3x4m if ( m_iter == 0 ) goto consider_edge_cases; + //handling case when alpha and beta are real and +/-1. + uint64_t alpha_real_one = *((uint64_t*)(&alpha->real)); + uint64_t beta_real_one = *((uint64_t*)(&beta->real)); + + uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1); + uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1); + + char alpha_mul_type = BLIS_MUL_DEFAULT; + char beta_mul_type = BLIS_MUL_DEFAULT; + + if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1) + { + alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1 + if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) + { + alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1 + } + } + + if(beta->imag == 0)// beta is real + { + if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1) + { + beta_mul_type = BLIS_MUL_ONE; + if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) + { + beta_mul_type = BLIS_MUL_MINUS_ONE; + } + } + else if(beta_real_one == 0) + { + beta_mul_type = BLIS_MUL_ZERO; + } + } + // ------------------------------------------------------------------------- begin_asm() @@ -442,10 +483,30 @@ void bli_zgemmsup_rv_zen_asm_3x4m vaddsubpd(ymm14, ymm12, ymm12) vaddsubpd(ymm15, ymm13, ymm13) + //if(alpha_mul_type == BLIS_MUL_MINUS_ONE) + mov(var(alpha_mul_type), al) + cmp(imm(0xFF), al) + jne(.ALPHA_NOT_MINUS1) + + // when alpha = -1 and real. + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vsubpd(ymm4, ymm0, ymm4) + vsubpd(ymm5, ymm0, ymm5) + vsubpd(ymm8, ymm0, ymm8) + vsubpd(ymm9, ymm0, ymm9) + vsubpd(ymm12, ymm0, ymm12) + vsubpd(ymm13, ymm0, ymm13) + jmp(.ALPHA_REAL_ONE) + + label(.ALPHA_NOT_MINUS1) + //when alpha is real and +/-1, multiplication is skipped. + cmp(imm(2), al)//if(alpha_mul_type != BLIS_MUL_DEFAULT) skip below multiplication. + jne(.ALPHA_REAL_ONE) + /* (ar + ai) x AB */ - mov(var(alpha), rax) // load address of alpha - vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate - vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + mov(var(alpha), rax) // load address of alpha + vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate vpermilpd(imm(0x5), ymm4, ymm3) vmulpd(ymm0, ymm4, ymm4) @@ -477,32 +538,93 @@ void bli_zgemmsup_rv_zen_asm_3x4m vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm13, ymm13) - /* (ßr + ßi)x C + ((ar + ai) x AB) */ - mov(var(beta), rbx) // load address of beta - vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate - vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + label(.ALPHA_REAL_ONE) + // Beta multiplication + /* (br + bi)x C + ((ar + ai) x AB) */ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) - - lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - // now avoid loading C if beta == 0 - vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. - vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. - sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); - vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. - sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); - and(r13b, r15b) // set ZF if r13b & r15b == 1. - jne(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + mov(var(beta_mul_type), al) + cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO) + je(.SBETAZERO) //jump to beta == 0 case lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a - cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16. + cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16. jz(.SCOLSTORED) // jump to column storage case label(.SROWSTORED) + cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT) + je(.ROW_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication. + + cmp(imm(0xFF), al) // if(beta_mul_type == BLIS_MUL_MINUS_ONE) + je(.ROW_BETA_REAL_MINUS1) // jump to beta real = -1 section. + + //CASE 1: beta is real = 1 + ZGEMM_INPUT_RS_BETA_ONE + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vaddpd(ymm5, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_RS_BETA_ONE + vaddpd(ymm8, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vaddpd(ymm9, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 2*rs_c + + ZGEMM_INPUT_RS_BETA_ONE + vaddpd(ymm12, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vaddpd(ymm13, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + jmp(.SDONE) + + + //CASE 2: beta is real = -1 + label(.ROW_BETA_REAL_MINUS1) + ZGEMM_INPUT_RS_BETA_ONE + vsubpd(ymm0, ymm4, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vsubpd(ymm0, ymm5, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_RS_BETA_ONE + vsubpd(ymm0, ymm8, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vsubpd(ymm0, ymm9, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 2*rs_c + + ZGEMM_INPUT_RS_BETA_ONE + vsubpd(ymm0, ymm12, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vsubpd(ymm0, ymm13, ymm0) + ZGEMM_OUTPUT_RS_NEXT + jmp(.SDONE) + + + //CASE 3: Default case with multiplication + // beta not equal to (+/-1) or zero, do normal multiplication. + label(.ROW_BETA_NOT_REAL_ONE) + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm4, ymm0, ymm0) @@ -529,10 +651,12 @@ void bli_zgemmsup_rv_zen_asm_3x4m ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT vaddpd(ymm13, ymm0, ymm0) ZGEMM_OUTPUT_RS_NEXT - jmp(.SDONE) // jump to end. label(.SCOLSTORED) + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate /*|--------| |-------| | | | | | 3x4 | | 4x3 | @@ -679,6 +803,8 @@ void bli_zgemmsup_rv_zen_asm_3x4m end_asm( : // output operands (none) : // input operands + [alpha_mul_type] "m" (alpha_mul_type), + [beta_mul_type] "m" (beta_mul_type), [m_iter] "m" (m_iter), [k_iter] "m" (k_iter), [k_left] "m" (k_left), @@ -1025,7 +1151,7 @@ void bli_zgemmsup_rv_zen_asm_3x2m vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm12, ymm12) - /* (ßr + ßi)x C + ((ar + ai) x AB) */ + /* (br + bi)x C + ((ar + ai) x AB) */ mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate @@ -1226,4 +1352,4 @@ void bli_zgemmsup_rv_zen_asm_3x2m ); return; } -} +} \ No newline at end of file diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c index 872d04868..072f5262c 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -123,6 +123,41 @@ void bli_zgemmsup_rv_zen_asm_3x4n if ( n_iter == 0 ) goto consider_edge_cases; + //handling case when alpha and beta are real and +/-1. + uint64_t alpha_real_one = *((uint64_t*)(&alpha->real)); + uint64_t beta_real_one = *((uint64_t*)(&beta->real)); + + uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1); + uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1); + + char alpha_mul_type = BLIS_MUL_DEFAULT; + char beta_mul_type = BLIS_MUL_DEFAULT; + + if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1) + { + alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1 + if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) + { + alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1 + } + } + + if(beta->imag == 0)// beta is real + { + if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1) + { + beta_mul_type = BLIS_MUL_ONE; + if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) + { + beta_mul_type = BLIS_MUL_MINUS_ONE; + } + } + else if(beta_real_one == 0) + { + beta_mul_type = BLIS_MUL_ZERO; + } + } + // ------------------------------------------------------------------------- //scratch registers __m256d ymm0, ymm1, ymm2, ymm3; @@ -217,45 +252,60 @@ void bli_zgemmsup_rv_zen_asm_3x4n ymm12 = _mm256_addsub_pd( ymm12, ymm14); ymm13 = _mm256_addsub_pd( ymm13, ymm15); - // alpha, beta multiplication. + //When alpha_real = -1.0, instead of multiplying with -1, sign is changed. + if(alpha_mul_type == BLIS_MUL_MINUS_ONE)// equivalent to if(alpha->real == -1.0) + { + ymm0 = _mm256_setzero_pd(); + ymm4 = _mm256_sub_pd(ymm0,ymm4); + ymm5 = _mm256_sub_pd(ymm0, ymm5); + ymm8 = _mm256_sub_pd(ymm0, ymm8); + ymm9 = _mm256_sub_pd(ymm0, ymm9); + ymm12 = _mm256_sub_pd(ymm0, ymm12); + ymm13 = _mm256_sub_pd(ymm0, ymm13); + } - /* (ar + ai) x AB */ - ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate - ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate + //when alpha is real and +/-1, multiplication is skipped. + if(alpha_mul_type == BLIS_MUL_DEFAULT) + { + // alpha, beta multiplication. + /* (ar + ai) x AB */ + ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate - ymm3 = _mm256_permute_pd(ymm4, 5); - ymm4 = _mm256_mul_pd(ymm0, ymm4); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_addsub_pd(ymm4, ymm3); + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); - ymm3 = _mm256_permute_pd(ymm5, 5); - ymm5 = _mm256_mul_pd(ymm0, ymm5); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_addsub_pd(ymm5, ymm3); + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); - ymm3 = _mm256_permute_pd(ymm8, 5); - ymm8 = _mm256_mul_pd(ymm0, ymm8); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_addsub_pd(ymm8, ymm3); + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); - ymm3 = _mm256_permute_pd(ymm9, 5); - ymm9 = _mm256_mul_pd(ymm0, ymm9); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm9 = _mm256_addsub_pd(ymm9, ymm3); + ymm3 = _mm256_permute_pd(ymm9, 5); + ymm9 = _mm256_mul_pd(ymm0, ymm9); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_addsub_pd(ymm9, ymm3); - ymm3 = _mm256_permute_pd(ymm12, 5); - ymm12 = _mm256_mul_pd(ymm0, ymm12); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm12 = _mm256_addsub_pd(ymm12, ymm3); + ymm3 = _mm256_permute_pd(ymm12, 5); + ymm12 = _mm256_mul_pd(ymm0, ymm12); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_addsub_pd(ymm12, ymm3); - ymm3 = _mm256_permute_pd(ymm13, 5); - ymm13 = _mm256_mul_pd(ymm0, ymm13); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm13 = _mm256_addsub_pd(ymm13, ymm3); + ymm3 = _mm256_permute_pd(ymm13, 5); + ymm13 = _mm256_mul_pd(ymm0, ymm13); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm13 = _mm256_addsub_pd(ymm13, ymm3); + } if(tc_inc_row == 1) //col stored { - if(beta->real == 0.0 && beta->imag == 0.0) + if(beta_mul_type == BLIS_MUL_ZERO) { //transpose left 3x2 _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); @@ -364,7 +414,7 @@ void bli_zgemmsup_rv_zen_asm_3x4n } else { - if(beta->real == 0.0 && beta->imag == 0.0) + if(beta_mul_type == BLIS_MUL_ZERO) { _mm256_storeu_pd((double *)(tC), ymm4); _mm256_storeu_pd((double *)(tC + 2), ymm5); @@ -373,6 +423,28 @@ void bli_zgemmsup_rv_zen_asm_3x4n _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); } + else if(beta_mul_type == BLIS_MUL_ONE)// equivalent to if(beta->real == 1.0) + { + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm4 = _mm256_add_pd(ymm4,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm5 = _mm256_add_pd(ymm5,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm8 = _mm256_add_pd(ymm8,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); + ymm9 = _mm256_add_pd(ymm9,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); + ymm12 = _mm256_add_pd(ymm12,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2)); + ymm13 = _mm256_add_pd(ymm13,ymm2); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); + } else{ /* (br + bi) C + (ar + ai) AB */ ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate