Files
blis/kernels/armsve/1m/bli_dpackm_armsve512_asm_16xk.c
Field G. Van Zee e9da6425e2 Allow use of 1m with mixing of row/col-pref ukrs.
Details:
- Fixed a bug that broke the use of 1m for dcomplex when the single-
  precision real and double-precision real ukernels had opposing I/O
  preferences (row-preferential sgemm ukernel + column-preferential
  dgemm ukernel, or vice versa). The fix involved adjusting the API
  to bli_cntx_set_ind_blkszs() so that the induced method context init
  function (e.g., bli_cntx_init_<subconfig>_ind()) could call that
  function for only one datatype at a time. This allowed the blocksize
  scaling (which varies depending on whether we're doing 1m_r or 1m_c)
  to happen on a per-datatype basis. This fixes issue #557. Thanks to
  Devin Matthews and RuQing Xu for helping discover and report this bug.
- The aforementioned 1m fix required moving the 1m_r/1m_c logic from
  bli_cntx_ref.c into a new function, bli_l3_set_schemas(), which is
  called from each level-3 _front() function. The pack_t schemas in the
  cntx_t were also removed entirely, along with the associated accessor
  functions. This in turn required updating the trsm1m-related virtual
  ukernels to read the pack schema for B from the auxinfo_t struct
  rather than the context. This also required slight tweaks to
  bli_gemm_md.c.
- Repositioned the logic for transposing the operation to accommodate
  the microkernel IO preference. This mostly only affects gemm. Thanks
  to Devin Matthews for his help with this.
- Updated dpackm pack ukernels in the 'armsve' kernel set to avoid
  querying pack_t schemas from the context.
- Removed the num_t dt argument from the ind_cntx_init_ft type defined
  in bli_gks.c. The context initialization functions for induced methods
  were previously passed a dt argument, but I can no longer figure out
  *why* they were passed this value. To reduce confusion, I've removed
  the dt argument (including also from the function defintion +
  prototype).
- Commented out setting of cntx_t schemas in bli_cntx_ind_stage.c. This
  breaks high-leve implementations of 3m and 4m, but this is okay since
  those implementations will be removed very soon.
- Removed some older blocks of preprocessor-disabled code.
- Comment update to test_libblis.c.
2021-10-13 14:15:38 -05:00

363 lines
13 KiB
C

/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2021, The University of Tokyo
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "armsve512_asm_transpose_d8x8.h"
// assumption:
// SVE vector length = 512 bits.
void bli_dpackm_armsve512_asm_16xk
(
conj_t conja,
pack_t schema,
dim_t cdim_,
dim_t n_,
dim_t n_max_,
double* restrict kappa,
double* restrict a, inc_t inca_, inc_t lda_,
double* restrict p, inc_t ldp_,
cntx_t* restrict cntx
)
{
const int64_t cdim = cdim_;
const int64_t mnr = 16;
const int64_t n = n_;
const int64_t n_max = n_max_;
const int64_t inca = inca_;
const int64_t lda = lda_;
const int64_t ldp = ldp_;
const bool gs = inca != 1 && lda != 1;
const bool unitk = bli_deq1( *kappa );
#ifdef _A64FX
{
// Infer whether A or B is being packed.
if ( schema == BLIS_PACKED_ROWS )
p = ( (uint64_t)0x1 << 56 ) | (uint64_t)p;
if ( schema == BLIS_PACKED_COLUMNS )
p = ( (uint64_t)0x2 << 56 ) | (uint64_t)p;
}
#endif
if ( cdim == mnr && !gs && unitk )
{
uint64_t n_mker = n / 8;
uint64_t n_left = n % 8;
__asm__ volatile (
"mov x0, %[a] \n\t"
"mov x1, %[p] \n\t"
"mov x2, %[ldp] \n\t"
"mov x3, %[lda] \n\t"
"mov x4, %[inca] \n\t"
"cmp x4, #1 \n\t"
// Skips by sizeof(double).
"mov x8, #8 \n\t"
"madd x2, x2, x8, xzr \n\t"
"madd x3, x3, x8, xzr \n\t"
"madd x4, x4, x8, xzr \n\t"
// "mov x8, 0x8 \n\t" // Control#0 for A address.
// "mov x8, 0x24 \n\t" // Higher 6bit for Control#0:
// "lsl x8, x8, #58 \n\t" // Valid|Strong|Strong|Alloc|Load|Strong
// "orr x8, x8, x3 \n\t" // Stride.
// "msr S3_3_C11_C6_0, x8 \n\t" // Write system register.
// Loop constants.
"mov x8, %[n_mker] \n\t"
"mov x9, %[n_left] \n\t"
"ptrue p0.d \n\t"
"b.ne .AROWSTOR \n\t"
// A stored in columns.
" .ACOLSTOR: \n\t"
// Prefetch distance.
"mov x17, #8 \n\t"
"madd x17, x17, x3, xzr \n\t"
#ifdef _A64FX
"mov x16, 0x6 \n\t" // Disable hardware prefetch for A.
"lsl x16, x16, #60 \n\t"
"orr x0, x0, x16 \n\t"
#endif
// "add x5, x0, x3 \n\t"
// "add x6, x5, x3 \n\t"
// "add x7, x6, x3 \n\t"
// "prfm PLDL1STRM, [x0] \n\t"
// "prfm PLDL1STRM, [x5] \n\t"
// "prfm PLDL1STRM, [x6] \n\t"
// "prfm PLDL1STRM, [x7] \n\t"
// "add x18, x7, x3 \n\t"
// "add x5, x18, x3 \n\t"
// "add x6, x5, x3 \n\t"
// "add x7, x6, x3 \n\t"
// "prfm PLDL1STRM, [x18] \n\t"
// "prfm PLDL1STRM, [x5] \n\t"
// "prfm PLDL1STRM, [x6] \n\t"
// "prfm PLDL1STRM, [x7] \n\t"
" .ACOLSTORMKER: \n\t"
"cmp x8, xzr \n\t"
"b.eq .ACOLSTORMKEREND \n\t"
"add x5, x0, x3 \n\t"
"add x6, x5, x3 \n\t"
"add x7, x6, x3 \n\t"
"add x10, x1, x2 \n\t"
"add x11, x10, x2 \n\t"
"add x12, x11, x2 \n\t"
"add x13, x12, x2 \n\t"
"add x14, x13, x2 \n\t"
"add x15, x14, x2 \n\t"
"add x16, x15, x2 \n\t"
"ld1d z0.d, p0/z, [x0] \n\t"
"ld1d z1.d, p0/z, [x0, #1, mul vl] \n\t"
"ld1d z2.d, p0/z, [x5] \n\t"
"ld1d z3.d, p0/z, [x5, #1, mul vl] \n\t"
"ld1d z4.d, p0/z, [x6] \n\t"
"ld1d z5.d, p0/z, [x6, #1, mul vl] \n\t"
"ld1d z6.d, p0/z, [x7] \n\t"
"ld1d z7.d, p0/z, [x7, #1, mul vl] \n\t"
"add x18, x17, x0 \n\t"
"prfm PLDL1STRM, [x18] \n\t"
"add x18, x17, x5 \n\t"
"prfm PLDL1STRM, [x18] \n\t"
"add x18, x17, x6 \n\t"
"prfm PLDL1STRM, [x18] \n\t"
"add x18, x17, x7 \n\t"
"prfm PLDL1STRM, [x18] \n\t"
"add x0, x7, x3 \n\t"
"add x5, x0, x3 \n\t"
"add x6, x5, x3 \n\t"
"add x7, x6, x3 \n\t"
"ld1d z8.d, p0/z, [x0] \n\t"
"ld1d z9.d, p0/z, [x0, #1, mul vl] \n\t"
"ld1d z10.d, p0/z, [x5] \n\t"
"ld1d z11.d, p0/z, [x5, #1, mul vl] \n\t"
"ld1d z12.d, p0/z, [x6] \n\t"
"ld1d z13.d, p0/z, [x6, #1, mul vl] \n\t"
"ld1d z14.d, p0/z, [x7] \n\t"
"ld1d z15.d, p0/z, [x7, #1, mul vl] \n\t"
"add x18, x17, x0 \n\t"
"prfm PLDL1STRM, [x18] \n\t"
"add x18, x17, x5 \n\t"
"prfm PLDL1STRM, [x18] \n\t"
"add x18, x17, x6 \n\t"
"prfm PLDL1STRM, [x18] \n\t"
"add x18, x17, x7 \n\t"
"prfm PLDL1STRM, [x18] \n\t"
"st1d z0.d, p0, [x1] \n\t"
"st1d z1.d, p0, [x1, #1, mul vl] \n\t"
"st1d z2.d, p0, [x10] \n\t"
"st1d z3.d, p0, [x10, #1, mul vl] \n\t"
"st1d z4.d, p0, [x11] \n\t"
"st1d z5.d, p0, [x11, #1, mul vl] \n\t"
"st1d z6.d, p0, [x12] \n\t"
"st1d z7.d, p0, [x12, #1, mul vl] \n\t"
"st1d z8.d, p0, [x13] \n\t"
"st1d z9.d, p0, [x13, #1, mul vl] \n\t"
"st1d z10.d, p0, [x14] \n\t"
"st1d z11.d, p0, [x14, #1, mul vl] \n\t"
"st1d z12.d, p0, [x15] \n\t"
"st1d z13.d, p0, [x15, #1, mul vl] \n\t"
"st1d z14.d, p0, [x16] \n\t"
"st1d z15.d, p0, [x16, #1, mul vl] \n\t"
"add x0, x7, x3 \n\t"
"add x1, x16, x2 \n\t"
"sub x8, x8, #1 \n\t"
"b .ACOLSTORMKER \n\t"
" .ACOLSTORMKEREND: \n\t"
" .ACOLSTORLEFT: \n\t"
"cmp x9, xzr \n\t"
"b.eq .UNITKDONE \n\t"
"ld1d z0.d, p0/z, [x0] \n\t"
"ld1d z1.d, p0/z, [x0, #1, mul vl] \n\t"
"st1d z0.d, p0, [x1] \n\t"
"st1d z1.d, p0, [x1, #1, mul vl] \n\t"
"add x0, x0, x3 \n\t"
"add x1, x1, x2 \n\t"
"sub x9, x9, #1 \n\t"
"b .ACOLSTORLEFT \n\t"
// A stored in rows.
" .AROWSTOR: \n\t"
// Prepare predicates for in-reg transpose.
SVE512_IN_REG_TRANSPOSE_d8x8_PREPARE(x16,p0,p1,p2,p3,p8,p4,p6)
" .AROWSTORMKER: \n\t" // X[10-16] for A here not P. Be careful.
"cmp x8, xzr \n\t"
"b.eq .AROWSTORMKEREND \n\t"
"add x10, x0, x4 \n\t"
"add x11, x10, x4 \n\t"
"add x12, x11, x4 \n\t"
"add x13, x12, x4 \n\t"
"add x14, x13, x4 \n\t"
"add x15, x14, x4 \n\t"
"add x16, x15, x4 \n\t"
"ld1d z0.d, p0/z, [x0] \n\t"
"ld1d z1.d, p0/z, [x10] \n\t"
"ld1d z2.d, p0/z, [x11] \n\t"
"ld1d z3.d, p0/z, [x12] \n\t"
"ld1d z4.d, p0/z, [x13] \n\t"
"ld1d z5.d, p0/z, [x14] \n\t"
"ld1d z6.d, p0/z, [x15] \n\t"
"ld1d z7.d, p0/z, [x16] \n\t"
"add x5, x16, x4 \n\t"
"add x10, x5, x4 \n\t"
"add x11, x10, x4 \n\t"
"add x12, x11, x4 \n\t"
"add x13, x12, x4 \n\t"
"add x14, x13, x4 \n\t"
"add x15, x14, x4 \n\t"
"add x16, x15, x4 \n\t"
"ld1d z16.d, p0/z, [x5] \n\t"
"ld1d z17.d, p0/z, [x10] \n\t"
"ld1d z18.d, p0/z, [x11] \n\t"
"ld1d z19.d, p0/z, [x12] \n\t"
"ld1d z20.d, p0/z, [x13] \n\t"
"ld1d z21.d, p0/z, [x14] \n\t"
"ld1d z22.d, p0/z, [x15] \n\t"
"ld1d z23.d, p0/z, [x16] \n\t"
// Transpose first 8 rows.
SVE512_IN_REG_TRANSPOSE_d8x8(z8,z9,z10,z11,z12,z13,z14,z15,z0,z1,z2,z3,z4,z5,z6,z7,p0,p1,p2,p3,p8,p4,p6)
// Transpose last 8 rows.
SVE512_IN_REG_TRANSPOSE_d8x8(z24,z25,z26,z27,z28,z29,z30,z31,z16,z17,z18,z19,z20,z21,z22,z23,p0,p1,p2,p3,p8,p4,p6)
"add x10, x1, x2 \n\t"
"add x11, x10, x2 \n\t"
"add x12, x11, x2 \n\t"
"add x13, x12, x2 \n\t"
"add x14, x13, x2 \n\t"
"add x15, x14, x2 \n\t"
"add x16, x15, x2 \n\t"
"st1d z8.d, p0, [x1] \n\t"
"st1d z24.d, p0, [x1, #1, mul vl] \n\t"
"st1d z9.d, p0, [x10] \n\t"
"st1d z25.d, p0, [x10, #1, mul vl] \n\t"
"st1d z10.d, p0, [x11] \n\t"
"st1d z26.d, p0, [x11, #1, mul vl] \n\t"
"st1d z11.d, p0, [x12] \n\t"
"st1d z27.d, p0, [x12, #1, mul vl] \n\t"
"st1d z12.d, p0, [x13] \n\t"
"st1d z28.d, p0, [x13, #1, mul vl] \n\t"
"st1d z13.d, p0, [x14] \n\t"
"st1d z29.d, p0, [x14, #1, mul vl] \n\t"
"st1d z14.d, p0, [x15] \n\t"
"st1d z30.d, p0, [x15, #1, mul vl] \n\t"
"st1d z15.d, p0, [x16] \n\t"
"st1d z31.d, p0, [x16, #1, mul vl] \n\t"
"add x0, x0, #64 \n\t"
"add x1, x16, x2 \n\t"
"sub x8, x8, #1 \n\t"
"b .AROWSTORMKER \n\t"
" .AROWSTORMKEREND: \n\t"
"mov x4, %[inca] \n\t" // Restore unshifted inca.
"index z30.d, xzr, x4 \n\t" // Generate index.
"lsl x4, x4, #3 \n\t" // Shift again.
"lsl x5, x4, #3 \n\t" // Virtual column vl.
" .AROWSTORLEFT: \n\t"
"cmp x9, xzr \n\t"
"b.eq .UNITKDONE \n\t"
"add x6, x0, x5 \n\t"
"ld1d z0.d, p0/z, [x0, z30.d, lsl #3] \n\t"
"ld1d z1.d, p0/z, [x6, z30.d, lsl #3] \n\t"
"st1d z0.d, p0, [x1] \n\t"
"st1d z1.d, p0, [x1, #1, mul vl] \n\t"
"add x1, x1, x2 \n\t"
"add x0, x0, #8 \n\t"
"sub x9, x9, #1 \n\t"
"b .AROWSTORLEFT \n\t"
" .UNITKDONE: \n\t"
"mov x0, #0 \n\t"
:
: [a] "r" (a),
[p] "r" (p),
[lda] "r" (lda),
[ldp] "r" (ldp),
[inca] "r" (inca),
[n_mker] "r" (n_mker),
[n_left] "r" (n_left)
: "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x8", "x9", "x10","x11","x12","x13","x14","x15",
"x16","x17","x18",
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7",
"z8", "z9", "z10","z11","z12","z13","z14","z15",
// "z16","z17","z18","z19","z20","z21","z22","z23",
// "z24","z25","z26","z27","z28","z29","z30","z31",
"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7"
);
}
else // if ( cdim < mnr )
{
bli_dscal2m_ex
(
0,
BLIS_NONUNIT_DIAG,
BLIS_DENSE,
( trans_t )conja,
cdim,
n,
kappa,
a, inca, lda,
p, 1, ldp,
cntx,
NULL
);
// if ( cdim < mnr )
{
const dim_t i = cdim;
const dim_t m_edge = mnr - i;
const dim_t n_edge = n_max;
double* restrict p_edge = p + (i )*1;
bli_dset0s_mxn
(
m_edge,
n_edge,
p_edge, 1, ldp
);
}
}
if ( n < n_max )
{
const dim_t j = n;
const dim_t m_edge = mnr;
const dim_t n_edge = n_max - j;
double* restrict p_edge = p + (j )*ldp;
bli_dset0s_mxn
(
m_edge,
n_edge,
p_edge, 1, ldp
);
}
}