From f5c03e9fe808f9bd8a3e0c62786334e13c46b0fc Mon Sep 17 00:00:00 2001 From: RuQing Xu Date: Sun, 3 Oct 2021 16:51:51 +0900 Subject: [PATCH] Armv8 Handle *beta == 0 for GEMMSUP ?rc Case. --- .../3/sup/bli_gemmsup_rd_armv8a_asm_d6x8m.c | 40 ++++++----- .../3/sup/bli_gemmsup_rd_armv8a_asm_d6x8n.c | 38 +++++++---- .../sup/d3x4/bli_gemmsup_rd_armv8a_asm_d3x4.c | 27 +++++--- .../sup/d3x4/bli_gemmsup_rd_armv8a_asm_d6x3.c | 31 +++++---- .../sup/d3x4/bli_gemmsup_rd_armv8a_int_d2x8.c | 67 ++++++++++++------- .../sup/d3x4/bli_gemmsup_rd_armv8a_int_d3x4.c | 40 +++++++---- 6 files changed, 154 insertions(+), 89 deletions(-) diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8m.c b/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8m.c index 7046c33a4..e0ab95d82 100644 --- a/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8m.c +++ b/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8m.c @@ -372,6 +372,12 @@ LABEL(WRITE_MEM_PREP) " ld1r {v30.2d}, [x4] \n\t" // Load alpha & beta (value). " ld1r {v31.2d}, [x8] \n\t" " \n\t" +" fmov d28, #1.0 \n\t" // Don't scale for unit alpha. +" fcmp d30, d28 \n\t" +BEQ(UNIT_ALPHA) +DSCALE12V(0,1,2,3,4,5,6,7,8,9,10,11,30,0) +LABEL(UNIT_ALPHA) +" \n\t" " mov x1, x5 \n\t" // C address for loading. " \n\t" // C address for storing is x5 itself. " cmp x7, #8 \n\t" // Check for column-storage. @@ -379,11 +385,13 @@ BNE(WRITE_MEM_C) // // C storage in rows. LABEL(WRITE_MEM_R) +" fcmp d31, #0.0 \n\t" // Don't load for zero beta. +BEQ(ZERO_BETA_R) DLOADC_4V_R_FWD(12,13,14,15,x1,0,x6) DLOADC_4V_R_FWD(16,17,18,19,x1,0,x6) DLOADC_4V_R_FWD(20,21,22,23,x1,0,x6) -DSCALE12V(12,13,14,15,16,17,18,19,20,21,22,23,31,0) -DSCALEA12V(12,13,14,15,16,17,18,19,20,21,22,23,0,1,2,3,4,5,6,7,8,9,10,11,30,0) +DSCALEA12V(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,31,0) +LABEL(ZERO_BETA_R) #ifndef __clang__ " cmp x12, #1 \n\t" BRANCH(PRFM_END_R) @@ -393,9 +401,9 @@ BRANCH(PRFM_END_R) " prfm PLDL1STRM, [%[b_next], #16*1] \n\t" LABEL(PRFM_END_R) #endif -DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6) -DSTOREC_4V_R_FWD(16,17,18,19,x5,0,x6) -DSTOREC_4V_R_FWD(20,21,22,23,x5,0,x6) +DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) +DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) BRANCH(END_WRITE_MEM) // // C storage in columns. @@ -408,6 +416,8 @@ LABEL(WRITE_MEM_C) " trn2 v17.2d, v2.2d, v6.2d \n\t" " trn1 v18.2d, v3.2d, v7.2d \n\t" " trn2 v19.2d, v3.2d, v7.2d \n\t" +" fcmp d31, #0.0 \n\t" // Don't load for zero beta. +BEQ(ZERO_BETA_C) DLOADC_1V_1ELM_C_FWD(0,20,0,x1,0,x7) DLOADC_1V_1ELM_C_FWD(1,20,1,x1,0,x7) DLOADC_1V_1ELM_C_FWD(2,21,0,x1,0,x7) @@ -416,8 +426,8 @@ DLOADC_1V_1ELM_C_FWD(4,22,0,x1,0,x7) DLOADC_1V_1ELM_C_FWD(5,22,1,x1,0,x7) DLOADC_1V_1ELM_C_FWD(6,23,0,x1,0,x7) DLOADC_1V_1ELM_C_FWD(7,23,1,x1,0,x7) -DSCALE12V(0,1,2,3,4,5,6,7,20,21,22,23,31,0) -DSCALEA12V(0,1,2,3,4,5,6,7,20,21,22,23,12,13,14,15,16,17,18,19,8,9,10,11,30,0) +DSCALEA12V(12,13,14,15,16,17,18,19,8,9,10,11,0,1,2,3,4,5,6,7,20,21,22,23,31,0) +LABEL(ZERO_BETA_C) #ifndef __clang__ " cmp x12, #1 \n\t" BRANCH(PRFM_END_C) @@ -427,14 +437,14 @@ BRANCH(PRFM_END_C) " prfm PLDL1STRM, [%[b_next], #16*1] \n\t" LABEL(PRFM_END_C) #endif -DSTOREC_1V_1ELM_C_FWD(0,20,0,x5,0,x7) -DSTOREC_1V_1ELM_C_FWD(1,20,1,x5,0,x7) -DSTOREC_1V_1ELM_C_FWD(2,21,0,x5,0,x7) -DSTOREC_1V_1ELM_C_FWD(3,21,1,x5,0,x7) -DSTOREC_1V_1ELM_C_FWD(4,22,0,x5,0,x7) -DSTOREC_1V_1ELM_C_FWD(5,22,1,x5,0,x7) -DSTOREC_1V_1ELM_C_FWD(6,23,0,x5,0,x7) -DSTOREC_1V_1ELM_C_FWD(7,23,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(12,8,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(13,8,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(14,9,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(15,9,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(16,10,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(17,10,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(18,11,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(19,11,1,x5,0,x7) // // End of this microkernel. LABEL(END_WRITE_MEM) diff --git a/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8n.c b/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8n.c index 2703f75b3..53bedd773 100644 --- a/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8n.c +++ b/kernels/armv8a/3/sup/bli_gemmsup_rd_armv8a_asm_d6x8n.c @@ -426,6 +426,12 @@ LABEL(WRITE_MEM_PREP) " ld1r {v30.2d}, [x4] \n\t" // Load alpha & beta (value). " ld1r {v31.2d}, [x8] \n\t" " \n\t" +" fmov d28, #1.0 \n\t" // Don't scale for unit alpha. +" fcmp d30, d28 \n\t" +BEQ(UNIT_ALPHA) +DSCALE12V(0,1,2,3,4,5,6,7,8,9,10,11,30,0) +LABEL(UNIT_ALPHA) +" \n\t" " mov x1, x5 \n\t" // C address for loading. " \n\t" // C address for storing is x5 itself. " cmp x7, #8 \n\t" // Check for column-storage. @@ -433,14 +439,16 @@ BNE(WRITE_MEM_C) // // C storage in rows. LABEL(WRITE_MEM_R) +" fcmp d31, #0.0 \n\t" // Don't load for zero beta. +BEQ(ZERO_BETA_R) DLOADC_2V_R_FWD(12,13,x1,0,x6) DLOADC_2V_R_FWD(14,15,x1,0,x6) DLOADC_2V_R_FWD(16,17,x1,0,x6) DLOADC_2V_R_FWD(18,19,x1,0,x6) DLOADC_2V_R_FWD(20,21,x1,0,x6) DLOADC_2V_R_FWD(22,23,x1,0,x6) -DSCALE12V(12,13,14,15,16,17,18,19,20,21,22,23,31,0) -DSCALEA12V(12,13,14,15,16,17,18,19,20,21,22,23,0,1,2,3,4,5,6,7,8,9,10,11,30,0) +DSCALEA12V(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,31,0) +LABEL(ZERO_BETA_R) #ifndef __clang__ " cmp x12, #1 \n\t" BRANCH(PRFM_END_R) @@ -450,12 +458,12 @@ BRANCH(PRFM_END_R) " prfm PLDL1STRM, [%[b_next], #16*1] \n\t" LABEL(PRFM_END_R) #endif -DSTOREC_2V_R_FWD(12,13,x5,0,x6) -DSTOREC_2V_R_FWD(14,15,x5,0,x6) -DSTOREC_2V_R_FWD(16,17,x5,0,x6) -DSTOREC_2V_R_FWD(18,19,x5,0,x6) -DSTOREC_2V_R_FWD(20,21,x5,0,x6) -DSTOREC_2V_R_FWD(22,23,x5,0,x6) +DSTOREC_2V_R_FWD(0,1,x5,0,x6) +DSTOREC_2V_R_FWD(2,3,x5,0,x6) +DSTOREC_2V_R_FWD(4,5,x5,0,x6) +DSTOREC_2V_R_FWD(6,7,x5,0,x6) +DSTOREC_2V_R_FWD(8,9,x5,0,x6) +DSTOREC_2V_R_FWD(10,11,x5,0,x6) BRANCH(END_WRITE_MEM) // // C storage in columns. @@ -472,12 +480,14 @@ LABEL(WRITE_MEM_C) " trn2 v21.2d, v1.2d, v3.2d \n\t" " trn2 v22.2d, v5.2d, v7.2d \n\t" " trn2 v23.2d, v9.2d, v11.2d \n\t" +" fcmp d31, #0.0 \n\t" // Don't load for zero beta. +BEQ(ZERO_BETA_C) DLOADC_3V_C_FWD(0,1,2,x1,0,x7) DLOADC_3V_C_FWD(3,4,5,x1,0,x7) DLOADC_3V_C_FWD(6,7,8,x1,0,x7) DLOADC_3V_C_FWD(9,10,11,x1,0,x7) -DSCALE12V(0,1,2,3,4,5,6,7,8,9,10,11,31,0) -DSCALEA12V(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,30,0) +DSCALEA12V(12,13,14,15,16,17,18,19,20,21,22,23,0,1,2,3,4,5,6,7,8,9,10,11,31,0) +LABEL(ZERO_BETA_C) #ifndef __clang__ " cmp x12, #1 \n\t" BRANCH(PRFM_END_C) @@ -487,10 +497,10 @@ BRANCH(PRFM_END_C) " prfm PLDL1STRM, [%[b_next], #16*1] \n\t" LABEL(PRFM_END_C) #endif -DSTOREC_3V_C_FWD(0,1,2,x5,0,x7) -DSTOREC_3V_C_FWD(3,4,5,x5,0,x7) -DSTOREC_3V_C_FWD(6,7,8,x5,0,x7) -DSTOREC_3V_C_FWD(9,10,11,x5,0,x7) +DSTOREC_3V_C_FWD(12,13,14,x5,0,x7) +DSTOREC_3V_C_FWD(15,16,17,x5,0,x7) +DSTOREC_3V_C_FWD(18,19,20,x5,0,x7) +DSTOREC_3V_C_FWD(21,22,23,x5,0,x7) // // End of this microkernel. LABEL(END_WRITE_MEM) diff --git a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d3x4.c b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d3x4.c index 44a9915e0..84c7c4a7d 100644 --- a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d3x4.c +++ b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d3x4.c @@ -240,6 +240,7 @@ LABEL(WRITE_MEM_PREP) " ldr x8, %[beta] \n\t" " ld1r {v30.2d}, [x4] \n\t" // Load alpha & beta (value). " ld1r {v31.2d}, [x8] \n\t" +DSCALE6V(0,1,2,3,4,5,30,0) " \n\t" " mov x9, x5 \n\t" // C address for loading. " \n\t" // C address for storing is x5 itself. @@ -248,14 +249,16 @@ BNE(WRITE_MEM_C) // // C storage in rows. LABEL(WRITE_MEM_R) +" fcmp d31, #0.0 \n\t" +BEQ(ZERO_BETA_R) DLOADC_2V_R_FWD(12,13,x9,0,x6) DLOADC_2V_R_FWD(14,15,x9,0,x6) DLOADC_2V_R_FWD(16,17,x9,0,x6) -DSCALE6V(12,13,14,15,16,17,31,0) -DSCALEA6V(12,13,14,15,16,17,0,1,2,3,4,5,30,0) -DSTOREC_2V_R_FWD(12,13,x5,0,x6) -DSTOREC_2V_R_FWD(14,15,x5,0,x6) -DSTOREC_2V_R_FWD(16,17,x5,0,x6) +DSCALEA6V(0,1,2,3,4,5,12,13,14,15,16,17,31,0) +LABEL(ZERO_BETA_R) +DSTOREC_2V_R_FWD(0,1,x5,0,x6) +DSTOREC_2V_R_FWD(2,3,x5,0,x6) +DSTOREC_2V_R_FWD(4,5,x5,0,x6) BRANCH(END_WRITE_MEM) // // C storage in columns. @@ -264,16 +267,18 @@ LABEL(WRITE_MEM_C) " trn2 v7.2d, v0.2d, v2.2d \n\t" " trn1 v8.2d, v1.2d, v3.2d \n\t" " trn2 v9.2d, v1.2d, v3.2d \n\t" +" fcmp d31, #0.0 \n\t" +BEQ(ZERO_BETA_C) DLOADC_1V_1ELM_C_FWD(12,20,0,x9,0,x7) DLOADC_1V_1ELM_C_FWD(13,20,1,x9,0,x7) DLOADC_1V_1ELM_C_FWD(14,21,0,x9,0,x7) DLOADC_1V_1ELM_C_FWD(15,21,1,x9,0,x7) -DSCALE6V(12,13,14,15,20,21,31,0) -DSCALEA6V(12,13,14,15,20,21,6,7,8,9,4,5,30,0) -DSTOREC_1V_1ELM_C_FWD(12,20,0,x5,0,x7) -DSTOREC_1V_1ELM_C_FWD(13,20,1,x5,0,x7) -DSTOREC_1V_1ELM_C_FWD(14,21,0,x5,0,x7) -DSTOREC_1V_1ELM_C_FWD(15,21,1,x5,0,x7) +DSCALEA6V(6,7,8,9,4,5,12,13,14,15,20,21,31,0) +LABEL(ZERO_BETA_C) +DSTOREC_1V_1ELM_C_FWD(6,4,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(7,4,1,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(8,5,0,x5,0,x7) +DSTOREC_1V_1ELM_C_FWD(9,5,1,x5,0,x7) // // End of this microkernel. LABEL(END_WRITE_MEM) diff --git a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d6x3.c b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d6x3.c index 410d51283..abbb6fb4d 100644 --- a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d6x3.c +++ b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_asm_d6x3.c @@ -283,6 +283,7 @@ LABEL(WRITE_MEM_PREP) " ldr x8, %[beta] \n\t" " ld1r {v30.2d}, [x4] \n\t" // Load alpha & beta (value). " ld1r {v31.2d}, [x8] \n\t" +DSCALE9V(0,1,2,3,4,5,6,7,8,30,0) " \n\t" " mov x9, x5 \n\t" // C address for loading. " \n\t" // C address for storing is x5 itself. @@ -297,32 +298,36 @@ LABEL(WRITE_MEM_R) " trn2 v23.2d, v3.2d, v4.2d \n\t" " trn1 v24.2d, v6.2d, v7.2d \n\t" " trn2 v25.2d, v6.2d, v7.2d \n\t" +" fcmp d31, #0.0 \n\t" +BEQ(ZERO_BETA_R) DLOADC_1V_1ELM_R_FWD(10,26,0,x9,0,x6) DLOADC_1V_1ELM_R_FWD(11,26,1,x9,0,x6) DLOADC_1V_1ELM_R_FWD(12,27,0,x9,0,x6) DLOADC_1V_1ELM_R_FWD(13,27,1,x9,0,x6) DLOADC_1V_1ELM_R_FWD(14,28,0,x9,0,x6) DLOADC_1V_1ELM_R_FWD(15,28,1,x9,0,x6) -DSCALE9V(10,11,12,13,14,15,26,27,28,31,0) -DSCALEA9V(10,11,12,13,14,15,26,27,28,20,21,22,23,24,25,2,5,8,30,0) -DSTOREC_1V_1ELM_R_FWD(10,26,0,x5,0,x6) -DSTOREC_1V_1ELM_R_FWD(11,26,1,x5,0,x6) -DSTOREC_1V_1ELM_R_FWD(12,27,0,x5,0,x6) -DSTOREC_1V_1ELM_R_FWD(13,27,1,x5,0,x6) -DSTOREC_1V_1ELM_R_FWD(14,28,0,x5,0,x6) -DSTOREC_1V_1ELM_R_FWD(15,28,1,x5,0,x6) +DSCALEA9V(20,21,22,23,24,25,2,5,8,10,11,12,13,14,15,26,27,28,31,0) +LABEL(ZERO_BETA_R) +DSTOREC_1V_1ELM_R_FWD(20,2,0,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(21,2,1,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(22,5,0,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(23,5,1,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(24,8,0,x5,0,x6) +DSTOREC_1V_1ELM_R_FWD(25,8,1,x5,0,x6) BRANCH(END_WRITE_MEM) // // C storage in columns. LABEL(WRITE_MEM_C) +" fcmp d31, #0.0 \n\t" +BEQ(ZERO_BETA_C) DLOADC_3V_C_FWD(12,15,18,x9,0,x7) DLOADC_3V_C_FWD(13,16,19,x9,0,x7) DLOADC_3V_C_FWD(14,17,20,x9,0,x7) -DSCALE9V(12,13,14,15,16,17,18,19,20,31,0) -DSCALEA9V(12,13,14,15,16,17,18,19,20,0,1,2,3,4,5,6,7,8,30,0) -DSTOREC_3V_C_FWD(12,15,18,x5,0,x7) -DSTOREC_3V_C_FWD(13,16,19,x5,0,x7) -DSTOREC_3V_C_FWD(14,17,20,x5,0,x7) +DSCALEA9V(0,1,2,3,4,5,6,7,8,12,13,14,15,16,17,18,19,20,31,0) +LABEL(ZERO_BETA_C) +DSTOREC_3V_C_FWD(0,3,6,x5,0,x7) +DSTOREC_3V_C_FWD(1,4,7,x5,0,x7) +DSTOREC_3V_C_FWD(2,5,8,x5,0,x7) // // End of this microkernel. LABEL(END_WRITE_MEM) diff --git a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d2x8.c b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d2x8.c index e96069f87..43880063e 100644 --- a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d2x8.c +++ b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d2x8.c @@ -82,6 +82,7 @@ void bli_dgemmsup_rd_armv8a_int_2x8 uint64_t k_mker = k0 / 2; uint64_t k_left = k0 % 2; + uint64_t b_iszr = ( *beta == 0.0 ); assert( cs_a == 1 ); assert( rs_b == 1 ); @@ -252,10 +253,13 @@ void bli_dgemmsup_rd_armv8a_int_2x8 else if ( n0 > 4 ) vb_2 = vld1q_lane_f64( c_loc + 0 * rs_c + 4, vb_2, 0 ); if ( n0 > 7 ) vb_3 = vld1q_f64 ( c_loc + 0 * rs_c + 6 ); else if ( n0 > 6 ) vb_3 = vld1q_lane_f64( c_loc + 0 * rs_c + 6, vb_3, 0 ); - vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); - vc_02 = vfmaq_f64( vc_02, va_0, vb_1 ); - vc_04 = vfmaq_f64( vc_04, va_0, vb_2 ); - vc_06 = vfmaq_f64( vc_06, va_0, vb_3 ); + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_1 ); + vc_04 = vfmaq_f64( vc_04, va_0, vb_2 ); + vc_06 = vfmaq_f64( vc_06, va_0, vb_3 ); + } if ( n0 > 1 ) vst1q_f64 ( c_loc + 0 * rs_c + 0, vc_00 ); else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 0 * rs_c + 0, vc_00, 0 ); if ( n0 > 3 ) vst1q_f64 ( c_loc + 0 * rs_c + 2, vc_02 ); @@ -275,10 +279,13 @@ void bli_dgemmsup_rd_armv8a_int_2x8 else if ( n0 > 4 ) vb_2 = vld1q_lane_f64( c_loc + 1 * rs_c + 4, vb_2, 0 ); if ( n0 > 7 ) vb_3 = vld1q_f64 ( c_loc + 1 * rs_c + 6 ); else if ( n0 > 6 ) vb_3 = vld1q_lane_f64( c_loc + 1 * rs_c + 6, vb_3, 0 ); - vc_10 = vfmaq_f64( vc_10, va_0, vb_0 ); - vc_12 = vfmaq_f64( vc_12, va_0, vb_1 ); - vc_14 = vfmaq_f64( vc_14, va_0, vb_2 ); - vc_16 = vfmaq_f64( vc_16, va_0, vb_3 ); + if ( !b_iszr ) + { + vc_10 = vfmaq_f64( vc_10, va_0, vb_0 ); + vc_12 = vfmaq_f64( vc_12, va_0, vb_1 ); + vc_14 = vfmaq_f64( vc_14, va_0, vb_2 ); + vc_16 = vfmaq_f64( vc_16, va_0, vb_3 ); + } if ( n0 > 1 ) vst1q_f64 ( c_loc + 1 * rs_c + 0, vc_10 ); else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 1 * rs_c + 0, vc_10, 0 ); if ( n0 > 3 ) vst1q_f64 ( c_loc + 1 * rs_c + 2, vc_12 ); @@ -308,10 +315,13 @@ void bli_dgemmsup_rd_armv8a_int_2x8 if ( n0 > 1 ) vb_1 = vld1q_f64( c_loc + 0 + 1 * cs_c ); if ( n0 > 2 ) vb_2 = vld1q_f64( c_loc + 0 + 2 * cs_c ); if ( n0 > 3 ) vb_3 = vld1q_f64( c_loc + 0 + 3 * cs_c ); - vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); - vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); - vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); - vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + } vst1q_f64( c_loc + 0 + 0 * cs_c, vc_00 ); if ( n0 > 1 ) vst1q_f64( c_loc + 0 + 1 * cs_c, vc_01 ); if ( n0 > 2 ) vst1q_f64( c_loc + 0 + 2 * cs_c, vc_02 ); @@ -321,10 +331,13 @@ void bli_dgemmsup_rd_armv8a_int_2x8 if ( n0 > 5 ) vb_1 = vld1q_f64( c_loc + 0 + 5 * cs_c ); if ( n0 > 6 ) vb_2 = vld1q_f64( c_loc + 0 + 6 * cs_c ); if ( n0 > 7 ) vb_3 = vld1q_f64( c_loc + 0 + 7 * cs_c ); - vc_04 = vfmaq_f64( vc_04, va_0, vb_0 ); - vc_05 = vfmaq_f64( vc_05, va_0, vb_1 ); - vc_06 = vfmaq_f64( vc_06, va_0, vb_2 ); - vc_07 = vfmaq_f64( vc_07, va_0, vb_3 ); + if ( !b_iszr ) + { + vc_04 = vfmaq_f64( vc_04, va_0, vb_0 ); + vc_05 = vfmaq_f64( vc_05, va_0, vb_1 ); + vc_06 = vfmaq_f64( vc_06, va_0, vb_2 ); + vc_07 = vfmaq_f64( vc_07, va_0, vb_3 ); + } if ( n0 > 4 ) vst1q_f64( c_loc + 0 + 4 * cs_c, vc_04 ); if ( n0 > 5 ) vst1q_f64( c_loc + 0 + 5 * cs_c, vc_05 ); if ( n0 > 6 ) vst1q_f64( c_loc + 0 + 6 * cs_c, vc_06 ); @@ -337,10 +350,13 @@ void bli_dgemmsup_rd_armv8a_int_2x8 if ( n0 > 1 ) vb_1 = vld1q_lane_f64( c_loc + 0 + 1 * cs_c, vb_1, 0 ); if ( n0 > 2 ) vb_2 = vld1q_lane_f64( c_loc + 0 + 2 * cs_c, vb_2, 0 ); if ( n0 > 3 ) vb_3 = vld1q_lane_f64( c_loc + 0 + 3 * cs_c, vb_3, 0 ); - vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); - vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); - vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); - vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_01 = vfmaq_f64( vc_01, va_0, vb_1 ); + vc_02 = vfmaq_f64( vc_02, va_0, vb_2 ); + vc_03 = vfmaq_f64( vc_03, va_0, vb_3 ); + } vst1q_lane_f64( c_loc + 0 + 0 * cs_c, vc_00, 0 ); if ( n0 > 1 ) vst1q_lane_f64( c_loc + 0 + 1 * cs_c, vc_01, 0 ); if ( n0 > 2 ) vst1q_lane_f64( c_loc + 0 + 2 * cs_c, vc_02, 0 ); @@ -350,10 +366,13 @@ void bli_dgemmsup_rd_armv8a_int_2x8 if ( n0 > 5 ) vb_1 = vld1q_lane_f64( c_loc + 0 + 5 * cs_c, vb_1, 0 ); if ( n0 > 6 ) vb_2 = vld1q_lane_f64( c_loc + 0 + 6 * cs_c, vb_2, 0 ); if ( n0 > 7 ) vb_3 = vld1q_lane_f64( c_loc + 0 + 7 * cs_c, vb_3, 0 ); - vc_04 = vfmaq_f64( vc_04, va_0, vb_0 ); - vc_05 = vfmaq_f64( vc_05, va_0, vb_1 ); - vc_06 = vfmaq_f64( vc_06, va_0, vb_2 ); - vc_07 = vfmaq_f64( vc_07, va_0, vb_3 ); + if ( !b_iszr ) + { + vc_04 = vfmaq_f64( vc_04, va_0, vb_0 ); + vc_05 = vfmaq_f64( vc_05, va_0, vb_1 ); + vc_06 = vfmaq_f64( vc_06, va_0, vb_2 ); + vc_07 = vfmaq_f64( vc_07, va_0, vb_3 ); + } if ( n0 > 4 ) vst1q_lane_f64( c_loc + 0 + 4 * cs_c, vc_04, 0 ); if ( n0 > 5 ) vst1q_lane_f64( c_loc + 0 + 5 * cs_c, vc_05, 0 ); if ( n0 > 6 ) vst1q_lane_f64( c_loc + 0 + 6 * cs_c, vc_06, 0 ); diff --git a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d3x4.c b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d3x4.c index 7ab06d1ca..73e5f20fb 100644 --- a/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d3x4.c +++ b/kernels/armv8a/3/sup/d3x4/bli_gemmsup_rd_armv8a_int_d3x4.c @@ -94,6 +94,7 @@ void bli_dgemmsup_rd_armv8a_int_3x4 uint64_t k_mker = k0 / 2; uint64_t k_left = k0 % 2; + uint64_t b_iszr = ( *beta == 0.0 ); assert( cs_a == 1 ); assert( rs_b == 1 ); @@ -228,8 +229,11 @@ void bli_dgemmsup_rd_armv8a_int_3x4 if ( n0 > 3 ) va_1 = vld1q_f64 ( c_loc + 0 * rs_c + 2 ); else if ( n0 > 2 ) va_1 = vld1q_lane_f64( c_loc + 0 * rs_c + 2, va_1, 0 ); - vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); - vc_02 = vfmaq_f64( vc_02, va_1, vb_0 ); + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_02 = vfmaq_f64( vc_02, va_1, vb_0 ); + } if ( n0 > 1 ) vst1q_f64 ( c_loc + 0 * rs_c + 0, vc_00 ); else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 0 * rs_c + 0, vc_00, 0 ); @@ -243,8 +247,11 @@ void bli_dgemmsup_rd_armv8a_int_3x4 if ( n0 > 3 ) va_1 = vld1q_f64 ( c_loc + 1 * rs_c + 2 ); else if ( n0 > 2 ) va_1 = vld1q_lane_f64( c_loc + 1 * rs_c + 2, va_1, 0 ); - vc_10 = vfmaq_f64( vc_10, va_0, vb_0 ); - vc_12 = vfmaq_f64( vc_12, va_1, vb_0 ); + if ( !b_iszr ) + { + vc_10 = vfmaq_f64( vc_10, va_0, vb_0 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_0 ); + } if ( n0 > 1 ) vst1q_f64 ( c_loc + 1 * rs_c + 0, vc_10 ); else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 1 * rs_c + 0, vc_10, 0 ); @@ -258,8 +265,11 @@ void bli_dgemmsup_rd_armv8a_int_3x4 if ( n0 > 3 ) va_1 = vld1q_f64 ( c_loc + 2 * rs_c + 2 ); else if ( n0 > 2 ) va_1 = vld1q_lane_f64( c_loc + 2 * rs_c + 2, va_1, 0 ); - vc_20 = vfmaq_f64( vc_20, va_0, vb_0 ); - vc_22 = vfmaq_f64( vc_22, va_1, vb_0 ); + if ( !b_iszr ) + { + vc_20 = vfmaq_f64( vc_20, va_0, vb_0 ); + vc_22 = vfmaq_f64( vc_22, va_1, vb_0 ); + } if ( n0 > 1 ) vst1q_f64 ( c_loc + 2 * rs_c + 0, vc_20 ); else if ( n0 > 0 ) vst1q_lane_f64( c_loc + 2 * rs_c + 0, vc_20, 0 ); @@ -279,9 +289,12 @@ void bli_dgemmsup_rd_armv8a_int_3x4 if ( m0 > 1 ) va_1 = vld1q_lane_f64( c_loc + 1 + 1 * cs_c, va_1, 1 ); if ( m0 > 2 ) va_2 = vld1q_lane_f64( c_loc + 2 + 1 * cs_c, va_2, 1 ); } - vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); - vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); - vc_20 = vfmaq_f64( vc_20, va_2, vb_0 ); + if ( !b_iszr ) + { + vc_00 = vfmaq_f64( vc_00, va_0, vb_0 ); + vc_10 = vfmaq_f64( vc_10, va_1, vb_0 ); + vc_20 = vfmaq_f64( vc_20, va_2, vb_0 ); + } if ( m0 > 0 ) vst1q_lane_f64( c_loc + 0 + 0 * cs_c, vc_00, 0 ); if ( m0 > 1 ) vst1q_lane_f64( c_loc + 1 + 0 * cs_c, vc_10, 0 ); if ( m0 > 2 ) vst1q_lane_f64( c_loc + 2 + 0 * cs_c, vc_20, 0 ); @@ -304,9 +317,12 @@ void bli_dgemmsup_rd_armv8a_int_3x4 if ( m0 > 1 ) va_1 = vld1q_lane_f64( c_loc + 1 + 3 * cs_c, va_1, 1 ); if ( m0 > 2 ) va_2 = vld1q_lane_f64( c_loc + 2 + 3 * cs_c, va_2, 1 ); } - vc_02 = vfmaq_f64( vc_02, va_0, vb_0 ); - vc_12 = vfmaq_f64( vc_12, va_1, vb_0 ); - vc_22 = vfmaq_f64( vc_22, va_2, vb_0 ); + if ( !b_iszr ) + { + vc_02 = vfmaq_f64( vc_02, va_0, vb_0 ); + vc_12 = vfmaq_f64( vc_12, va_1, vb_0 ); + vc_22 = vfmaq_f64( vc_22, va_2, vb_0 ); + } if ( n0 > 2 ) { if ( m0 > 0 ) vst1q_lane_f64( c_loc + 0 + 2 * cs_c, vc_02, 0 );