ArmSVE Use Predicate in M-Direction

No need to query MR during kernel runtime.
This commit is contained in:
RuQing Xu
2022-02-05 16:56:04 +09:00
committed by Devin Matthews
parent 9cc897f374
commit 72089bb291
4 changed files with 43 additions and 74 deletions

View File

@@ -68,10 +68,10 @@ void bli_cgemm_armsve_asm_2vx10_unindexed
uint64_t cs_c = cs_c0;
uint64_t info = 0;
uint64_t mr = bli_vl_bytes_armsve() * 2 / 8;
GEMM_UKR_SETUP_CT( c, mr, 10, false );
GEMM_UKR_SETUP_CT( c, m, 10, false );
__asm__ volatile (
" whilelo p0.s, xzr, %12 \n\t"
// " ldr x0, %[a] \n\t"
// " ldr x1, %[b] \n\t"
" mov x2, xzr \n\t"
@@ -97,7 +97,6 @@ void bli_cgemm_armsve_asm_2vx10_unindexed
" madd x2, x16, x2, xzr \n\t" // cs_a
" madd x3, x16, x3, xzr \n\t" // rs_b
" madd %4, x16, %4, xzr \n\t" // cs_c
" ptrue p0.s \n\t"
" \n\t"
// " ldr x5, %[k_mker] \n\t" // Number of loops.
// " ldr x6, %[k_left] \n\t"
@@ -307,7 +306,7 @@ GEMM_CCMPLX_STORE_COL2_C(z8 ,z9 ,z10,z11,p0,%2,%4)
"+r" (a_next), // %9
"+r" (b_next), // %10
"=r" (info) // %11
:
: "r" (m) // %12
: "x2","x3","x9","x16",
"z0","z1","z2","z3","z4","z5","z6","z7",
"z8","z9","z10","z11","z12","z13","z14","z15",

View File

@@ -67,10 +67,14 @@ void bli_dgemm_armsve_asm_2vx10_unindexed
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
uint64_t mr = bli_vl_bytes_armsve() * 2 / 8;
GEMM_UKR_SETUP_CT( d, mr, 10, false );
GEMM_UKR_SETUP_CT( d, m, 10, false );
__asm__ volatile (
" mov x0, xzr \n\t"
" ldr x1, %[m] \n\t"
" whilelo p0.d, x0, x1 \n\t" " incd x0 \n\t"
" whilelo p1.d, x0, x1 \n\t"
" \n\t"
" ldr x0, %[a] \n\t"
" ldr x1, %[b] \n\t"
" mov x2, xzr \n\t"
@@ -96,7 +100,6 @@ void bli_dgemm_armsve_asm_2vx10_unindexed
" madd x2, x8, x2, xzr \n\t" // cs_a
" madd x3, x8, x3, xzr \n\t" // rs_b
" madd x7, x8, x7, xzr \n\t" // cs_c
" ptrue p0.d \n\t"
" \n\t"
" ldr x4, %[k_mker] \n\t" // Number of loops.
" ldr x8, %[k_left] \n\t"
@@ -114,7 +117,7 @@ void bli_dgemm_armsve_asm_2vx10_unindexed
" ld1rd z26.d, p0/z, [x1, 48] \n\t"
" ld1rd z27.d, p0/z, [x1, 56] \n\t"
" \n\t"
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0)
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0)
" \n\t"
" CCOL_PRFM: \n\t"
// " cmp x6, #1 \n\t"
@@ -149,22 +152,22 @@ CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z1
" K_MKER_LOOP: \n\t"
" \n\t"
" add x0, x0, x2 \n\t" // Forward A's address to the next column.
GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0)
GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p1,x0)
GEMM_2VX10_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3)
" \n\t"
" add x0, x0, x2 \n\t" // Forward A's address to the next column.
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0)
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0)
GEMM_2VX10_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3)
" \n\t"
" add x0, x0, x2 \n\t" // Forward A's address to the next column.
GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0)
GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p1,x0)
GEMM_2VX10_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3)
" \n\t"
" subs x4, x4, #1 \n\t" // Decrease counter before final replica.
" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem.
" \n\t"
" add x0, x0, x2 \n\t" // Forward A's address to the next column.
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0)
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0)
GEMM_2VX10_MKER_LOOP_PLAIN_C_4(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3)
" b K_MKER_LOOP \n\t"
" \n\t"
@@ -176,7 +179,7 @@ GEMM_2VX10_MKER_LOOP_PLAIN_C_4_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3
" cmp x8, #0 \n\t" // End of execution.
" b.eq WRITE_MEM_PREP \n\t"
" \n\t"
GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0)
GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p1,x0)
" ld1rd z20.d, p0/z, [x1] \n\t" // Load 8/10 of first B row.
" ld1rd z21.d, p0/z, [x1, 8] \n\t"
" ld1rd z22.d, p0/z, [x1, 16] \n\t"
@@ -255,7 +258,7 @@ GEMM_FMLA2(z18,z19,p0,z30,z31,z29)
" \n\t" // C address for storing is x5 itself.
// " cmp x6, #1 \n\t" // Preload first half of C for contiguous case.
// " b.ne WRITE_MEM \n\t"
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x9,x7)
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7)
" \n\t"
" WRITE_MEM: \n\t"
" \n\t"
@@ -273,35 +276,16 @@ SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z1
" fcmp d31, #0.0 \n\t" // Skip loading if *beta == 0 to override NaN.
" b.eq BETA_ZERO_C \n\t"
// First half of C is already loaded in this case.
// GEMM_C_FMAD_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31,x9,x7)
// GEMM_C_FMAD_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31,x9,x7)
GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x9,x7)
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7)
GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
" \n\t"
" BETA_ZERO_C: \n\t"
GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x5,x7)
GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p0,x5,x7)
GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p1,x5,x7)
GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p1,x5,x7)
// " b END_WRITE_MEM \n\t"
// " \n\t"
// " WRITE_MEM_G: \n\t" // Available scratch: Z[20-30].
// " \n\t" // Here used scratch: Z[20-30] - Z30 as index.
// " mov x8, xzr \n\t"
// " incb x8 \n\t"
// " madd x8, x8, x6, xzr \n\t" // C-column's logical 1-vector skip.
// " index z30.d, xzr, x6 \n\t" // Skips passed to index is not multiplied by 8.
// " \n\t"
// " fcmp d31, #0.0 \n\t" // Skip loading if *beta == 0 to override NaN.
// " b.eq BETA_ZERO_G \n\t"
// " \n\t"
// GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x9,x7,x8,x16)
// GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
// GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x9,x7,x8,x16)
// GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
// " \n\t"
// " BETA_ZERO_G: \n\t"
// GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p0,x5,x7,x8,x16)
// GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x16)
// " \n\t"
// " END_WRITE_MEM: \n\t"
// " b END_EXEC \n\t"
// " \n\t"
@@ -310,7 +294,8 @@ GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p0,x5,x7)
" END_EXEC: \n\t"
" mov x0, #0 \n\t" // Return normal.
:
: [a] "m" (a),
: [m] "m" (m),
[a] "m" (a),
[b] "m" (b),
[c] "m" (c),
[rs_c] "m" (rs_c),

View File

@@ -67,10 +67,14 @@ void bli_sgemm_armsve_asm_2vx10_unindexed
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
uint64_t mr = bli_vl_bytes_armsve() * 2 / 4;
GEMM_UKR_SETUP_CT( s, mr, 10, false );
GEMM_UKR_SETUP_CT( s, m, 10, false );
__asm__ volatile (
" mov x0, xzr \n\t"
" ldr x1, %[m] \n\t"
" whilelo p0.s, x0, x1 \n\t" " incw x0 \n\t"
" whilelo p1.s, x0, x1 \n\t"
" \n\t"
" ldr x0, %[a] \n\t"
" ldr x1, %[b] \n\t"
" mov x2, xzr \n\t"
@@ -96,7 +100,6 @@ void bli_sgemm_armsve_asm_2vx10_unindexed
" madd x2, x8, x2, xzr \n\t" // cs_a
" madd x3, x8, x3, xzr \n\t" // rs_b
" madd x7, x8, x7, xzr \n\t" // cs_c
" ptrue p0.s \n\t"
" \n\t"
" ldr x4, %[k_mker] \n\t" // Number of loops.
" ldr x8, %[k_left] \n\t"
@@ -114,7 +117,7 @@ void bli_sgemm_armsve_asm_2vx10_unindexed
" ld1rw z26.s, p0/z, [x1, 24] \n\t"
" ld1rw z27.s, p0/z, [x1, 28] \n\t"
" \n\t"
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0)
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0)
" \n\t"
" CCOL_PRFM: \n\t"
// " cmp x6, #1 \n\t"
@@ -149,22 +152,22 @@ CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z1
" K_MKER_LOOP: \n\t"
" \n\t"
" add x0, x0, x2 \n\t" // Forward A's address to the next column.
GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0)
GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p1,x0)
GEMM_2VX10_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3)
" \n\t"
" add x0, x0, x2 \n\t" // Forward A's address to the next column.
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0)
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0)
GEMM_2VX10_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3)
" \n\t"
" add x0, x0, x2 \n\t" // Forward A's address to the next column.
GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0)
GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p1,x0)
GEMM_2VX10_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3)
" \n\t"
" subs x4, x4, #1 \n\t" // Decrease counter before final replica.
" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem.
" \n\t"
" add x0, x0, x2 \n\t" // Forward A's address to the next column.
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0)
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0)
GEMM_2VX10_MKER_LOOP_PLAIN_C_4(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3)
" b K_MKER_LOOP \n\t"
" \n\t"
@@ -176,7 +179,7 @@ GEMM_2VX10_MKER_LOOP_PLAIN_C_4_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3
" cmp x8, #0 \n\t" // End of execution.
" b.eq WRITE_MEM_PREP \n\t"
" \n\t"
GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0)
GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p1,x0)
" ld1rw z20.s, p0/z, [x1] \n\t" // Load 8/10 of first B row.
" ld1rw z21.s, p0/z, [x1, 4] \n\t"
" ld1rw z22.s, p0/z, [x1, 8] \n\t"
@@ -260,34 +263,16 @@ SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z1
" \n\t" // Here used scratch: Z[20-29].
" fcmp s31, #0.0 \n\t"
" b.eq BETA_ZERO_C \n\t"
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x9,x7)
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7)
GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x9,x7)
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7)
GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
" \n\t"
" BETA_ZERO_C: \n\t"
GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x5,x7)
GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p0,x5,x7)
GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p1,x5,x7)
GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p1,x5,x7)
// " b END_WRITE_MEM \n\t"
// " \n\t"
// " WRITE_MEM_G: \n\t" // Available scratch: Z[20-30].
// " \n\t" // Here used scratch: Z[20-30] - Z30 as index.
// " mov x8, xzr \n\t"
// " incb x8 \n\t"
// " madd x8, x8, x6, xzr \n\t" // C-column's logical 1-vector skip.
// " index z30.s, wzr, w6 \n\t" // Skips passed to index is not multiplied by 8.
// " \n\t"
// " fcmp s31, #0.0 \n\t"
// " b.eq BETA_ZERO_G \n\t"
// GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x9,x7,x8,x16)
// GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
// GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x9,x7,x8,x16)
// GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
// " \n\t"
// " BETA_ZERO_G: \n\t"
// GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p0,x5,x7,x8,x16)
// GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x16)
// " \n\t"
// " END_WRITE_MEM: \n\t"
// " b END_EXEC \n\t"
// " \n\t"
@@ -296,7 +281,8 @@ GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p0,x5,x7)
" END_EXEC: \n\t"
" mov x0, #0 \n\t" // Return normal.
:
: [a] "m" (a),
: [m] "m" (m),
[a] "m" (a),
[b] "m" (b),
[c] "m" (c),
[rs_c] "m" (rs_c),

View File

@@ -68,10 +68,10 @@ void bli_zgemm_armsve_asm_2vx10_unindexed
uint64_t cs_c = cs_c0;
uint64_t info = 0;
uint64_t mr = bli_vl_bytes_armsve() * 2 / 16;
GEMM_UKR_SETUP_CT( z, mr, 10, false );
GEMM_UKR_SETUP_CT( z, m, 10, false );
__asm__ volatile (
" whilelo p0.d, xzr, %12 \n\t"
// " ldr x0, %[a] \n\t"
// " ldr x1, %[b] \n\t"
" mov x2, xzr \n\t"
@@ -97,7 +97,6 @@ void bli_zgemm_armsve_asm_2vx10_unindexed
" madd x2, x16, x2, xzr \n\t" // cs_a
" madd x3, x16, x3, xzr \n\t" // rs_b
" madd %4, x16, %4, xzr \n\t" // cs_c
" ptrue p0.d \n\t"
" \n\t"
// " ldr x5, %[k_mker] \n\t" // Number of loops.
// " ldr x6, %[k_left] \n\t"
@@ -306,7 +305,7 @@ GEMM_CCMPLX_STORE_COL2_C(z8 ,z9 ,z10,z11,p0,%2,%4)
"+r" (a_next), // %9
"+r" (b_next), // %10
"=r" (info) // %11
:
: "r" (m) // %12
: "x2","x3","x9","x16",
"z0","z1","z2","z3","z4","z5","z6","z7",
"z8","z9","z10","z11","z12","z13","z14","z15",