diff --git a/config/firestorm/bli_cntx_init_firestorm.c b/config/firestorm/bli_cntx_init_firestorm.c index 05e946ffd..3ea35c690 100644 --- a/config/firestorm/bli_cntx_init_firestorm.c +++ b/config/firestorm/bli_cntx_init_firestorm.c @@ -50,7 +50,7 @@ void bli_cntx_init_firestorm( cntx_t* cntx ) ( 2, BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armv8a_asm_8x12, FALSE, - BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armv8a_asm_6x8, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armv8a_asm_6x8r, TRUE, cntx ); diff --git a/kernels/armv8a/3/armv8a_asm_utils.h b/kernels/armv8a/3/armv8a_asm_utils.h index 86dcaa7a6..5cb0bad69 100644 --- a/kernels/armv8a/3/armv8a_asm_utils.h +++ b/kernels/armv8a/3/armv8a_asm_utils.h @@ -96,6 +96,11 @@ DLOAD2V(V0,V1,ADDR,SHIFT) \ DLOAD2V(V2,V3,ADDR,SHIFT+32) +// Generic: load one line. +#define DLOAD1V_GATHER_ELMFWD(V,ADDR,INC) \ +" ld1 {v"#V".d}[0], ["#ADDR"], "#INC" \n\t" \ +" ld1 {v"#V".d}[1], ["#ADDR"], "#INC" \n\t" + // Store one line. #define DSTORE1V(V,ADDR,SHIFT) \ " str q"#V", ["#ADDR", #"#SHIFT"] \n\t" @@ -106,4 +111,9 @@ DSTORE2V(V0,V1,ADDR,SHIFT) \ DSTORE2V(V2,V3,ADDR,SHIFT+32) +// Generic: store one line. +#define DSTORE1V_SCATTER_ELMFWD(V,ADDR,INC) \ +" st1 {v"#V".d}[0], ["#ADDR"], "#INC" \n\t" \ +" st1 {v"#V".d}[1], ["#ADDR"], "#INC" \n\t" + diff --git a/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8r.c b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8r.c index 2fe18e004..2fe83438f 100644 --- a/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8r.c +++ b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8r.c @@ -35,7 +35,6 @@ */ #include "blis.h" -#include "assert.h" // Label locality & misc. #include "armv8a_asm_utils.h" @@ -94,6 +93,24 @@ " prfm PLDL1KEEP, ["#CADDR"] \n\t" \ " add "#CADDR", "#CADDR", "#RSC" \n\t" +// For scattered storage of C. +#define DLOADC_GATHER_4V_R_FWD(C0,C1,C2,C3,CADDR,CELEM,CSC,RSC) \ +" mov "#CELEM", "#CADDR" \n\t" \ + DLOAD1V_GATHER_ELMFWD(C0,CELEM,CSC) \ + DLOAD1V_GATHER_ELMFWD(C1,CELEM,CSC) \ + DLOAD1V_GATHER_ELMFWD(C2,CELEM,CSC) \ + DLOAD1V_GATHER_ELMFWD(C3,CELEM,CSC) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + +#define DSTOREC_SCATTER_4V_R_FWD(C0,C1,C2,C3,CADDR,CELEM,CSC,RSC) \ +" mov "#CELEM", "#CADDR" \n\t" \ + DSTORE1V_SCATTER_ELMFWD(C0,CELEM,CSC) \ + DSTORE1V_SCATTER_ELMFWD(C1,CELEM,CSC) \ + DSTORE1V_SCATTER_ELMFWD(C2,CELEM,CSC) \ + DSTORE1V_SCATTER_ELMFWD(C3,CELEM,CSC) \ +" add "#CADDR", "#CADDR", "#RSC" \n\t" + + void bli_dgemm_armv8a_asm_6x8r ( dim_t k0, @@ -109,11 +126,6 @@ void bli_dgemm_armv8a_asm_6x8r void* a_next = bli_auxinfo_next_a( data ); void* b_next = bli_auxinfo_next_b( data ); - // This kernel is a WIP. - // I have no generic stride support at this moment. - assert( cs_c0 == 1 ); - // if ( cs_c0 != 1 ) return ; - // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. uint64_t k_mker = k0 / 4; @@ -245,6 +257,14 @@ LABEL(PREFETCH_ABNEXT) " prfm PLDL1STRM, [x1, 64*1] \n\t" " prfm PLDL1STRM, [x1, 64*3] \n\t" " \n\t" +" fmov d26, #1.0 \n\t" +" fcmp d24, d26 \n\t" +BEQ(UNIT_ALPHA) +DSCALE8V(0,1,2,3,4,5,6,7,24,0) +DSCALE8V(8,9,10,11,12,13,14,15,24,0) +DSCALE8V(16,17,18,19,20,21,22,23,24,0) +LABEL(UNIT_ALPHA) +" \n\t" " mov x9, x5 \n\t" // C address for loading. " \n\t" // C address for storing is x5 itself. " cmp x7, #8 \n\t" // Check for generic storage. @@ -252,31 +272,65 @@ BNE(WRITE_MEM_G) // // Contiguous C-storage. LABEL(WRITE_MEM_R) +" fcmp d25, #0.0 \n\t" // Sets conditional flag whether *beta == 0. +" \n\t" // This conditional flag will be used +" \n\t" // multiple times for skipping load. +// Row 0: +BEQ(ZERO_BETA_R_0) DLOADC_4V_R_FWD(26,27,28,29,x9,0,x6) -DSCALE4V(26,27,28,29,25,0) -DSCALEA4V(26,27,28,29,0,1,2,3,24,0) -DLOADC_4V_R_FWD(0,1,2,3,x9,0,x6) -DSCALE4V(0,1,2,3,25,0) -DSCALEA4V(0,1,2,3,4,5,6,7,24,0) -DSTOREC_4V_R_FWD(26,27,28,29,x5,0,x6) -DLOADC_4V_R_FWD(4,5,6,7,x9,0,x6) -DLOADC_4V_R_FWD(26,27,28,29,x9,0,x6) -DSCALE8V(4,5,6,7,26,27,28,29,25,0) -DSCALEA8V(4,5,6,7,26,27,28,29,8,9,10,11,12,13,14,15,24,0) -DLOADC_4V_R_FWD(8,9,10,11,x9,0,x6) -DLOADC_4V_R_FWD(12,13,14,15,x9,0,x6) -DSCALE8V(8,9,10,11,12,13,14,15,25,0) -DSCALEA8V(8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,0) +DSCALEA4V(0,1,2,3,26,27,28,29,25,0) +LABEL(ZERO_BETA_R_0) DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6) +// Row 1 & 2: +BEQ(ZERO_BETA_R_1_2) +DLOADC_4V_R_FWD(26,27,28,29,x9,0,x6) +DLOADC_4V_R_FWD(0,1,2,3,x9,0,x6) +DSCALEA8V(4,5,6,7,8,9,10,11,26,27,28,29,0,1,2,3,25,0) +LABEL(ZERO_BETA_R_1_2) DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6) -DSTOREC_4V_R_FWD(26,27,28,29,x5,0,x6) DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6) +// Row 3 & 4 & 5: +BEQ(ZERO_BETA_R_3_4_5) +DLOADC_4V_R_FWD(0,1,2,3,x9,0,x6) +DLOADC_4V_R_FWD(4,5,6,7,x9,0,x6) +DLOADC_4V_R_FWD(8,9,10,11,x9,0,x6) +DSCALEA8V(12,13,14,15,16,17,18,19,0,1,2,3,4,5,6,7,25,0) +DSCALEA4V(20,21,22,23,8,9,10,11,25,0) +LABEL(ZERO_BETA_R_3_4_5) DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6) +DSTOREC_4V_R_FWD(16,17,18,19,x5,0,x6) +DSTOREC_4V_R_FWD(20,21,22,23,x5,0,x6) BRANCH(END_WRITE_MEM) // // Generic-strided C-storage. LABEL(WRITE_MEM_G) -// TODO: Implement. +" fcmp d25, #0.0 \n\t" // Sets conditional flag whether *beta == 0. +" \n\t" +// Row 0: +BEQ(ZERO_BETA_G_0) +DLOADC_GATHER_4V_R_FWD(26,27,28,29,x9,x0,x7,x6) +DSCALEA4V(0,1,2,3,26,27,28,29,25,0) +LABEL(ZERO_BETA_G_0) +DSTOREC_SCATTER_4V_R_FWD(0,1,2,3,x5,x1,x7,x6) +// Row 1 & 2: +BEQ(ZERO_BETA_G_1_2) +DLOADC_GATHER_4V_R_FWD(26,27,28,29,x9,x0,x7,x6) +DLOADC_GATHER_4V_R_FWD(0,1,2,3,x9,x0,x7,x6) +DSCALEA8V(4,5,6,7,8,9,10,11,26,27,28,29,0,1,2,3,25,0) +LABEL(ZERO_BETA_G_1_2) +DSTOREC_SCATTER_4V_R_FWD(4,5,6,7,x5,x1,x7,x6) +DSTOREC_SCATTER_4V_R_FWD(8,9,10,11,x5,x1,x7,x6) +// Row 3 & 4 & 5: +BEQ(ZERO_BETA_G_3_4_5) +DLOADC_GATHER_4V_R_FWD(0,1,2,3,x9,x0,x7,x6) +DLOADC_GATHER_4V_R_FWD(4,5,6,7,x9,x0,x7,x6) +DLOADC_GATHER_4V_R_FWD(8,9,10,11,x9,x0,x7,x6) +DSCALEA8V(12,13,14,15,16,17,18,19,0,1,2,3,4,5,6,7,25,0) +DSCALEA4V(20,21,22,23,8,9,10,11,25,0) +LABEL(ZERO_BETA_G_3_4_5) +DSTOREC_SCATTER_4V_R_FWD(12,13,14,15,x5,x1,x7,x6) +DSTOREC_SCATTER_4V_R_FWD(16,17,18,19,x5,x1,x7,x6) +DSTOREC_SCATTER_4V_R_FWD(20,21,22,23,x5,x1,x7,x6) LABEL(END_WRITE_MEM) : : [a] "m" (a),