Armv8 Fix 6x8 Row-Maj Ukr

- Fixed for 6x8 only, 4x4 & 4x8 pending;
- Installed to config firestorm as benchmark seems to show better perf:
   Old:
blis_dgemm_ukr_c                     6     8   320    36.87   2.43e-17   PASS
blis_dgemm_ukr_c                     6     8   352    40.55   1.04e-17   PASS
blis_dgemm_ukr_c                     6     8   384    44.24   5.68e-17   PASS
blis_dgemm_ukr_c                     6     8   416    41.67   3.51e-17   PASS
blis_dgemm_ukr_c                     6     8   448    34.41   2.94e-17   PASS
blis_dgemm_ukr_c                     6     8   480    42.53   2.35e-17   PASS

   New:
blis_dgemm_ukr_r                     6     8   352    50.69   1.59e-17   PASS
blis_dgemm_ukr_r                     6     8   384    49.15   5.55e-17   PASS
blis_dgemm_ukr_r                     6     8   416    50.44   2.86e-17   PASS
blis_dgemm_ukr_r                     6     8   448    46.92   3.12e-17   PASS
blis_dgemm_ukr_r                     6     8   480    48.08   4.08e-17   PASS
This commit is contained in:
RuQing Xu
2021-10-03 13:14:19 +09:00
parent 9c0064f3f6
commit abc648352c
3 changed files with 87 additions and 23 deletions

View File

@@ -50,7 +50,7 @@ void bli_cntx_init_firestorm( cntx_t* cntx )
(
2,
BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armv8a_asm_8x12, FALSE,
BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armv8a_asm_6x8, FALSE,
BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armv8a_asm_6x8r, TRUE,
cntx
);

View File

@@ -96,6 +96,11 @@
DLOAD2V(V0,V1,ADDR,SHIFT) \
DLOAD2V(V2,V3,ADDR,SHIFT+32)
// Generic: load one line.
#define DLOAD1V_GATHER_ELMFWD(V,ADDR,INC) \
" ld1 {v"#V".d}[0], ["#ADDR"], "#INC" \n\t" \
" ld1 {v"#V".d}[1], ["#ADDR"], "#INC" \n\t"
// Store one line.
#define DSTORE1V(V,ADDR,SHIFT) \
" str q"#V", ["#ADDR", #"#SHIFT"] \n\t"
@@ -106,4 +111,9 @@
DSTORE2V(V0,V1,ADDR,SHIFT) \
DSTORE2V(V2,V3,ADDR,SHIFT+32)
// Generic: store one line.
#define DSTORE1V_SCATTER_ELMFWD(V,ADDR,INC) \
" st1 {v"#V".d}[0], ["#ADDR"], "#INC" \n\t" \
" st1 {v"#V".d}[1], ["#ADDR"], "#INC" \n\t"

View File

@@ -35,7 +35,6 @@
*/
#include "blis.h"
#include "assert.h"
// Label locality & misc.
#include "armv8a_asm_utils.h"
@@ -94,6 +93,24 @@
" prfm PLDL1KEEP, ["#CADDR"] \n\t" \
" add "#CADDR", "#CADDR", "#RSC" \n\t"
// For scattered storage of C.
#define DLOADC_GATHER_4V_R_FWD(C0,C1,C2,C3,CADDR,CELEM,CSC,RSC) \
" mov "#CELEM", "#CADDR" \n\t" \
DLOAD1V_GATHER_ELMFWD(C0,CELEM,CSC) \
DLOAD1V_GATHER_ELMFWD(C1,CELEM,CSC) \
DLOAD1V_GATHER_ELMFWD(C2,CELEM,CSC) \
DLOAD1V_GATHER_ELMFWD(C3,CELEM,CSC) \
" add "#CADDR", "#CADDR", "#RSC" \n\t"
#define DSTOREC_SCATTER_4V_R_FWD(C0,C1,C2,C3,CADDR,CELEM,CSC,RSC) \
" mov "#CELEM", "#CADDR" \n\t" \
DSTORE1V_SCATTER_ELMFWD(C0,CELEM,CSC) \
DSTORE1V_SCATTER_ELMFWD(C1,CELEM,CSC) \
DSTORE1V_SCATTER_ELMFWD(C2,CELEM,CSC) \
DSTORE1V_SCATTER_ELMFWD(C3,CELEM,CSC) \
" add "#CADDR", "#CADDR", "#RSC" \n\t"
void bli_dgemm_armv8a_asm_6x8r
(
dim_t k0,
@@ -109,11 +126,6 @@ void bli_dgemm_armv8a_asm_6x8r
void* a_next = bli_auxinfo_next_a( data );
void* b_next = bli_auxinfo_next_b( data );
// This kernel is a WIP.
// I have no generic stride support at this moment.
assert( cs_c0 == 1 );
// if ( cs_c0 != 1 ) return ;
// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
uint64_t k_mker = k0 / 4;
@@ -245,6 +257,14 @@ LABEL(PREFETCH_ABNEXT)
" prfm PLDL1STRM, [x1, 64*1] \n\t"
" prfm PLDL1STRM, [x1, 64*3] \n\t"
" \n\t"
" fmov d26, #1.0 \n\t"
" fcmp d24, d26 \n\t"
BEQ(UNIT_ALPHA)
DSCALE8V(0,1,2,3,4,5,6,7,24,0)
DSCALE8V(8,9,10,11,12,13,14,15,24,0)
DSCALE8V(16,17,18,19,20,21,22,23,24,0)
LABEL(UNIT_ALPHA)
" \n\t"
" mov x9, x5 \n\t" // C address for loading.
" \n\t" // C address for storing is x5 itself.
" cmp x7, #8 \n\t" // Check for generic storage.
@@ -252,31 +272,65 @@ BNE(WRITE_MEM_G)
//
// Contiguous C-storage.
LABEL(WRITE_MEM_R)
" fcmp d25, #0.0 \n\t" // Sets conditional flag whether *beta == 0.
" \n\t" // This conditional flag will be used
" \n\t" // multiple times for skipping load.
// Row 0:
BEQ(ZERO_BETA_R_0)
DLOADC_4V_R_FWD(26,27,28,29,x9,0,x6)
DSCALE4V(26,27,28,29,25,0)
DSCALEA4V(26,27,28,29,0,1,2,3,24,0)
DLOADC_4V_R_FWD(0,1,2,3,x9,0,x6)
DSCALE4V(0,1,2,3,25,0)
DSCALEA4V(0,1,2,3,4,5,6,7,24,0)
DSTOREC_4V_R_FWD(26,27,28,29,x5,0,x6)
DLOADC_4V_R_FWD(4,5,6,7,x9,0,x6)
DLOADC_4V_R_FWD(26,27,28,29,x9,0,x6)
DSCALE8V(4,5,6,7,26,27,28,29,25,0)
DSCALEA8V(4,5,6,7,26,27,28,29,8,9,10,11,12,13,14,15,24,0)
DLOADC_4V_R_FWD(8,9,10,11,x9,0,x6)
DLOADC_4V_R_FWD(12,13,14,15,x9,0,x6)
DSCALE8V(8,9,10,11,12,13,14,15,25,0)
DSCALEA8V(8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,0)
DSCALEA4V(0,1,2,3,26,27,28,29,25,0)
LABEL(ZERO_BETA_R_0)
DSTOREC_4V_R_FWD(0,1,2,3,x5,0,x6)
// Row 1 & 2:
BEQ(ZERO_BETA_R_1_2)
DLOADC_4V_R_FWD(26,27,28,29,x9,0,x6)
DLOADC_4V_R_FWD(0,1,2,3,x9,0,x6)
DSCALEA8V(4,5,6,7,8,9,10,11,26,27,28,29,0,1,2,3,25,0)
LABEL(ZERO_BETA_R_1_2)
DSTOREC_4V_R_FWD(4,5,6,7,x5,0,x6)
DSTOREC_4V_R_FWD(26,27,28,29,x5,0,x6)
DSTOREC_4V_R_FWD(8,9,10,11,x5,0,x6)
// Row 3 & 4 & 5:
BEQ(ZERO_BETA_R_3_4_5)
DLOADC_4V_R_FWD(0,1,2,3,x9,0,x6)
DLOADC_4V_R_FWD(4,5,6,7,x9,0,x6)
DLOADC_4V_R_FWD(8,9,10,11,x9,0,x6)
DSCALEA8V(12,13,14,15,16,17,18,19,0,1,2,3,4,5,6,7,25,0)
DSCALEA4V(20,21,22,23,8,9,10,11,25,0)
LABEL(ZERO_BETA_R_3_4_5)
DSTOREC_4V_R_FWD(12,13,14,15,x5,0,x6)
DSTOREC_4V_R_FWD(16,17,18,19,x5,0,x6)
DSTOREC_4V_R_FWD(20,21,22,23,x5,0,x6)
BRANCH(END_WRITE_MEM)
//
// Generic-strided C-storage.
LABEL(WRITE_MEM_G)
// TODO: Implement.
" fcmp d25, #0.0 \n\t" // Sets conditional flag whether *beta == 0.
" \n\t"
// Row 0:
BEQ(ZERO_BETA_G_0)
DLOADC_GATHER_4V_R_FWD(26,27,28,29,x9,x0,x7,x6)
DSCALEA4V(0,1,2,3,26,27,28,29,25,0)
LABEL(ZERO_BETA_G_0)
DSTOREC_SCATTER_4V_R_FWD(0,1,2,3,x5,x1,x7,x6)
// Row 1 & 2:
BEQ(ZERO_BETA_G_1_2)
DLOADC_GATHER_4V_R_FWD(26,27,28,29,x9,x0,x7,x6)
DLOADC_GATHER_4V_R_FWD(0,1,2,3,x9,x0,x7,x6)
DSCALEA8V(4,5,6,7,8,9,10,11,26,27,28,29,0,1,2,3,25,0)
LABEL(ZERO_BETA_G_1_2)
DSTOREC_SCATTER_4V_R_FWD(4,5,6,7,x5,x1,x7,x6)
DSTOREC_SCATTER_4V_R_FWD(8,9,10,11,x5,x1,x7,x6)
// Row 3 & 4 & 5:
BEQ(ZERO_BETA_G_3_4_5)
DLOADC_GATHER_4V_R_FWD(0,1,2,3,x9,x0,x7,x6)
DLOADC_GATHER_4V_R_FWD(4,5,6,7,x9,x0,x7,x6)
DLOADC_GATHER_4V_R_FWD(8,9,10,11,x9,x0,x7,x6)
DSCALEA8V(12,13,14,15,16,17,18,19,0,1,2,3,4,5,6,7,25,0)
DSCALEA4V(20,21,22,23,8,9,10,11,25,0)
LABEL(ZERO_BETA_G_3_4_5)
DSTOREC_SCATTER_4V_R_FWD(12,13,14,15,x5,x1,x7,x6)
DSTOREC_SCATTER_4V_R_FWD(16,17,18,19,x5,x1,x7,x6)
DSTOREC_SCATTER_4V_R_FWD(20,21,22,23,x5,x1,x7,x6)
LABEL(END_WRITE_MEM)
:
: [a] "m" (a),