mirror of
https://github.com/amd/blis.git
synced 2026-04-25 01:58:51 +00:00
AVX512 optimizations for CGEMM(Native)
- Implemented the following AVX512 native computational kernels for CGEMM : Row-preferential : 4x24 Column-preferential : 24x4 - The implementations use a common set of macros, defined in a separate header. This is due to the fact that the implementations differ solely on the matrix chosen for load/broadcast operations. - Added the associated AVX512 based packing kernels, packing 24xk and 4xk panels of input. - Registered the column-preferential kernel(24x4) in ZEN4 and ZEN5 contexts. Further updated the cache-blocking parameters. - Removed redundant BLIS object creation and its contingencies in the native micro-kernel testing interface(for complex types). Added the required unit-tests for memory and functionality checks of the new kernels. AMD-Interal: [CPUPL-6498] Change-Id: I520ff17dba4c2f9bc277bf33ba9ab4384408ffe1
This commit is contained in:
committed by
Vignesh Balasubramanian
parent
6c29236166
commit
99770558bb
845
kernels/zen4/1m/bli_packm_zen4_asm_c24xk.c
Normal file
845
kernels/zen4/1m/bli_packm_zen4_asm_c24xk.c
Normal file
@@ -0,0 +1,845 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
|
||||
#define BLIS_ASM_SYNTAX_ATT
|
||||
#include "bli_x86_asm_macros.h"
|
||||
|
||||
/**
|
||||
* Shuffle 2 scomplex elements selected by imm8 from S1 and S2,
|
||||
* and store the results in D1
|
||||
* S1 : 1 9 3 11 5 13 7 15
|
||||
* S2 : 2 10 4 12 6 14 8 16
|
||||
* D1 : 1 9 5 13 2 10 6 14
|
||||
* D2 : 3 11 7 15 4 12 8 16
|
||||
*/
|
||||
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
|
||||
|
||||
/**
|
||||
* Unpacks and interleave low half and high half of each
|
||||
* 128-bit lane in S1 and S2 and store into D1 and D2
|
||||
* respectively.
|
||||
* S1 : 1 2 3 4 5 6 7 8
|
||||
* S2 : 9 10 11 12 13 14 15 16
|
||||
* D1 : 1 9 3 11 5 13 7 15
|
||||
* D2 : 2 10 4 12 6 14 8 16
|
||||
*/
|
||||
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
vunpcklpd(zmm(S1), zmm(S2), zmm(D1)) \
|
||||
vunpckhpd(zmm(S1), zmm(S2), zmm(D2)) \
|
||||
vunpcklpd(zmm(S3), zmm(S4), zmm(D3)) \
|
||||
vunpckhpd(zmm(S3), zmm(S4), zmm(D4))
|
||||
|
||||
void bli_cpackm_zen4_asm_24xk
|
||||
(
|
||||
conj_t conja,
|
||||
pack_t schema,
|
||||
dim_t cdim0,
|
||||
dim_t k0,
|
||||
dim_t k0_max,
|
||||
scomplex* restrict kappa,
|
||||
scomplex* restrict a, inc_t inca0, inc_t lda0,
|
||||
scomplex* restrict p, inc_t ldp0,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
// This is the panel dimension assumed by the packm kernel.
|
||||
const dim_t mnr = 24;
|
||||
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
// NOTE : k_iter is in blocks of 8, due to the SIMD width of scomplex
|
||||
// in a 512-register. This way, we could still perform AVX512 loads
|
||||
// and stores in case of the matrix being in row-major format.
|
||||
const uint64_t k_iter = k0 / 8;
|
||||
const uint64_t k_left = k0 % 8;
|
||||
/**
|
||||
* Preparing the mask for k_left, since we are computing in blocks of 8.
|
||||
* For the edge cases, mask is set to load and store only the leftover elements.
|
||||
*/
|
||||
uint16_t one = 1;
|
||||
uint16_t mask = ( one << ( 2 * k_left ) ) - one;
|
||||
|
||||
// NOTE: For the purposes of the comments in this packm kernel, we
|
||||
// interpret inca and lda as rs_a and cs_a, respectively, and similarly
|
||||
// interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading
|
||||
// this packm kernel, you should think of the operation as packing an
|
||||
// m x n micropanel, where m and n are tiny and large, respectively, and
|
||||
// where elements of each column of the packed matrix P are contiguous.
|
||||
// (This packm kernel can still be used to pack micropanels of matrix B
|
||||
// in a gemm operation.)
|
||||
const uint64_t inca = inca0;
|
||||
const uint64_t lda = lda0;
|
||||
const uint64_t ldp = ldp0;
|
||||
|
||||
const bool gs = ( inca0 != 1 && lda0 != 1 );
|
||||
|
||||
// NOTE: If/when this kernel ever supports scaling by kappa within the
|
||||
// assembly region, this constraint should be lifted.
|
||||
const bool unitk = bli_ceq1( *kappa );
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
if ( cdim0 == mnr && !gs && !conja && unitk )
|
||||
{
|
||||
begin_asm()
|
||||
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
|
||||
mov(var(a), rax) // load address of source buffer
|
||||
mov(var(a), r13) // load address of source buffer
|
||||
|
||||
mov(var(inca), r8) // load inca
|
||||
mov(var(lda), r10) // load lda
|
||||
lea(mem( , r8, 8), r8) // inca *= sizeof(scomplex)
|
||||
lea(mem( , r10, 8), r10) // lda *= sizeof(scomplex)
|
||||
|
||||
mov(var(p), rbx) // load address of p.
|
||||
|
||||
lea(mem( , r8, 8), r14) // r14 = 8*inca
|
||||
|
||||
cmp(imm(8), r8) // set ZF if (8*inca) == 8.
|
||||
jz(.CCOLUNIT) // jump to column storage case
|
||||
|
||||
// -- kappa unit, row storage on A -------------------------------
|
||||
|
||||
label(.CROWUNIT)
|
||||
|
||||
lea(mem(r8, r8, 2), r12) // r12 = 3*inca
|
||||
lea(mem(r12, r8, 2), rcx) // rcx = 5*inca
|
||||
lea(mem(r12, r8, 4), rdx) // rdx = 7*inca
|
||||
|
||||
mov(var(k_iter), rsi) // i = k_iter;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.CCONKLEFTROWU) // if i == 0, jump to code that
|
||||
// contains the k_left loop.
|
||||
label(.CKITERROWU) // MAIN LOOP (k_iter)
|
||||
|
||||
/* The 24x8 block in every iteration is broken down into 3
|
||||
sets of 8x8 packing in every iteration of the main loop */
|
||||
/* Load first 8 rows of matrix */
|
||||
vmovups(mem(rax, 0), zmm6)
|
||||
vmovups(mem(rax, r8, 1, 0), zmm8)
|
||||
vmovups(mem(rax, r8, 2, 0), zmm10)
|
||||
vmovups(mem(rax, r12, 1, 0), zmm12)
|
||||
vmovups(mem(rax, r8, 4, 0), zmm14)
|
||||
vmovups(mem(rax, rcx, 1, 0), zmm16)
|
||||
vmovups(mem(rax, r12, 2, 0), zmm18)
|
||||
vmovups(mem(rax, rdx, 1, 0), zmm20)
|
||||
|
||||
/* Transpose the 8x8 matrix onto another set of 8 registers */
|
||||
/*
|
||||
Input :
|
||||
zmm6 --> row1
|
||||
zmm8 --> row2
|
||||
zmm10 --> row3
|
||||
zmm12 --> row4
|
||||
zmm14 --> row5
|
||||
zmm16 --> row6
|
||||
zmm18 --> row7
|
||||
zmm20 --> row8
|
||||
|
||||
Output(after transpose) :
|
||||
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
|
||||
| | | | | | | |
|
||||
| | | | | | | |
|
||||
V V V V V V V V
|
||||
col1 col2 col3 col4 col5 col6 col7 col8
|
||||
|
||||
*/
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
/* Store the 8 rows post transpose */
|
||||
vmovups(zmm0, mem(rbx, 0*192))
|
||||
vmovups(zmm4, mem(rbx, 1*192))
|
||||
vmovups(zmm2, mem(rbx, 2*192))
|
||||
vmovups(zmm6, mem(rbx, 3*192))
|
||||
vmovups(zmm1, mem(rbx, 4*192))
|
||||
vmovups(zmm5, mem(rbx, 5*192))
|
||||
vmovups(zmm3, mem(rbx, 6*192))
|
||||
vmovups(zmm8, mem(rbx, 7*192))
|
||||
|
||||
add(r14, rax) // a += 8*inca;
|
||||
|
||||
/* Load another 8 rows of matrix */
|
||||
vmovups(mem(rax, 0), zmm6)
|
||||
vmovups(mem(rax, r8, 1, 0), zmm8)
|
||||
vmovups(mem(rax, r8, 2, 0), zmm10)
|
||||
vmovups(mem(rax, r12, 1, 0), zmm12)
|
||||
vmovups(mem(rax, r8, 4, 0), zmm14)
|
||||
vmovups(mem(rax, rcx, 1, 0), zmm16)
|
||||
vmovups(mem(rax, r12, 2, 0), zmm18)
|
||||
vmovups(mem(rax, rdx, 1, 0), zmm20)
|
||||
|
||||
/* Transpose the 8x8 matrix onto another set of 8 registers */
|
||||
/*
|
||||
Input :
|
||||
zmm6 --> row9
|
||||
zmm8 --> row10
|
||||
zmm10 --> row11
|
||||
zmm12 --> row12
|
||||
zmm14 --> row13
|
||||
zmm16 --> row14
|
||||
zmm18 --> row15
|
||||
zmm20 --> row16
|
||||
|
||||
Output(after transpose) :
|
||||
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
|
||||
| | | | | | | |
|
||||
| | | | | | | |
|
||||
V V V V V V V V
|
||||
col9 col10 col11 col12 col13 col14 col15 col16
|
||||
*/
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
/* Store the 8 rows post transpose */
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 64))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 64))
|
||||
vmovups(zmm2, mem(rbx, 2*192 + 64))
|
||||
vmovups(zmm6, mem(rbx, 3*192 + 64))
|
||||
vmovups(zmm1, mem(rbx, 4*192 + 64))
|
||||
vmovups(zmm5, mem(rbx, 5*192 + 64))
|
||||
vmovups(zmm3, mem(rbx, 6*192 + 64))
|
||||
vmovups(zmm8, mem(rbx, 7*192 + 64))
|
||||
|
||||
add(r14, rax) // a += 8*inca;
|
||||
|
||||
/* Load the final 8 rows of matrix */
|
||||
vmovups(mem(rax, 0), zmm6)
|
||||
vmovups(mem(rax, r8, 1, 0), zmm8)
|
||||
vmovups(mem(rax, r8, 2, 0), zmm10)
|
||||
vmovups(mem(rax, r12, 1, 0), zmm12)
|
||||
vmovups(mem(rax, r8, 4, 0), zmm14)
|
||||
vmovups(mem(rax, rcx, 1, 0), zmm16)
|
||||
vmovups(mem(rax, r12, 2, 0), zmm18)
|
||||
vmovups(mem(rax, rdx, 1, 0), zmm20)
|
||||
|
||||
/* Transpose the 8x8 matrix onto another set of 8 registers */
|
||||
/*
|
||||
Input :
|
||||
zmm6 --> row17
|
||||
zmm8 --> row18
|
||||
zmm10 --> row19
|
||||
zmm12 --> row20
|
||||
zmm14 --> row21
|
||||
zmm16 --> row22
|
||||
zmm18 --> row23
|
||||
zmm20 --> row24
|
||||
|
||||
Output(after transpose) :
|
||||
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
|
||||
| | | | | | | |
|
||||
| | | | | | | |
|
||||
V V V V V V V V
|
||||
col17 col18 col19 col20 col21 col22 col23 col24
|
||||
*/
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
/* Store the 8 rows post transpose */
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 128))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 128))
|
||||
vmovups(zmm2, mem(rbx, 2*192 + 128))
|
||||
vmovups(zmm6, mem(rbx, 3*192 + 128))
|
||||
vmovups(zmm1, mem(rbx, 4*192 + 128))
|
||||
vmovups(zmm5, mem(rbx, 5*192 + 128))
|
||||
vmovups(zmm3, mem(rbx, 6*192 + 128))
|
||||
vmovups(zmm8, mem(rbx, 7*192 + 128))
|
||||
|
||||
add(imm(8*8), r13)
|
||||
mov(r13, rax) // a += 8*8*lda
|
||||
add(imm(8*8*24), rbx) // p += 8*ldp
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.CKITERROWU) // iterate again if i != 0.
|
||||
|
||||
label(.CCONKLEFTROWU)
|
||||
|
||||
mov(var(k_left), rsi) // i = k_left;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.CDONE) // if i == 0, we're done; jump to end.
|
||||
// else, we prepare to enter k_left loop.
|
||||
|
||||
label(.CKLEFTROWU) // EDGE LOOP (k_left)
|
||||
|
||||
LABEL(.UPDATEL1)
|
||||
/* Move the first 8xk_left block of data */
|
||||
vmovups(mem(rax, 0), zmm6 MASK_KZ(2))
|
||||
vmovups(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
|
||||
vmovups(mem(rax, r8, 2, 0), zmm10 MASK_KZ(2))
|
||||
vmovups(mem(rax, r12, 1, 0), zmm12 MASK_KZ(2))
|
||||
vmovups(mem(rax, r8, 4, 0), zmm14 MASK_KZ(2))
|
||||
vmovups(mem(rax, rcx, 1, 0), zmm16 MASK_KZ(2))
|
||||
vmovups(mem(rax, r12, 2, 0), zmm18 MASK_KZ(2))
|
||||
vmovups(mem(rax, rdx, 1, 0), zmm20 MASK_KZ(2))
|
||||
|
||||
/* Transpose the 8x8 matrix onto another set of 8 registers */
|
||||
/*
|
||||
Input :
|
||||
zmm6 --> row1(masked loads)
|
||||
zmm8 --> row2(masked loads)
|
||||
zmm10 --> row3(masked loads)
|
||||
zmm12 --> row4(masked loads)
|
||||
zmm14 --> row5(masked loads)
|
||||
zmm16 --> row6(masked loads)
|
||||
zmm18 --> row7(masked loads)
|
||||
zmm20 --> row8(masked loads)
|
||||
|
||||
Output(after transpose) :
|
||||
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
|
||||
| | | | | | | |
|
||||
| | | | | | | |
|
||||
V V V V V V V V
|
||||
col1 col2 col3 col4 col5 col6 col7 col8
|
||||
|
||||
*/
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
add(r14, rax) // a += 8*inca
|
||||
|
||||
cmp(imm(7), rsi)
|
||||
JZ(.UPDATE7L1)
|
||||
cmp(imm(6), rsi)
|
||||
JZ(.UPDATE6L1)
|
||||
cmp(imm(5), rsi)
|
||||
JZ(.UPDATE5L1)
|
||||
cmp(imm(4), rsi)
|
||||
JZ(.UPDATE4L1)
|
||||
cmp(imm(3), rsi)
|
||||
JZ(.UPDATE3L1)
|
||||
cmp(imm(2), rsi)
|
||||
JZ(.UPDATE2L1)
|
||||
cmp(imm(1), rsi)
|
||||
JZ(.UPDATE1L1)
|
||||
|
||||
LABEL(.UPDATE7L1)
|
||||
// Update 8x7 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192))
|
||||
vmovups(zmm4, mem(rbx, 1*192))
|
||||
vmovups(zmm2, mem(rbx, 2*192))
|
||||
vmovups(zmm6, mem(rbx, 3*192))
|
||||
vmovups(zmm1, mem(rbx, 4*192))
|
||||
vmovups(zmm5, mem(rbx, 5*192))
|
||||
vmovups(zmm3, mem(rbx, 6*192))
|
||||
jmp(.UPDATEL2)
|
||||
|
||||
LABEL(.UPDATE6L1)
|
||||
// Update 8x6 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192))
|
||||
vmovups(zmm4, mem(rbx, 1*192))
|
||||
vmovups(zmm2, mem(rbx, 2*192))
|
||||
vmovups(zmm6, mem(rbx, 3*192))
|
||||
vmovups(zmm1, mem(rbx, 4*192))
|
||||
vmovups(zmm5, mem(rbx, 5*192))
|
||||
jmp(.UPDATEL2)
|
||||
|
||||
LABEL(.UPDATE5L1)
|
||||
// Update 8x5 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192))
|
||||
vmovups(zmm4, mem(rbx, 1*192))
|
||||
vmovups(zmm2, mem(rbx, 2*192))
|
||||
vmovups(zmm6, mem(rbx, 3*192))
|
||||
vmovups(zmm1, mem(rbx, 4*192))
|
||||
jmp(.UPDATEL2)
|
||||
|
||||
LABEL(.UPDATE4L1)
|
||||
// Update 8x4 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192))
|
||||
vmovups(zmm4, mem(rbx, 1*192))
|
||||
vmovups(zmm2, mem(rbx, 2*192))
|
||||
vmovups(zmm6, mem(rbx, 3*192))
|
||||
jmp(.UPDATEL2)
|
||||
|
||||
LABEL(.UPDATE3L1)
|
||||
// Update 8x3 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192))
|
||||
vmovups(zmm4, mem(rbx, 1*192))
|
||||
vmovups(zmm2, mem(rbx, 2*192))
|
||||
jmp(.UPDATEL2)
|
||||
|
||||
LABEL(.UPDATE2L1)
|
||||
// Update 8x2 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192))
|
||||
vmovups(zmm4, mem(rbx, 1*192))
|
||||
jmp(.UPDATEL2)
|
||||
|
||||
LABEL(.UPDATE1L1)
|
||||
// Update 8x1 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192))
|
||||
jmp(.UPDATEL2)
|
||||
|
||||
LABEL(.UPDATEL2)
|
||||
/* Move the next 8xk_left block of data */
|
||||
vmovups(mem(rax, 0), zmm6 MASK_KZ(2))
|
||||
vmovups(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
|
||||
vmovups(mem(rax, r8, 2, 0), zmm10 MASK_KZ(2))
|
||||
vmovups(mem(rax, r12, 1, 0), zmm12 MASK_KZ(2))
|
||||
vmovups(mem(rax, r8, 4, 0), zmm14 MASK_KZ(2))
|
||||
vmovups(mem(rax, rcx, 1, 0), zmm16 MASK_KZ(2))
|
||||
vmovups(mem(rax, r12, 2, 0), zmm18 MASK_KZ(2))
|
||||
vmovups(mem(rax, rdx, 1, 0), zmm20 MASK_KZ(2))
|
||||
|
||||
/* Transpose the 8x8 matrix onto another set of 8 registers */
|
||||
/*
|
||||
Input :
|
||||
zmm6 --> row9(masked loads)
|
||||
zmm8 --> row10(masked loads)
|
||||
zmm10 --> row11(masked loads)
|
||||
zmm12 --> row12(masked loads)
|
||||
zmm14 --> row13(masked loads)
|
||||
zmm16 --> row14(masked loads)
|
||||
zmm18 --> row15(masked loads)
|
||||
zmm20 --> row16(masked loads)
|
||||
|
||||
Output(after transpose) :
|
||||
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
|
||||
| | | | | | | |
|
||||
| | | | | | | |
|
||||
V V V V V V V V
|
||||
col9 col10 col11 col12 col13 col14 col15 col16
|
||||
|
||||
*/
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
add(r14, rax) // a += 8*inca
|
||||
|
||||
cmp(imm(7), rsi)
|
||||
JZ(.UPDATE7L2)
|
||||
cmp(imm(6), rsi)
|
||||
JZ(.UPDATE6L2)
|
||||
cmp(imm(5), rsi)
|
||||
JZ(.UPDATE5L2)
|
||||
cmp(imm(4), rsi)
|
||||
JZ(.UPDATE4L2)
|
||||
cmp(imm(3), rsi)
|
||||
JZ(.UPDATE3L2)
|
||||
cmp(imm(2), rsi)
|
||||
JZ(.UPDATE2L2)
|
||||
cmp(imm(1), rsi)
|
||||
JZ(.UPDATE1L2)
|
||||
|
||||
LABEL(.UPDATE7L2)
|
||||
// Update 8x7 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 64))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 64))
|
||||
vmovups(zmm2, mem(rbx, 2*192 + 64))
|
||||
vmovups(zmm6, mem(rbx, 3*192 + 64))
|
||||
vmovups(zmm1, mem(rbx, 4*192 + 64))
|
||||
vmovups(zmm5, mem(rbx, 5*192 + 64))
|
||||
vmovups(zmm3, mem(rbx, 6*192 + 64))
|
||||
jmp(.UPDATEL3)
|
||||
|
||||
LABEL(.UPDATE6L2)
|
||||
// Update 8x6 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 64))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 64))
|
||||
vmovups(zmm2, mem(rbx, 2*192 + 64))
|
||||
vmovups(zmm6, mem(rbx, 3*192 + 64))
|
||||
vmovups(zmm1, mem(rbx, 4*192 + 64))
|
||||
vmovups(zmm5, mem(rbx, 5*192 + 64))
|
||||
jmp(.UPDATEL3)
|
||||
|
||||
LABEL(.UPDATE5L2)
|
||||
// Update 8x5 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 64))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 64))
|
||||
vmovups(zmm2, mem(rbx, 2*192 + 64))
|
||||
vmovups(zmm6, mem(rbx, 3*192 + 64))
|
||||
vmovups(zmm1, mem(rbx, 4*192 + 64))
|
||||
jmp(.UPDATEL3)
|
||||
|
||||
LABEL(.UPDATE4L2)
|
||||
// Update 8x4 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 64))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 64))
|
||||
vmovups(zmm2, mem(rbx, 2*192 + 64))
|
||||
vmovups(zmm6, mem(rbx, 3*192 + 64))
|
||||
jmp(.UPDATEL3)
|
||||
|
||||
LABEL(.UPDATE3L2)
|
||||
// Update 8x3 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 64))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 64))
|
||||
vmovups(zmm2, mem(rbx, 2*192 + 64))
|
||||
jmp(.UPDATEL3)
|
||||
|
||||
LABEL(.UPDATE2L2)
|
||||
// Update 8x2 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 64))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 64))
|
||||
jmp(.UPDATEL3)
|
||||
|
||||
LABEL(.UPDATE1L2)
|
||||
// Update 8x1 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 64))
|
||||
jmp(.UPDATEL3)
|
||||
|
||||
LABEL(.UPDATEL3)
|
||||
/* Move the next 8xk_left block of data */
|
||||
vmovups(mem(rax, 0), zmm6 MASK_KZ(2))
|
||||
vmovups(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
|
||||
vmovups(mem(rax, r8, 2, 0), zmm10 MASK_KZ(2))
|
||||
vmovups(mem(rax, r12, 1, 0), zmm12 MASK_KZ(2))
|
||||
vmovups(mem(rax, r8, 4, 0), zmm14 MASK_KZ(2))
|
||||
vmovups(mem(rax, rcx, 1, 0), zmm16 MASK_KZ(2))
|
||||
vmovups(mem(rax, r12, 2, 0), zmm18 MASK_KZ(2))
|
||||
vmovups(mem(rax, rdx, 1, 0), zmm20 MASK_KZ(2))
|
||||
|
||||
/* Transpose the 8x8 matrix onto another set of 8 registers */
|
||||
/*
|
||||
Input :
|
||||
zmm6 --> row16(masked loads)
|
||||
zmm8 --> row17(masked loads)
|
||||
zmm10 --> row18(masked loads)
|
||||
zmm12 --> row19(masked loads)
|
||||
zmm14 --> row20(masked loads)
|
||||
zmm16 --> row21(masked loads)
|
||||
zmm18 --> row22(masked loads)
|
||||
zmm20 --> row23(masked loads)
|
||||
|
||||
Output(after transpose) :
|
||||
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
|
||||
| | | | | | | |
|
||||
| | | | | | | |
|
||||
V V V V V V V V
|
||||
col16 col17 col18 col19 col20 col21 col22 col23
|
||||
|
||||
*/
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
cmp(imm(7), rsi)
|
||||
JZ(.UPDATE7L3)
|
||||
cmp(imm(6), rsi)
|
||||
JZ(.UPDATE6L3)
|
||||
cmp(imm(5), rsi)
|
||||
JZ(.UPDATE5L3)
|
||||
cmp(imm(4), rsi)
|
||||
JZ(.UPDATE4L3)
|
||||
cmp(imm(3), rsi)
|
||||
JZ(.UPDATE3L3)
|
||||
cmp(imm(2), rsi)
|
||||
JZ(.UPDATE2L3)
|
||||
cmp(imm(1), rsi)
|
||||
JZ(.UPDATE1L3)
|
||||
|
||||
LABEL(.UPDATE7L3)
|
||||
// Update 8x7 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 128))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 128))
|
||||
vmovups(zmm2, mem(rbx, 2*192 + 128))
|
||||
vmovups(zmm6, mem(rbx, 3*192 + 128))
|
||||
vmovups(zmm1, mem(rbx, 4*192 + 128))
|
||||
vmovups(zmm5, mem(rbx, 5*192 + 128))
|
||||
vmovups(zmm3, mem(rbx, 6*192 + 128))
|
||||
jmp(.CDONE)
|
||||
|
||||
LABEL(.UPDATE6L3)
|
||||
// Update 8x6 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 128))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 128))
|
||||
vmovups(zmm2, mem(rbx, 2*192 + 128))
|
||||
vmovups(zmm6, mem(rbx, 3*192 + 128))
|
||||
vmovups(zmm1, mem(rbx, 4*192 + 128))
|
||||
vmovups(zmm5, mem(rbx, 5*192 + 128))
|
||||
jmp(.CDONE)
|
||||
|
||||
LABEL(.UPDATE5L3)
|
||||
// Update 8x5 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 128))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 128))
|
||||
vmovups(zmm2, mem(rbx, 2*192 + 128))
|
||||
vmovups(zmm6, mem(rbx, 3*192 + 128))
|
||||
vmovups(zmm1, mem(rbx, 4*192 + 128))
|
||||
jmp(.CDONE)
|
||||
|
||||
LABEL(.UPDATE4L3)
|
||||
// Update 8x4 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 128))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 128))
|
||||
vmovups(zmm2, mem(rbx, 2*192 + 128))
|
||||
vmovups(zmm6, mem(rbx, 3*192 + 128))
|
||||
jmp(.CDONE)
|
||||
|
||||
LABEL(.UPDATE3L3)
|
||||
// Update 8x3 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 128))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 128))
|
||||
vmovups(zmm2, mem(rbx, 2*192 + 128))
|
||||
jmp(.CDONE)
|
||||
|
||||
LABEL(.UPDATE2L3)
|
||||
// Update 8x2 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 128))
|
||||
vmovups(zmm4, mem(rbx, 1*192 + 128))
|
||||
jmp(.CDONE)
|
||||
|
||||
LABEL(.UPDATE1L3)
|
||||
// Update 8x1 tile to destination buffer
|
||||
vmovups(zmm0, mem(rbx, 0*192 + 128))
|
||||
jmp(.CDONE)
|
||||
|
||||
// -- column storage on A --------------------------------------
|
||||
|
||||
label(.CCOLUNIT)
|
||||
|
||||
mov(var(ldp), r8) // load ldp
|
||||
lea(mem(, r8, 8), r8) // r8 *= sizeof(scomplex)
|
||||
|
||||
mov(var(k_iter), rsi) // i = k_iter;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.CCONKLEFTCOLU) // if i == 0, jump to code that
|
||||
// contains the k_left loop.
|
||||
|
||||
label(.CKITERCOLU) // MAIN LOOP (k_iter)
|
||||
|
||||
/* Load and store 3 ZMM registers(24 elements) */
|
||||
/* Unroll-1 */
|
||||
vmovups(mem(rax), zmm6)
|
||||
vmovups(mem(rax, 64), zmm8)
|
||||
vmovups(mem(rax, 128), zmm10)
|
||||
vmovups(zmm6, mem(rbx, 0))
|
||||
vmovups(zmm8, mem(rbx, 64))
|
||||
vmovups(zmm10, mem(rbx, 128))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-2 */
|
||||
vmovups(mem(rax), zmm6)
|
||||
vmovups(mem(rax, 64), zmm8)
|
||||
vmovups(mem(rax, 128), zmm10)
|
||||
vmovups(zmm6, mem(rbx, 0))
|
||||
vmovups(zmm8, mem(rbx, 64))
|
||||
vmovups(zmm10, mem(rbx, 128))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-3 */
|
||||
vmovups(mem(rax), zmm6)
|
||||
vmovups(mem(rax, 64), zmm8)
|
||||
vmovups(mem(rax, 128), zmm10)
|
||||
vmovups(zmm6, mem(rbx, 0))
|
||||
vmovups(zmm8, mem(rbx, 64))
|
||||
vmovups(zmm10, mem(rbx, 128))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-4 */
|
||||
vmovups(mem(rax), zmm6)
|
||||
vmovups(mem(rax, 64), zmm8)
|
||||
vmovups(mem(rax, 128), zmm10)
|
||||
vmovups(zmm6, mem(rbx, 0))
|
||||
vmovups(zmm8, mem(rbx, 64))
|
||||
vmovups(zmm10, mem(rbx, 128))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-5 */
|
||||
vmovups(mem(rax), zmm6)
|
||||
vmovups(mem(rax, 64), zmm8)
|
||||
vmovups(mem(rax, 128), zmm10)
|
||||
vmovups(zmm6, mem(rbx, 0))
|
||||
vmovups(zmm8, mem(rbx, 64))
|
||||
vmovups(zmm10, mem(rbx, 128))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-6 */
|
||||
vmovups(mem(rax), zmm6)
|
||||
vmovups(mem(rax, 64), zmm8)
|
||||
vmovups(mem(rax, 128), zmm10)
|
||||
vmovups(zmm6, mem(rbx, 0))
|
||||
vmovups(zmm8, mem(rbx, 64))
|
||||
vmovups(zmm10, mem(rbx, 128))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-7 */
|
||||
vmovups(mem(rax), zmm6)
|
||||
vmovups(mem(rax, 64), zmm8)
|
||||
vmovups(mem(rax, 128), zmm10)
|
||||
vmovups(zmm6, mem(rbx, 0))
|
||||
vmovups(zmm8, mem(rbx, 64))
|
||||
vmovups(zmm10, mem(rbx, 128))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-8 */
|
||||
vmovups(mem(rax), zmm6)
|
||||
vmovups(mem(rax, 64), zmm8)
|
||||
vmovups(mem(rax, 128), zmm10)
|
||||
vmovups(zmm6, mem(rbx, 0))
|
||||
vmovups(zmm8, mem(rbx, 64))
|
||||
vmovups(zmm10, mem(rbx, 128))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.CKITERCOLU) // iterate again if i != 0.
|
||||
|
||||
label(.CCONKLEFTCOLU)
|
||||
|
||||
mov(var(k_left), rsi) // i = k_left;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.CDONE) // if i == 0, we're done; jump to end.
|
||||
// else, we prepare to enter k_left loop.
|
||||
|
||||
label(.CKLEFTCOLU) // EDGE LOOP (k_left)
|
||||
|
||||
/* Load and store 3 ZMM registers(24 elements) */
|
||||
vmovups(mem(rax), zmm6)
|
||||
vmovups(mem(rax, 64), zmm8)
|
||||
vmovups(mem(rax, 128), zmm10)
|
||||
vmovups(zmm6, mem(rbx, 0))
|
||||
vmovups(zmm8, mem(rbx, 64))
|
||||
vmovups(zmm10, mem(rbx, 128))
|
||||
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.CKLEFTCOLU) // iterate again if i != 0.
|
||||
|
||||
label(.CDONE)
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
: // input operands
|
||||
[mask] "m" (mask),
|
||||
[k_iter] "m" (k_iter),
|
||||
[k_left] "m" (k_left),
|
||||
[a] "m" (a),
|
||||
[inca] "m" (inca),
|
||||
[lda] "m" (lda),
|
||||
[p] "m" (p),
|
||||
[ldp] "m" (ldp)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi",
|
||||
"r8", "r10", "r12", "r13", "r14",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"zmm0", "zmm1", "zmm2", "zmm3",
|
||||
"zmm4", "zmm5", "zmm6", "zmm7",
|
||||
"zmm8", "zmm10", "zmm12", "zmm14",
|
||||
"zmm16", "zmm18", "zmm20", "zmm30",
|
||||
"zmm31", "k2", "memory"
|
||||
)
|
||||
}
|
||||
else // if ( cdim0 < mnr || gs || bli_does_conj( conja ) || !unitk )
|
||||
{
|
||||
PASTEMAC(cscal2m,BLIS_TAPI_EX_SUF)
|
||||
(
|
||||
0,
|
||||
BLIS_NONUNIT_DIAG,
|
||||
BLIS_DENSE,
|
||||
( trans_t )conja,
|
||||
cdim0,
|
||||
k0,
|
||||
kappa,
|
||||
a, inca0, lda0,
|
||||
p, 1, ldp0,
|
||||
cntx,
|
||||
NULL
|
||||
);
|
||||
|
||||
if ( cdim0 < mnr )
|
||||
{
|
||||
// Handle zero-filling along the "long" edge of the micropanel.
|
||||
|
||||
const dim_t i = cdim0;
|
||||
const dim_t m_edge = mnr - cdim0;
|
||||
const dim_t n_edge = k0_max;
|
||||
scomplex* restrict p_edge = p + (i )*1;
|
||||
|
||||
bli_cset0s_mxn
|
||||
(
|
||||
m_edge,
|
||||
n_edge,
|
||||
p_edge, 1, ldp
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if ( k0 < k0_max )
|
||||
{
|
||||
// Handle zero-filling along the "short" (far) edge of the micropanel.
|
||||
|
||||
const dim_t j = k0;
|
||||
const dim_t m_edge = mnr;
|
||||
const dim_t n_edge = k0_max - k0;
|
||||
scomplex* restrict p_edge = p + (j )*ldp;
|
||||
|
||||
bli_cset0s_mxn
|
||||
(
|
||||
m_edge,
|
||||
n_edge,
|
||||
p_edge, 1, ldp
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
498
kernels/zen4/1m/bli_packm_zen4_asm_c4xk.c
Normal file
498
kernels/zen4/1m/bli_packm_zen4_asm_c4xk.c
Normal file
@@ -0,0 +1,498 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
|
||||
#define BLIS_ASM_SYNTAX_ATT
|
||||
#include "bli_x86_asm_macros.h"
|
||||
|
||||
/**
|
||||
* Shuffle 2 scomplex elements selected by imm8 from S1 and S2,
|
||||
* and store the results in D1
|
||||
* S1 : 1 9 3 11 5 13 7 15
|
||||
* S2 : 2 10 4 12 6 14 8 16
|
||||
* D1 : 1 9 5 13 2 10 6 14
|
||||
* D2 : 3 11 7 15 4 12 8 16
|
||||
*/
|
||||
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
|
||||
|
||||
/**
|
||||
* Unpacks and interleave low half and high half of each
|
||||
* 128-bit lane in S1 and S2 and store into D1 and D2
|
||||
* respectively.
|
||||
* S1 : 1 2 3 4 5 6 7 8
|
||||
* S2 : 9 10 11 12 13 14 15 16
|
||||
* D1 : 1 9 3 11 5 13 7 15
|
||||
* D2 : 2 10 4 12 6 14 8 16
|
||||
*/
|
||||
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
|
||||
\
|
||||
vunpcklpd(zmm(S1), zmm(S2), zmm(D1)) \
|
||||
vunpckhpd(zmm(S1), zmm(S2), zmm(D2)) \
|
||||
vunpcklpd(zmm(S3), zmm(S4), zmm(D3)) \
|
||||
vunpckhpd(zmm(S3), zmm(S4), zmm(D4))
|
||||
|
||||
void bli_cpackm_zen4_asm_4xk
|
||||
(
|
||||
conj_t conja,
|
||||
pack_t schema,
|
||||
dim_t cdim0,
|
||||
dim_t k0,
|
||||
dim_t k0_max,
|
||||
scomplex* restrict kappa,
|
||||
scomplex* restrict a, inc_t inca0, inc_t lda0,
|
||||
scomplex* restrict p, inc_t ldp0,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
// This is the panel dimension assumed by the packm kernel.
|
||||
const dim_t mnr = 4;
|
||||
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
// NOTE : k_iter is in blocks of 8, due to the SIMD width of scomplex
|
||||
// in a 512-register. This way, we could still perform AVX512 loads
|
||||
// and stores in case of the matrix being in row-major format.
|
||||
const uint64_t k_iter = k0 / 8;
|
||||
const uint64_t k_left = k0 % 8;
|
||||
/**
|
||||
* Preparing the mask for k_left, since we are computing in blocks of 8.
|
||||
* For the edge cases, mask is set to load and store only the leftover elements.
|
||||
*/
|
||||
uint16_t one = 1;
|
||||
uint16_t mask = ( one << ( 2 * k_left ) ) - one;
|
||||
|
||||
// NOTE: For the purposes of the comments in this packm kernel, we
|
||||
// interpret inca and lda as rs_a and cs_a, respectively, and similarly
|
||||
// interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading
|
||||
// this packm kernel, you should think of the operation as packing an
|
||||
// m x n micropanel, where m and n are tiny and large, respectively, and
|
||||
// where elements of each column of the packed matrix P are contiguous.
|
||||
// (This packm kernel can still be used to pack micropanels of matrix B
|
||||
// in a gemm operation.)
|
||||
const uint64_t inca = inca0;
|
||||
const uint64_t lda = lda0;
|
||||
const uint64_t ldp = ldp0;
|
||||
|
||||
const bool gs = ( inca0 != 1 && lda0 != 1 );
|
||||
|
||||
// NOTE: If/when this kernel ever supports scaling by kappa within the
|
||||
// assembly region, this constraint should be lifted.
|
||||
const bool unitk = bli_ceq1( *kappa );
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
if ( cdim0 == mnr && !gs && !conja && unitk )
|
||||
{
|
||||
begin_asm()
|
||||
|
||||
mov(var(mask), rdx) // load mask
|
||||
kmovw(edx, k(2)) // move mask to k2 register
|
||||
|
||||
mov(var(a), rax) // load address of source buffer
|
||||
mov(var(a), r13) // load address of source buffer
|
||||
|
||||
mov(var(inca), r8) // load inca
|
||||
mov(var(lda), r10) // load lda
|
||||
lea(mem( , r8, 8), r8) // inca *= sizeof(scomplex)
|
||||
lea(mem( , r10, 8), r10) // lda *= sizeof(scomplex)
|
||||
|
||||
mov(var(p), rbx) // load address of p.
|
||||
|
||||
lea(mem( , r8, 8), r14) // r14 = 8*inca
|
||||
|
||||
cmp(imm(8), r8) // set ZF if (8*inca) == 8.
|
||||
jz(.CCOLUNIT) // jump to column storage case
|
||||
|
||||
// -- kappa unit, row storage on A -------------------------------
|
||||
|
||||
label(.CROWUNIT)
|
||||
|
||||
lea(mem(r8, r8, 2), r12) // r12 = 3*inca
|
||||
|
||||
mov(var(k_iter), rsi) // i = k_iter;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.CCONKLEFTROWU) // if i == 0, jump to code that
|
||||
// contains the k_left loop.
|
||||
label(.CKITERROWU) // MAIN LOOP (k_iter)
|
||||
|
||||
/**
|
||||
* Load first 4 rows of matrix.
|
||||
* Set 4 additional registers to zero
|
||||
* Transpose 8x8(by extending 4x8 with 0.0 padding ) tile
|
||||
and store it back to destination buffer.
|
||||
*/
|
||||
vmovups(mem(rax, 0), zmm6)
|
||||
vmovups(mem(rax, r8, 1, 0), zmm8)
|
||||
vmovups(mem(rax, r8, 2, 0), zmm10)
|
||||
vmovups(mem(rax, r12, 1, 0), zmm12)
|
||||
vxorps(zmm14, zmm14, zmm14)
|
||||
vxorps(zmm16, zmm16, zmm16)
|
||||
vxorps(zmm18, zmm18, zmm18)
|
||||
vxorps(zmm20, zmm20, zmm20)
|
||||
|
||||
/* Transpose the 8x8 matrix onto another set of 8 registers */
|
||||
/*
|
||||
Input :
|
||||
zmm6 --> row1
|
||||
zmm8 --> row2
|
||||
zmm10 --> row3
|
||||
zmm12 --> row4
|
||||
zmm14 --> 0.0f
|
||||
zmm16 --> 0.0f
|
||||
zmm18 --> 0.0f
|
||||
zmm20 --> 0.0f
|
||||
|
||||
Output(after transpose) :
|
||||
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
|
||||
| | | | | | | |
|
||||
| | | | | | | |
|
||||
V V V V V V V V
|
||||
col1 col2 col3 col4 col5 col6 col7 col8
|
||||
|
||||
Every column(register) will have the last 256-bit lane as 0.0f
|
||||
Thus, we only store the YMM registers to the destination buffer.
|
||||
|
||||
*/
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
/* Store the 256-bit lanes(YMM) onto the destination */
|
||||
vmovups(ymm0, mem(rbx, 0*32))
|
||||
vmovups(ymm4, mem(rbx, 1*32))
|
||||
vmovups(ymm2, mem(rbx, 2*32))
|
||||
vmovups(ymm6, mem(rbx, 3*32))
|
||||
vmovups(ymm1, mem(rbx, 4*32))
|
||||
vmovups(ymm5, mem(rbx, 5*32))
|
||||
vmovups(ymm3, mem(rbx, 6*32))
|
||||
vmovups(ymm8, mem(rbx, 7*32))
|
||||
|
||||
add(imm(8*8), r13)
|
||||
mov(r13, rax) // a += 8*8*lda
|
||||
add(imm(8*8*4), rbx) // p += 8*ldp
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.CKITERROWU) // iterate again if i != 0.
|
||||
|
||||
label(.CCONKLEFTROWU)
|
||||
|
||||
mov(var(k_left), rsi) // i = k_left;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.CDONE) // if i == 0, we're done; jump to end.
|
||||
// else, we prepare to enter k_left loop.
|
||||
|
||||
label(.CKLEFTROWU) // EDGE LOOP (k_left)
|
||||
|
||||
LABEL(.UPDATEL1)
|
||||
/* Move the first 4xk_left block of data */
|
||||
vmovups(mem(rax, 0), zmm6 MASK_KZ(2))
|
||||
vmovups(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
|
||||
vmovups(mem(rax, r8, 2, 0), zmm10 MASK_KZ(2))
|
||||
vmovups(mem(rax, r12, 1, 0), zmm12 MASK_KZ(2))
|
||||
vxorps(zmm14, zmm14, zmm14)
|
||||
vxorps(zmm16, zmm16, zmm16)
|
||||
vxorps(zmm18, zmm18, zmm18)
|
||||
vxorps(zmm20, zmm20, zmm20)
|
||||
|
||||
/*
|
||||
Input :
|
||||
zmm6 --> row1(masked loads)
|
||||
zmm8 --> row2(masked loads)
|
||||
zmm10 --> row3(masked loads)
|
||||
zmm12 --> row4(masked loads)
|
||||
zmm14 --> 0.0f
|
||||
zmm16 --> 0.0f
|
||||
zmm18 --> 0.0f
|
||||
zmm20 --> 0.0f
|
||||
|
||||
Output(after transpose) :
|
||||
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
|
||||
| | | | | | | |
|
||||
| | | | | | | |
|
||||
V V V V V V V V
|
||||
col1 col2 col3 col4 col5 col6 col7 col8
|
||||
|
||||
Every column(register) will have the last 256-bit lane as 0.0f
|
||||
Thus, we only store the YMM registers to the destination buffer.
|
||||
|
||||
*/
|
||||
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
|
||||
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
|
||||
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
|
||||
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
|
||||
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
|
||||
|
||||
cmp(imm(7), rsi)
|
||||
JZ(.UPDATE7L1)
|
||||
cmp(imm(6), rsi)
|
||||
JZ(.UPDATE6L1)
|
||||
cmp(imm(5), rsi)
|
||||
JZ(.UPDATE5L1)
|
||||
cmp(imm(4), rsi)
|
||||
JZ(.UPDATE4L1)
|
||||
cmp(imm(3), rsi)
|
||||
JZ(.UPDATE3L1)
|
||||
cmp(imm(2), rsi)
|
||||
JZ(.UPDATE2L1)
|
||||
cmp(imm(1), rsi)
|
||||
JZ(.UPDATE1L1)
|
||||
|
||||
LABEL(.UPDATE7L1)
|
||||
// Update 4x7 tile to destination buffer
|
||||
vmovups(ymm0, mem(rbx, 0*32))
|
||||
vmovups(ymm4, mem(rbx, 1*32))
|
||||
vmovups(ymm2, mem(rbx, 2*32))
|
||||
vmovups(ymm6, mem(rbx, 3*32))
|
||||
vmovups(ymm1, mem(rbx, 4*32))
|
||||
vmovups(ymm5, mem(rbx, 5*32))
|
||||
vmovups(ymm3, mem(rbx, 6*32))
|
||||
jmp(.CDONE)
|
||||
|
||||
LABEL(.UPDATE6L1)
|
||||
// Update 4x6 tile to destination buffer
|
||||
vmovups(ymm0, mem(rbx, 0*32))
|
||||
vmovups(ymm4, mem(rbx, 1*32))
|
||||
vmovups(ymm2, mem(rbx, 2*32))
|
||||
vmovups(ymm6, mem(rbx, 3*32))
|
||||
vmovups(ymm1, mem(rbx, 4*32))
|
||||
vmovups(ymm5, mem(rbx, 5*32))
|
||||
jmp(.CDONE)
|
||||
|
||||
LABEL(.UPDATE5L1)
|
||||
// Update 4x5 tile to destination buffer
|
||||
vmovups(ymm0, mem(rbx, 0*32))
|
||||
vmovups(ymm4, mem(rbx, 1*32))
|
||||
vmovups(ymm2, mem(rbx, 2*32))
|
||||
vmovups(ymm6, mem(rbx, 3*32))
|
||||
vmovups(ymm1, mem(rbx, 4*32))
|
||||
jmp(.CDONE)
|
||||
|
||||
LABEL(.UPDATE4L1)
|
||||
// Update 4x4 tile to destination buffer
|
||||
vmovups(ymm0, mem(rbx, 0*32))
|
||||
vmovups(ymm4, mem(rbx, 1*32))
|
||||
vmovups(ymm2, mem(rbx, 2*32))
|
||||
vmovups(ymm6, mem(rbx, 3*32))
|
||||
jmp(.CDONE)
|
||||
|
||||
LABEL(.UPDATE3L1)
|
||||
// Update 4x3 tile to destination buffer
|
||||
vmovups(ymm0, mem(rbx, 0*32))
|
||||
vmovups(ymm4, mem(rbx, 1*32))
|
||||
vmovups(ymm2, mem(rbx, 2*32))
|
||||
jmp(.CDONE)
|
||||
|
||||
LABEL(.UPDATE2L1)
|
||||
// Update 4x2 tile to destination buffer
|
||||
vmovups(ymm0, mem(rbx, 0*32))
|
||||
vmovups(ymm4, mem(rbx, 1*32))
|
||||
jmp(.CDONE)
|
||||
|
||||
LABEL(.UPDATE1L1)
|
||||
// Update 4x1 tile to destination buffer
|
||||
vmovups(ymm0, mem(rbx, 0*32))
|
||||
jmp(.CDONE)
|
||||
|
||||
// -- column storage on A --------------------------------------
|
||||
|
||||
label(.CCOLUNIT)
|
||||
|
||||
mov(var(ldp), r8) // load ldp
|
||||
lea(mem(, r8, 8), r8) // r8 *= sizeof(scomplex)
|
||||
|
||||
mov(var(k_iter), rsi) // i = k_iter;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.CCONKLEFTCOLU) // if i == 0, jump to code that
|
||||
// contains the k_left loop.
|
||||
|
||||
label(.CKITERCOLU) // MAIN LOOP (k_iter)
|
||||
|
||||
/* Load/store a column of C using YMM regsiters */
|
||||
/* Unroll-1 */
|
||||
vmovups(mem(rax), ymm6)
|
||||
vmovups(ymm6, mem(rbx))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-2 */
|
||||
vmovups(mem(rax), ymm6)
|
||||
vmovups(ymm6, mem(rbx))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-3 */
|
||||
vmovups(mem(rax), ymm6)
|
||||
vmovups(ymm6, mem(rbx))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-4 */
|
||||
vmovups(mem(rax), ymm6)
|
||||
vmovups(ymm6, mem(rbx))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-5 */
|
||||
vmovups(mem(rax), ymm6)
|
||||
vmovups(ymm6, mem(rbx))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-6 */
|
||||
vmovups(mem(rax), ymm6)
|
||||
vmovups(ymm6, mem(rbx))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-7 */
|
||||
vmovups(mem(rax), ymm6)
|
||||
vmovups(ymm6, mem(rbx))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
/* Unroll-8 */
|
||||
vmovups(mem(rax), ymm6)
|
||||
vmovups(ymm6, mem(rbx))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.CKITERCOLU) // iterate again if i != 0.
|
||||
|
||||
label(.CCONKLEFTCOLU)
|
||||
|
||||
mov(var(k_left), rsi) // i = k_left;
|
||||
test(rsi, rsi) // check i via logical AND.
|
||||
je(.CDONE) // if i == 0, we're done; jump to end.
|
||||
// else, we prepare to enter k_left loop.
|
||||
|
||||
label(.CKLEFTCOLU) // EDGE LOOP (k_left)
|
||||
|
||||
/* Load/store a column of C using YMM register */
|
||||
vmovups(mem(rax), ymm6)
|
||||
vmovups(ymm6, mem(rbx))
|
||||
add(r10, rax)
|
||||
add(r8, rbx)
|
||||
|
||||
dec(rsi) // i -= 1;
|
||||
jne(.CKLEFTCOLU) // iterate again if i != 0.
|
||||
|
||||
label(.CDONE)
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
: // input operands
|
||||
[mask] "m" (mask),
|
||||
[k_iter] "m" (k_iter),
|
||||
[k_left] "m" (k_left),
|
||||
[a] "m" (a),
|
||||
[inca] "m" (inca),
|
||||
[lda] "m" (lda),
|
||||
[p] "m" (p),
|
||||
[ldp] "m" (ldp)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rsi",
|
||||
"r8", "r10", "r12", "r13", "r14",
|
||||
"xmm0", "xmm1", "xmm2", "xmm3",
|
||||
"ymm0", "ymm1", "ymm2", "ymm3",
|
||||
"ymm4", "ymm5", "ymm6", "ymm8",
|
||||
"zmm0", "zmm1", "zmm2", "zmm3",
|
||||
"zmm4", "zmm5", "zmm6", "zmm7",
|
||||
"zmm8", "zmm10", "zmm12", "zmm14",
|
||||
"zmm16", "zmm18", "zmm20", "zmm30",
|
||||
"zmm31", "k2", "memory"
|
||||
)
|
||||
}
|
||||
else // if ( cdim0 < mnr || gs || bli_does_conj( conja ) || !unitk )
|
||||
{
|
||||
PASTEMAC(cscal2m,BLIS_TAPI_EX_SUF)
|
||||
(
|
||||
0,
|
||||
BLIS_NONUNIT_DIAG,
|
||||
BLIS_DENSE,
|
||||
( trans_t )conja,
|
||||
cdim0,
|
||||
k0,
|
||||
kappa,
|
||||
a, inca0, lda0,
|
||||
p, 1, ldp0,
|
||||
cntx,
|
||||
NULL
|
||||
);
|
||||
|
||||
if ( cdim0 < mnr )
|
||||
{
|
||||
// Handle zero-filling along the "long" edge of the micropanel.
|
||||
|
||||
const dim_t i = cdim0;
|
||||
const dim_t m_edge = mnr - cdim0;
|
||||
const dim_t n_edge = k0_max;
|
||||
scomplex* restrict p_edge = p + (i )*1;
|
||||
|
||||
bli_cset0s_mxn
|
||||
(
|
||||
m_edge,
|
||||
n_edge,
|
||||
p_edge, 1, ldp
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if ( k0 < k0_max )
|
||||
{
|
||||
// Handle zero-filling along the "short" (far) edge of the micropanel.
|
||||
|
||||
const dim_t j = k0;
|
||||
const dim_t m_edge = mnr;
|
||||
const dim_t n_edge = k0_max - k0;
|
||||
scomplex* restrict p_edge = p + (j )*ldp;
|
||||
|
||||
bli_cset0s_mxn
|
||||
(
|
||||
m_edge,
|
||||
n_edge,
|
||||
p_edge, 1, ldp
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
799
kernels/zen4/3/bli_cgemm_zen4_asm_24x4.c
Normal file
799
kernels/zen4/3/bli_cgemm_zen4_asm_24x4.c
Normal file
@@ -0,0 +1,799 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "bli_cgemm_zen4_asm_macros.h"
|
||||
|
||||
#define LOOP_ALIGN ALIGN32
|
||||
|
||||
/* Minimum number of iterations required for prefetching C */
|
||||
#define TAIL_ITER 6
|
||||
|
||||
// This array is used to support ADDSUB instruction.
|
||||
static float fma_vec[16] __attribute__((aligned(64)))
|
||||
= {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
|
||||
|
||||
// This is an array used for the scatter/gather instructions.
|
||||
static int64_t offsets[24] __attribute__((aligned(64))) =
|
||||
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
|
||||
14, 15, 16, 17, 18, 19, 20, 21, 22, 23 };
|
||||
|
||||
/*
|
||||
Register usage :
|
||||
ZMM(0) - ZMM(2) : Load A
|
||||
ZMM(28) - ZMM(31) : Bdcst B
|
||||
ZMM(3) - ZMM(14) : Accumulate real_scaling
|
||||
ZMM(15) - ZMM(26) : Accumulate imag_scaling / Load C
|
||||
|
||||
Total registers used : 31
|
||||
*/
|
||||
void bli_cgemm_zen4_asm_24x4(
|
||||
dim_t k0,
|
||||
scomplex *restrict alpha,
|
||||
scomplex *restrict a,
|
||||
scomplex *restrict b,
|
||||
scomplex *restrict beta,
|
||||
scomplex *restrict c, inc_t rs_c0, inc_t cs_c0,
|
||||
auxinfo_t *data,
|
||||
cntx_t *restrict cntx
|
||||
)
|
||||
{
|
||||
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7);
|
||||
|
||||
/* Casting all the integers to the same type */
|
||||
const uint64_t k = k0;
|
||||
uint64_t rs_c = rs_c0 * 8; // rs_c0 = rs_c * 8(size of scomplex datatype)
|
||||
uint64_t cs_c = cs_c0 * 8; // cs_c0 = cs_c * 8(size of scomplex datatype)
|
||||
|
||||
/* Storing the address of the fma_vec array, to be used in the computation */
|
||||
const float *fmaPtr = &fma_vec[0];
|
||||
/* Storing the address of offsets, to be used for general-stride computation */
|
||||
const int64_t *offsetPtr = &offsets[0];
|
||||
|
||||
/* Determining the alpha and beta multiplication types */
|
||||
/* This is done as part of optimizing the alpha and beta scaling */
|
||||
uint64_t alpha_mul_type = BLIS_MUL_DEFAULT;
|
||||
uint64_t beta_mul_type = BLIS_MUL_DEFAULT;
|
||||
|
||||
/* Setting alpha_mul_type and bet_mul_type, based on special cases
|
||||
of alpha and beta. */
|
||||
if ( alpha->imag == 0.0 )
|
||||
{
|
||||
if ( alpha->real == 1.0 )
|
||||
alpha_mul_type = BLIS_MUL_ONE;
|
||||
else if ( alpha->real == -1.0 )
|
||||
alpha_mul_type = BLIS_MUL_MINUS_ONE;
|
||||
}
|
||||
|
||||
if ( beta->imag == 0.0 )
|
||||
{
|
||||
if ( beta->real == 1.0 )
|
||||
beta_mul_type = BLIS_MUL_ONE ;
|
||||
else if ( beta->real == -1.0 )
|
||||
beta_mul_type = BLIS_MUL_MINUS_ONE;
|
||||
else if ( beta->real == 0.0 )
|
||||
beta_mul_type = BLIS_MUL_ZERO;
|
||||
}
|
||||
|
||||
/* Start of the assembly code-section */
|
||||
BEGIN_ASM()
|
||||
|
||||
/* Setting the registers to zero */
|
||||
SET_ZERO()
|
||||
|
||||
/* Loading the value of k in RSI */
|
||||
MOV(VAR(k), RSI)
|
||||
|
||||
/* Loading the addresses of A, B and C */
|
||||
MOV(VAR(a), RAX) // RAX = addr of A
|
||||
MOV(VAR(b), RBX) // RBX = addr of B
|
||||
MOV(VAR(c), RCX) // RCX = addr of C
|
||||
|
||||
/* Load R9 with address of C to be used during prefetch */
|
||||
MOV(RCX, R9)
|
||||
|
||||
/* Loading column-stride of C, since this is a column major kernel */
|
||||
/* NOTE : cs_c has already been scaled by the size of the datatype */
|
||||
MOV(VAR(cs_c), R10) // R10 = cs_c = 8 * cs_c0
|
||||
|
||||
/* Unrolling by a factor of 4, along k-dimension */
|
||||
MOV(RSI, RDI)
|
||||
AND(IMM(3), RSI) // RSI = k % 4(k_fringe)
|
||||
SAR(IMM(2), RDI) // RDI = k / 4(k_iter)
|
||||
|
||||
/* The k-loop is divided into 4 parts, to have a fixed distance for prefetching C */
|
||||
/* The k-loop is divided as follows :
|
||||
1. .CK_BEFORE_PREFETCH : k/4 - 4 - TAIL_ITER(before C prefetch)
|
||||
2. .CK_PREFETCH : 4(prefetches C)
|
||||
3. .CK_AFTER_PREFETCH : TAIL_ITER(after C prefetch(prefetch distance))
|
||||
4. .CK_FRINGE
|
||||
*/
|
||||
|
||||
LABEL(.CK_BEFORE_PREFETCH)
|
||||
/* Check for entering the k-loop(before prefetch) */
|
||||
SUB(IMM(4 + TAIL_ITER), RDI)
|
||||
/* Jump to k-loop(prefetch) if k/4 <= 4 + TAIL_ITER */
|
||||
JLE(.CK_PREFETCH)
|
||||
|
||||
/* K-loop(unrolled) to perform A*B */
|
||||
LOOP_ALIGN
|
||||
LABEL(.CK_ITER_BP) // Unrolled iteration (B)efore (P)refetching
|
||||
|
||||
/* Performing rank-1 update 4 times(based on unroll) */
|
||||
SUB_ITER_24x4(0, RAX, RBX)
|
||||
SUB_ITER_24x4(1, RAX, RBX)
|
||||
SUB(IMM(1), RDI) // RDI(iterator) -= 1
|
||||
SUB_ITER_24x4(2, RAX, RBX)
|
||||
SUB_ITER_24x4(3, RAX, RBX)
|
||||
|
||||
/* Adjusting the addresses of A and B for the next set */
|
||||
LEA(MEM(RAX, 4 * 24 * 8), RAX) // RAX = RAX + 4 * MR * 8(loads)
|
||||
LEA(MEM(RBX, 4 * 4 * 8), RBX) // RBX = RBX + 4 * NR * 8(broascasts)
|
||||
|
||||
/* Jump if RDI != 0 */
|
||||
JNZ(.CK_ITER_BP)
|
||||
|
||||
LABEL(.CK_PREFETCH)
|
||||
/* Check for entering the k-loop(for C prefetch) */
|
||||
ADD(IMM(4), RDI)
|
||||
/* Jump to k-loop(after prefetch) if k/4 <= TAIL_ITER */
|
||||
JLE(.CK_AFTER_PREFETCH)
|
||||
|
||||
/* K-loop(unrolled) to perform A*B */
|
||||
LOOP_ALIGN
|
||||
LABEL(.CK_ITER_P) // Unrolled iteration with (P)refetching
|
||||
|
||||
/* Performing rank-1 update 4 times(based on unroll) */
|
||||
/* Also prefetch C */
|
||||
PREFETCHW0(MEM(R9))
|
||||
SUB_ITER_24x4(0, RAX, RBX)
|
||||
PREFETCHW0(MEM(R9, 64))
|
||||
SUB_ITER_24x4(1, RAX, RBX)
|
||||
PREFETCHW0(MEM(R9, 128))
|
||||
SUB(IMM(1), RDI) // RDI(iterator) -= 1
|
||||
SUB_ITER_24x4(2, RAX, RBX)
|
||||
SUB_ITER_24x4(3, RAX, RBX)
|
||||
|
||||
/* Adjusting the addresses of A and B for the next set */
|
||||
LEA(MEM(RAX, 4 * 24 * 8), RAX) // RAX = RAX + 4 * MR * 8(loads)
|
||||
LEA(MEM(RBX, 4 * 4 * 8), RBX) // RBX = RBX + 4 * NR * 8(broascasts)
|
||||
LEA(MEM(R9, R10, 1), R9) // RCX = RCX + cs_c
|
||||
|
||||
/* Jump if RDI != 0 */
|
||||
JNZ(.CK_ITER_P)
|
||||
|
||||
LABEL(.CK_AFTER_PREFETCH)
|
||||
/* Check for entering the k-loop(for C prefetch) */
|
||||
ADD(IMM(0 + TAIL_ITER), RDI)
|
||||
/* Jump to k-loop(after prefetch) if k/4 <= 0 */
|
||||
JLE(.CK_FRINGE)
|
||||
|
||||
/* K-loop(unrolled) to perform A*B */
|
||||
LOOP_ALIGN
|
||||
LABEL(.CK_ITER_AP) // Unrolled iteration (A)fter (P)refetching
|
||||
|
||||
/* Performing rank-1 update 4 times(based on unroll) */
|
||||
/* Also prefetch C */
|
||||
SUB_ITER_24x4(0, RAX, RBX)
|
||||
SUB_ITER_24x4(1, RAX, RBX)
|
||||
SUB(IMM(1), RDI) // RDI(iterator) -= 1
|
||||
SUB_ITER_24x4(2, RAX, RBX)
|
||||
SUB_ITER_24x4(3, RAX, RBX)
|
||||
|
||||
/* Adjusting the addresses of A and B for the next set */
|
||||
LEA(MEM(RAX, 4 * 24 * 8), RAX) // RAX = RAX + 4 * MR * 8(loads)
|
||||
LEA(MEM(RBX, 4 * 4 * 8), RBX) // RBX = RBX + 4 * NR * 8(broascasts)
|
||||
|
||||
/* Jump if RDI != 0 */
|
||||
JNZ(.CK_ITER_AP)
|
||||
|
||||
LABEL(.CK_FRINGE)
|
||||
/* Check for entering the k-loop(fringe) */
|
||||
TEST(RSI, RSI)
|
||||
JE(.POSTACCUM)
|
||||
|
||||
/* K-loop(unrolled) to perform A*B */
|
||||
LOOP_ALIGN
|
||||
LABEL(.CK_ITER_FRINGE)
|
||||
|
||||
/* Performing rank-1 update */
|
||||
SUB_ITER_24x4(0, RAX, RBX)
|
||||
SUB(IMM(1), RSI) // RSI(iterator) -= 1
|
||||
|
||||
/* Adjusting the addresses of A and B for the next set */
|
||||
/* new_addr = old_addr + ( unroll_factor * {MR or NR} * size_of_type ) */
|
||||
LEA(MEM(RAX, 24 * 8), RAX) // RAX = RAX + MR * 8(loads)
|
||||
LEA(MEM(RBX, 4 * 8), RBX) // RBX = RBX + NR * 8(broascasts)
|
||||
|
||||
/* Jump until RSI becomes 0 */
|
||||
JNZ(.CK_ITER_FRINGE)
|
||||
|
||||
LABEL(.POSTACCUM)
|
||||
|
||||
/* The registers from ZMM(15) to ZMM(26) contain the FMA ops
|
||||
using imaginary components from elements in B matrix.
|
||||
We should shuffle them( even and odd indices )
|
||||
SRC: ZMM(15) = ( Ar0*Bi0, Ai0*Bi0, Ar1*Bi0, Ai1*Bi0, ... )
|
||||
DST: ZMM(15) = ( Ai0*Bi0, Ar0*Bi0, Ai1*Bi0, Ar1*Bi0, ... )
|
||||
Similary for the other registers
|
||||
*/
|
||||
PERMUTE(15, 16, 17)
|
||||
PERMUTE(18, 19, 20)
|
||||
PERMUTE(21, 22, 23)
|
||||
PERMUTE(24, 25, 26)
|
||||
|
||||
/* Loading ZMM(0) with 1.0f, for reduction */
|
||||
MOV(VAR(fmaPtr), R14)
|
||||
VMOVAPS(MEM(R14), ZMM(0))
|
||||
|
||||
/* Reducing the result using real/imag accumulators, for complex arithmetic
|
||||
SRC: ZMM(3) = ( Ar0*Br0, Ai0*Br0, Ar1*Br0, Ai1*Br0, ... )
|
||||
ZMM(15) = ( Ai0*Bi0, Ar0*Bi0, Ai1*Bi0, Ar1*Bi0, ... )
|
||||
DST: ZMM(3) = ( Ar0*Br0 - Ai0*Bi0, Ai0*Br0 + Ar0*Bi0, ... )
|
||||
Similarly done for the other registers
|
||||
*/
|
||||
FMADDSUB(3, 15, 4, 16, 5, 17)
|
||||
FMADDSUB(6, 18, 7, 19, 8, 20)
|
||||
FMADDSUB(9, 21, 10, 22, 11, 23)
|
||||
FMADDSUB(12, 24, 13, 25, 14, 26)
|
||||
|
||||
/*
|
||||
The result of A*B(micro-tile) is a 24x4 matrix(column-major), as follows :
|
||||
|
||||
Column-1 Column-2 Column-3 Column-4
|
||||
Rows(1-8) ZMM(3) ZMM(6) ZMM(9) ZMM(12)
|
||||
Rows(9-16) ZMM(4) ZMM(7) ZMM(10) ZMM(13)
|
||||
Rows(17-24) ZMM(5) ZMM(8) ZMM(11) ZMM(14)
|
||||
|
||||
*/
|
||||
|
||||
LABEL(.ALPHA_SCALING)
|
||||
/*
|
||||
Check for alpha_mul_type, to jump to the required code-section
|
||||
Intermediate result(IR) = alpha*(A*B)
|
||||
If alpha == ( 1.0, 0.0 ) => BLIS_MUL_ONE
|
||||
IR = A*B
|
||||
else if, alpha != ( -1.0, 0.0 ) => BLIS_MUL_DEFAULT
|
||||
IR = alpha*(A*B), using complex multiplication
|
||||
else => BLIS_MUL_MINUS_ONE
|
||||
IR = 0.0 - A*B, using subtraction
|
||||
*/
|
||||
MOV(VAR(alpha_mul_type), R14)
|
||||
CMP(IMM(1), R14) // Check if alpha = 1.0
|
||||
/* Skip alpha scaling and jump to beta scaling */
|
||||
JE(.BETA_SCALING)
|
||||
|
||||
CMP(IMM(2), R14) // Check if alpha != -1.0
|
||||
/* Jump to the general case of alpha scaling */
|
||||
JE(.ALPHA_GENERAL)
|
||||
|
||||
/* Alpha scaling when alpha == -1.0 */
|
||||
LABEL(.ALPHA_MINUS_ONE)
|
||||
/* Set ZMM(1) to 0.0f, and subtract the registers from ZMM(1) */
|
||||
VXORPS(ZMM(1), ZMM(1), ZMM(1))
|
||||
/* ZMM(3) = ZMM(1) - ZMM(3) = 0.0f - A*B
|
||||
Similarly done for other registers */
|
||||
ALPHA_MINUS_ONE(3, 1, 4, 1, 5, 1)
|
||||
ALPHA_MINUS_ONE(6, 1, 7, 1, 8, 1)
|
||||
ALPHA_MINUS_ONE(9, 1, 10, 1, 11, 1)
|
||||
ALPHA_MINUS_ONE(12, 1, 13, 1, 14, 1)
|
||||
|
||||
/* Jump to beta scaling */
|
||||
JMP(.BETA_SCALING)
|
||||
|
||||
/* Alpha scaling when alpha != 1.0 and alpha != -1.0 */
|
||||
LABEL(.ALPHA_GENERAL)
|
||||
/* Load alpha onto a ZMM register */
|
||||
MOV(VAR(alpha), RAX)
|
||||
/* Broadcast the real and imag components of alpha onto the registers */
|
||||
VBROADCASTSS(MEM(RAX, 0), ZMM(1))
|
||||
VBROADCASTSS(MEM(RAX, 4), ZMM(2))
|
||||
|
||||
/* Scale the result of A*B with alpha */
|
||||
/* ZMM(15) = alphai * ZMM(3)
|
||||
ZMM(3) = alphar * ZMM(3)
|
||||
ZMM(3) = fmaddsub(ZMM(3), permute(ZMM(15)))
|
||||
|
||||
Similarly done for other pairs of registers */
|
||||
ALPHA_DEFAULT(3, 15, 4, 16, 5, 17)
|
||||
ALPHA_DEFAULT(6, 18, 7, 19, 8, 20)
|
||||
ALPHA_DEFAULT(9, 21, 10, 22, 11, 23)
|
||||
ALPHA_DEFAULT(12, 24, 13, 25, 14, 26)
|
||||
|
||||
/* Perform beta scaling */
|
||||
LABEL(.BETA_SCALING)
|
||||
/* Load the row and column strides of C */
|
||||
MOV(VAR(rs_c), RDI)
|
||||
MOV(VAR(cs_c), RSI)
|
||||
|
||||
CMP(IMM(8), RDI) // Check if C is column stored
|
||||
|
||||
JNE(.ROWSTORED) // Jump to row stored
|
||||
|
||||
LABEL(.COLSTORED)
|
||||
/*
|
||||
Check for beta_mul_type, to jump to the required code-section
|
||||
Intermediate C = beta*C + IR, where IR = alpha*A*B
|
||||
If beta == ( 0.0, 0.0 ) => BLIS_MUL_ZERO
|
||||
C = IR, skip beta-scaling
|
||||
else if beta == ( 1.0, 0.0 ) => BLIS_MUL_ONE
|
||||
C = C + IR, using addition
|
||||
else if, beta != ( -1.0, 0.0 ) => BLIS_MUL_DEFAULT
|
||||
C = beta*C + IR, using complex multiplication
|
||||
else => BLIS_MUL_MINUS_ONE
|
||||
C = ( 0.0 - C ) + IR, using subtraction
|
||||
*/
|
||||
MOV(VAR(beta_mul_type), R14)
|
||||
CMP(IMM(0), R14) // Check if beta = 0.0
|
||||
/* Skip beta scaling and jump to store */
|
||||
JE(.BETA_ZERO_COL)
|
||||
|
||||
CMP(IMM(1), R14) // Check if beta = 1.0
|
||||
/* Jump to beta = 1.0 case */
|
||||
JE(.BETA_ONE_COL)
|
||||
|
||||
CMP(IMM(2), R14) // Check if alpha != -1.0
|
||||
/* Jump to the general case of alpha scaling */
|
||||
JE(.BETA_DEFAULT_COL)
|
||||
|
||||
/* Beta scaling when beta == -1.0 */
|
||||
LABEL(.BETA_MINUS_ONE_COL)
|
||||
|
||||
/* Perform C = alpha*A*B - C */
|
||||
/* ZMM(15) = load(C)
|
||||
ZMM(15) = ZMM(3) - ZMM(15) = alpha*A*B - C
|
||||
store(ZMM(15))
|
||||
|
||||
Similarly done for other registers */
|
||||
BETA_MINUS_ONE_PRIMARY(3, 15, 4, 16, 5, 17)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_MINUS_ONE_PRIMARY(6, 18, 7, 19, 8, 20)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_MINUS_ONE_PRIMARY(9, 21, 10, 22, 11, 23)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_MINUS_ONE_PRIMARY(12, 24, 13, 25, 14, 26)
|
||||
|
||||
JMP(.END)
|
||||
|
||||
/* Beta scaling when beta == -1.0 */
|
||||
LABEL(.BETA_ONE_COL)
|
||||
|
||||
/* Perform C = C + alpha*A*B */
|
||||
/* ZMM(15) = load(C)
|
||||
ZMM(15) = ZMM(3) + ZMM(15) = alpha*A*B + C
|
||||
store(ZMM(15))
|
||||
|
||||
Similarly done for other registers */
|
||||
BETA_ONE_PRIMARY(3, 15, 4, 16, 5, 17)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_ONE_PRIMARY(6, 18, 7, 19, 8, 20)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_ONE_PRIMARY(9, 21, 10, 22, 11, 23)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_ONE_PRIMARY(12, 24, 13, 25, 14, 26)
|
||||
|
||||
JMP(.END)
|
||||
|
||||
/* Beta scaling for generic case */
|
||||
LABEL(.BETA_DEFAULT_COL)
|
||||
/* Load beta onto a ZMM register */
|
||||
MOV(VAR(beta), RBX)
|
||||
/* Broadcast the real and imag components of beta onto the registers */
|
||||
VBROADCASTSS(MEM(RBX, 0), ZMM(1))
|
||||
VBROADCASTSS(MEM(RBX, 4), ZMM(2))
|
||||
|
||||
/* Perform C = beta*C + alpha*A*B */
|
||||
/* ZMM(15) = load(C)
|
||||
Perform beta scaling of ZMM(15)(similar to alpha scaling)
|
||||
ZMM(15) = ZMM(3) + ZMM(15) = alpha*A*B + beta*C
|
||||
store(ZMM(15))
|
||||
|
||||
Similarly done for other pairs of registers */
|
||||
BETA_DEFAULT_PRIMARY(3, 15, 4, 16, 5, 17)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_DEFAULT_PRIMARY(6, 18, 7, 19, 8, 20)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_DEFAULT_PRIMARY(9, 21, 10, 22, 11, 23)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_DEFAULT_PRIMARY(12, 24, 13, 25, 14, 26)
|
||||
|
||||
JMP(.END)
|
||||
|
||||
LABEL(.BETA_ZERO_COL)
|
||||
/* This code-section is taken if we want to skip scaling */
|
||||
VMOVUPS(ZMM(3), MEM(RCX))
|
||||
VMOVUPS(ZMM(4), MEM(RCX, 64))
|
||||
VMOVUPS(ZMM(5), MEM(RCX, 128))
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
|
||||
VMOVUPS(ZMM(6), MEM(RCX))
|
||||
VMOVUPS(ZMM(7), MEM(RCX, 64))
|
||||
VMOVUPS(ZMM(8), MEM(RCX, 128))
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
|
||||
VMOVUPS(ZMM(9), MEM(RCX))
|
||||
VMOVUPS(ZMM(10), MEM(RCX, 64))
|
||||
VMOVUPS(ZMM(11), MEM(RCX, 128))
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
|
||||
VMOVUPS(ZMM(12), MEM(RCX))
|
||||
VMOVUPS(ZMM(13), MEM(RCX, 64))
|
||||
VMOVUPS(ZMM(14), MEM(RCX, 128))
|
||||
|
||||
JMP(.END)
|
||||
|
||||
LABEL(.ROWSTORED)
|
||||
/* Check for general stride of C */
|
||||
CMP(IMM(8), RSI) // Check if C is row stored
|
||||
JNE(.GENERALSTRIDE) // Jump to general stride
|
||||
|
||||
/* This code-section is taken if C is row-stored */
|
||||
/*
|
||||
Check for beta_mul_type, to jump to the required code-section
|
||||
Intermediate C = beta*C + IR, where IR = alpha*A*B
|
||||
If beta == ( 0.0, 0.0 ) => BLIS_MUL_ZERO
|
||||
C = IR, skip beta-scaling
|
||||
else => BLIS_MUL_DEFAULT
|
||||
C = beta*C + IR, using complex multiplication
|
||||
*/
|
||||
MOV(VAR(beta_mul_type), R14)
|
||||
CMP(IMM(0), R14) // Check if beta = 0.0
|
||||
/* Skip beta scaling and jump to store */
|
||||
JE(.BETA_ZERO_ROW)
|
||||
|
||||
LABEL(.BETA_DEFAULT_ROW)
|
||||
/* Load beta onto a ZMM register */
|
||||
MOV(VAR(beta), RBX)
|
||||
/* Broadcast the real and imag components of beta onto the registers */
|
||||
VBROADCASTSS(MEM(RBX, 0), ZMM(1))
|
||||
VBROADCASTSS(MEM(RBX, 4), ZMM(2))
|
||||
|
||||
/* We need to transpose the 24x4 block of alpha*A*B,
|
||||
in steps of 8x4.
|
||||
We use an 8x8 transpose routine with additional
|
||||
registers.
|
||||
|
||||
Input for transpose:
|
||||
Column-1 Column-2 Column-3 Column-4 Column-5 Column-6 Column-7 Column-8
|
||||
Rows(1-8) ZMM(3) ZMM(6) ZMM(9) ZMM(12) ZMM(28) ZMM(29) ZMM(30) ZMM(31)
|
||||
*/
|
||||
|
||||
/* Transpose the 8x8 block of alpha*A*B */
|
||||
/* ZMM(15) to ZMM(22) are used as temporary registers
|
||||
for transpose operation */
|
||||
TRANSPOSE_8X8(3, 6, 9, 12, 28, 29, 30, 31,
|
||||
15, 16, 17, 18, 19, 20, 21, 22)
|
||||
|
||||
/* Scale C by beta and compute the result */
|
||||
/* This is done one row at a time */
|
||||
BETA_DEFAULT_SECONDARY(3, 15, 16)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(6, 17, 18)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(9, 19, 20)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(12, 21, 22)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(28, 15, 16)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(29, 17, 18)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(30, 19, 20)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(31, 21, 22)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
|
||||
/*
|
||||
Input for transpose:
|
||||
Column-1 Column-2 Column-3 Column-4 Column-5 Column-6 Column-7 Column-8
|
||||
Rows(9-15) ZMM(4) ZMM(7) ZMM(10) ZMM(13) ZMM(28) ZMM(29) ZMM(30) ZMM(31)
|
||||
*/
|
||||
|
||||
/* Transpose the 8x8 block of alpha*A*B */
|
||||
/* ZMM(15) to ZMM(22) are used as temporary registers
|
||||
for transpose operation */
|
||||
TRANSPOSE_8X8(4, 7, 10, 13, 28, 29, 30, 31,
|
||||
15, 16, 17, 18, 19, 20, 21, 22)
|
||||
|
||||
/* Scale C by beta and compute the result */
|
||||
/* This is done one row at a time */
|
||||
BETA_DEFAULT_SECONDARY(4, 15, 16)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(7, 17, 18)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(10, 19, 20)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(13, 21, 22)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(28, 15, 16)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(29, 17, 18)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(30, 19, 20)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(31, 21, 22)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
|
||||
/*
|
||||
Input for transpose:
|
||||
Column-1 Column-2 Column-3 Column-4 Column-5 Column-6 Column-7 Column-8
|
||||
Rows(16-23) ZMM(5) ZMM(8) ZMM(11) ZMM(14) ZMM(28) ZMM(29) ZMM(30) ZMM(31)
|
||||
*/
|
||||
|
||||
/* Transpose the 8x8 block of alpha*A*B */
|
||||
/* ZMM(15) to ZMM(22) are used as temporary registers
|
||||
for transpose operation */
|
||||
TRANSPOSE_8X8(5, 8, 11, 14, 28, 29, 30, 31,
|
||||
15, 16, 17, 18, 19, 20, 21, 22)
|
||||
|
||||
/* Scale C by beta and compute the result */
|
||||
/* This is done one row at a time */
|
||||
BETA_DEFAULT_SECONDARY(5, 15, 16)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(8, 17, 18)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(11, 19, 20)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(14, 21, 22)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(28, 15, 16)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(29, 17, 18)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(30, 19, 20)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(31, 21, 22)
|
||||
|
||||
JMP(.END)
|
||||
|
||||
LABEL(.BETA_ZERO_ROW)
|
||||
/* We need to transpose the 24x4 block of alpha*A*B,
|
||||
in steps of 8x4.
|
||||
We use an 8x8 transpose routine with additional
|
||||
registers.
|
||||
|
||||
Input for transpose:
|
||||
Column-1 Column-2 Column-3 Column-4 Column-5 Column-6 Column-7 Column-8
|
||||
Rows(1-8) ZMM(3) ZMM(6) ZMM(9) ZMM(12) ZMM(28) ZMM(29) ZMM(30) ZMM(31)
|
||||
*/
|
||||
/* Set 4 additional registers to 0.0 */
|
||||
VXORPS(ZMM(28), ZMM(28), ZMM(28))
|
||||
VXORPS(ZMM(29), ZMM(29), ZMM(29))
|
||||
VXORPS(ZMM(30), ZMM(30), ZMM(30))
|
||||
VXORPS(ZMM(31), ZMM(31), ZMM(31))
|
||||
|
||||
/* Transpose the 8x8 block of alpha*A*B */
|
||||
/* ZMM(15) to ZMM(22) are used as temporary registers
|
||||
for transpose operation */
|
||||
TRANSPOSE_8X8(3, 6, 9, 12, 28, 29, 30, 31,
|
||||
15, 16, 17, 18, 19, 20, 21, 22)
|
||||
|
||||
/* Store the result back to C */
|
||||
/* We need to store only the first 256-bit lane of the
|
||||
registers post transpose */
|
||||
VMOVUPS(YMM(3), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(6), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(9), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(12), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(28), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(29), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(30), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(31), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
|
||||
/*
|
||||
Input for transpose:
|
||||
Column-1 Column-2 Column-3 Column-4 Column-5 Column-6 Column-7 Column-8
|
||||
Rows(9-15) ZMM(4) ZMM(7) ZMM(10) ZMM(13) ZMM(28) ZMM(29) ZMM(30) ZMM(31)
|
||||
*/
|
||||
|
||||
/* Transpose the 8x8 block of alpha*A*B */
|
||||
/* ZMM(15) to ZMM(22) are used as temporary registers
|
||||
for transpose operation */
|
||||
TRANSPOSE_8X8(4, 7, 10, 13, 28, 29, 30, 31,
|
||||
15, 16, 17, 18, 19, 20, 21, 22)
|
||||
|
||||
/* Store the result back to C */
|
||||
/* We need to store only the first 256-bit lane of the
|
||||
registers post transpose */
|
||||
VMOVUPS(YMM(4), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(7), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(10), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(13), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(28), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(29), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(30), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(31), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
|
||||
/*
|
||||
Input for transpose:
|
||||
Column-1 Column-2 Column-3 Column-4 Column-5 Column-6 Column-7 Column-8
|
||||
Rows(16-23) ZMM(5) ZMM(8) ZMM(11) ZMM(14) ZMM(28) ZMM(29) ZMM(30) ZMM(31)
|
||||
*/
|
||||
|
||||
/* Transpose the 8x8 block of alpha*A*B */
|
||||
/* ZMM(15) to ZMM(22) are used as temporary registers
|
||||
for transpose operation */
|
||||
TRANSPOSE_8X8(5, 8, 11, 14, 28, 29, 30, 31,
|
||||
15, 16, 17, 18, 19, 20, 21, 22)
|
||||
|
||||
/* Store the result back to C */
|
||||
/* We need to store only the first 256-bit lane of the
|
||||
registers post transpose */
|
||||
VMOVUPS(YMM(5), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(8), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(11), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(14), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(28), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(29), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(30), MEM(RCX))
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
VMOVUPS(YMM(31), MEM(RCX))
|
||||
|
||||
JMP(.END)
|
||||
|
||||
LABEL(.GENERALSTRIDE)
|
||||
/* This code-section is taken if C has general stride */
|
||||
/*
|
||||
In case of general strides for C, we need to load/store C
|
||||
using gather/scatter instructions.
|
||||
Visualizing C(8x4):
|
||||
---------------------------------------------
|
||||
| C00 | C10 | ... | C30 |
|
||||
| | (rs_c) | | (rs_c) | ... | | (rs_c) |
|
||||
| C01 | C11 | ... | C31 |
|
||||
| | (rs_c) | | (rs_c) | ... | | (rs_c) |
|
||||
| C02 | C12 | ... | C32 |
|
||||
| . | . | ... | . |
|
||||
| . | . | ... | . |
|
||||
| . | . | ... | . |
|
||||
| C07 | C17 | ... | C37 |
|
||||
---------------------------------------------
|
||||
|
||||
Loading C :
|
||||
Gather all elements of C column-wise onto ZMM registers
|
||||
|
||||
Compute with C(based on beta):
|
||||
Similar to column-stored case, perform beta scaling and add to
|
||||
alpha*A*B
|
||||
|
||||
Storing C :
|
||||
Scatter the result one column at a time, using ZMM registers
|
||||
*/
|
||||
MOV(VAR(offsetPtr), R9) // Load address of offsets
|
||||
VPBROADCASTQ(RDI, ZMM(31)) // Broadcast rs_c onto a register
|
||||
VPMULLQ(MEM(R9), ZMM(31), ZMM(28)) // ZMM28 = { 0*rs_c, 1*rs_c, 2*rs_c, 3*rs_c, ... }
|
||||
VPMULLQ(MEM(R9, 64), ZMM(31), ZMM(29)) // ZMM29 = { 8*rs_c, 9*rs_c, 10*rs_c, 11*rs_c, ... }
|
||||
VPMULLQ(MEM(R9, 128), ZMM(31), ZMM(30)) // ZMM30 = { 16*rs_c, 17*rs_c, 18*rs_c, 19*rs_c, ... }
|
||||
/*
|
||||
Check for beta_mul_type, to jump to the required code-section
|
||||
Intermediate C = beta*C + IR, where IR = alpha*A*B
|
||||
If beta == ( 0.0, 0.0 ) => BLIS_MUL_ZERO
|
||||
C = IR, skip beta-scaling
|
||||
else => BLIS_MUL_DEFAULT
|
||||
C = beta*C + IR, using complex multiplication
|
||||
*/
|
||||
MOV(VAR(beta_mul_type), R14)
|
||||
CMP(IMM(0), R14) // Check if beta = 0.0
|
||||
/* Skip beta scaling and jump to store */
|
||||
JE(.BETA_ZERO_GENERIC)
|
||||
|
||||
LABEL(.BETA_DEFAULT_GENERIC)
|
||||
/* Load beta onto a ZMM register */
|
||||
MOV(VAR(beta), RBX)
|
||||
/* Broadcast the real and imag components of beta onto the registers */
|
||||
VBROADCASTSS(MEM(RBX, 0), ZMM(1))
|
||||
VBROADCASTSS(MEM(RBX, 4), ZMM(2))
|
||||
|
||||
/* Compute C = beta*C + alpha*A*B, and store to C */
|
||||
BETA_DEFAULT_GENERAL(3, 4, 5)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_GENERAL(6, 7, 8)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_GENERAL(9, 10, 11)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_GENERAL(12, 13, 14)
|
||||
|
||||
JMP(.END)
|
||||
|
||||
LABEL(.BETA_ZERO_GENERIC)
|
||||
/* Store the result onto C, one column at a time */
|
||||
BETA_ZERO_GENERAL(3, 4, 5)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_ZERO_GENERAL(6, 7, 8)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_ZERO_GENERAL(9, 10, 11)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_ZERO_GENERAL(12, 13, 14)
|
||||
|
||||
LABEL(.END)
|
||||
VZEROUPPER()
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
: // input operands
|
||||
[k] "m"(k),
|
||||
[a] "m"(a),
|
||||
[b] "m"(b),
|
||||
[c] "m"(c),
|
||||
[rs_c] "m"(rs_c),
|
||||
[cs_c] "m"(cs_c),
|
||||
[fmaPtr] "m"(fmaPtr),
|
||||
[offsetPtr] "m"(offsetPtr),
|
||||
[alpha_mul_type] "m"(alpha_mul_type),
|
||||
[beta_mul_type] "m"(beta_mul_type),
|
||||
[alpha] "m"(alpha),
|
||||
[beta] "m"(beta)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdi", "rsi", "r9", "r10", "r12", "r14",
|
||||
"k0", "k1", "k2", "k3", "k4",
|
||||
"ymm0", "ymm1", "ymm2", "ymm3",
|
||||
"ymm4", "ymm5", "ymm6", "ymm7",
|
||||
"ymm8", "ymm9", "ymm10", "ymm11",
|
||||
"ymm12", "ymm13", "ymm14", "ymm15",
|
||||
"ymm16", "ymm17", "ymm18", "ymm19",
|
||||
"ymm20", "ymm21", "ymm22", "ymm28",
|
||||
"ymm29", "ymm30", "ymm31",
|
||||
"zmm0", "zmm1", "zmm2",
|
||||
"zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8",
|
||||
"zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14",
|
||||
"zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20",
|
||||
"zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
|
||||
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory"
|
||||
)
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7);
|
||||
}
|
||||
834
kernels/zen4/3/bli_cgemm_zen4_asm_4x24.c
Normal file
834
kernels/zen4/3/bli_cgemm_zen4_asm_4x24.c
Normal file
@@ -0,0 +1,834 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
#include "bli_cgemm_zen4_asm_macros.h"
|
||||
|
||||
#define LOOP_ALIGN ALIGN32
|
||||
|
||||
/* Minimum number of iterations required for prefetching C */
|
||||
#define TAIL_ITER 6
|
||||
|
||||
// This array is used to support ADDSUB instruction.
|
||||
static float fma_vec[16] __attribute__((aligned(64)))
|
||||
= {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
|
||||
|
||||
// This is an array used for the scatter/gather instructions.
|
||||
static int64_t offsets[24] __attribute__((aligned(64))) =
|
||||
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
|
||||
14, 15, 16, 17, 18, 19, 20, 21, 22, 23 };
|
||||
|
||||
/*
|
||||
Register usage :
|
||||
ZMM(0) - ZMM(2) : Load B
|
||||
ZMM(28) - ZMM(31) : Bdcst A
|
||||
ZMM(3) - ZMM(14) : Accumulate real_scaling
|
||||
ZMM(15) - ZMM(26) : Accumulate imag_scaling / Load C
|
||||
|
||||
Total registers used : 31
|
||||
*/
|
||||
void bli_cgemm_zen4_asm_4x24(
|
||||
dim_t k0,
|
||||
scomplex *restrict alpha,
|
||||
scomplex *restrict a,
|
||||
scomplex *restrict b,
|
||||
scomplex *restrict beta,
|
||||
scomplex *restrict c, inc_t rs_c0, inc_t cs_c0,
|
||||
auxinfo_t *data,
|
||||
cntx_t *restrict cntx
|
||||
)
|
||||
{
|
||||
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7);
|
||||
|
||||
/* Casting all the integers to the same type */
|
||||
const uint64_t k = k0;
|
||||
uint64_t rs_c = rs_c0 * 8; // rs_c0 = rs_c * 8(size of scomplex datatype)
|
||||
uint64_t cs_c = cs_c0 * 8; // cs_c0 = cs_c * 8(size of scomplex datatype)
|
||||
|
||||
/* Storing the address of the fma_vec array, to be used in the computation */
|
||||
const float *fmaPtr = &fma_vec[0];
|
||||
/* Storing the address of offsets, to be used for general-stride computation */
|
||||
const int64_t *offsetPtr = &offsets[0];
|
||||
|
||||
/* Determining the alpha and beta multiplication types */
|
||||
/* This is done as part of optimizing the alpha and beta scaling */
|
||||
uint64_t alpha_mul_type = BLIS_MUL_DEFAULT;
|
||||
uint64_t beta_mul_type = BLIS_MUL_DEFAULT;
|
||||
|
||||
/* Setting alpha_mul_type and bet_mul_type, based on special cases
|
||||
of alpha and beta. */
|
||||
if ( alpha->imag == 0.0 )
|
||||
{
|
||||
if ( alpha->real == 1.0 )
|
||||
alpha_mul_type = BLIS_MUL_ONE;
|
||||
else if ( alpha->real == -1.0 )
|
||||
alpha_mul_type = BLIS_MUL_MINUS_ONE;
|
||||
}
|
||||
|
||||
if ( beta->imag == 0.0 )
|
||||
{
|
||||
if ( beta->real == 1.0 )
|
||||
beta_mul_type = BLIS_MUL_ONE ;
|
||||
else if ( beta->real == -1.0 )
|
||||
beta_mul_type = BLIS_MUL_MINUS_ONE;
|
||||
else if ( beta->real == 0.0 )
|
||||
beta_mul_type = BLIS_MUL_ZERO;
|
||||
}
|
||||
|
||||
/* Start of the assembly code-section */
|
||||
BEGIN_ASM()
|
||||
|
||||
/* Setting the registers to zero */
|
||||
SET_ZERO()
|
||||
|
||||
/* Loading the value of k in RSI */
|
||||
MOV(VAR(k), RSI)
|
||||
|
||||
/* Loading the addresses of A, B and C */
|
||||
MOV(VAR(a), RAX) // RAX = addr of A
|
||||
MOV(VAR(b), RBX) // RBX = addr of B
|
||||
MOV(VAR(c), RCX) // RCX = addr of C
|
||||
|
||||
/* Load R9 with address of C to be used during prefetch */
|
||||
MOV(RCX, R9)
|
||||
|
||||
/* Loading row-stride of C, since this is a row major kernel */
|
||||
/* NOTE : rs_c has already been scaled by the size of the datatype */
|
||||
MOV(VAR(rs_c), R10) // R10 = rs_c = 8 * rs_c0
|
||||
|
||||
/* Unrolling by a factor of 4, along k-dimension */
|
||||
MOV(RSI, RDI)
|
||||
AND(IMM(3), RSI) // RSI = k % 4(k_fringe)
|
||||
SAR(IMM(2), RDI) // RDI = k / 4(k_iter)
|
||||
|
||||
/* The k-loop is divided into 4 parts, to have a fixed distance for prefetching C */
|
||||
/* The k-loop is divided as follows :
|
||||
1. .CK_BEFORE_PREFETCH : k/4 - 4 - TAIL_ITER(before C prefetch)
|
||||
2. .CK_PREFETCH : 4(prefetches C)
|
||||
3. .CK_AFTER_PREFETCH : TAIL_ITER(after C prefetch(prefetch distance))
|
||||
4. .CK_FRINGE
|
||||
*/
|
||||
|
||||
LABEL(.CK_BEFORE_PREFETCH)
|
||||
/* Check for entering the k-loop(before prefetch) */
|
||||
SUB(IMM(4 + TAIL_ITER), RDI)
|
||||
/* Jump to k-loop(prefetch) if k/4 <= 4 + TAIL_ITER */
|
||||
JLE(.CK_PREFETCH)
|
||||
|
||||
/* K-loop(unrolled) to perform A*B */
|
||||
LOOP_ALIGN
|
||||
LABEL(.CK_ITER_BP) // Unrolled iteration (B)efore (P)refetching
|
||||
|
||||
/* Performing rank-1 update 4 times(based on unroll) */
|
||||
SUB_ITER_24x4(0, RBX, RAX)
|
||||
SUB_ITER_24x4(1, RBX, RAX)
|
||||
SUB(IMM(1), RDI) // RDI(iterator) -= 1
|
||||
SUB_ITER_24x4(2, RBX, RAX)
|
||||
SUB_ITER_24x4(3, RBX, RAX)
|
||||
|
||||
/* Adjusting the addresses of A and B for the next set */
|
||||
LEA(MEM(RBX, 4 * 24 * 8), RBX) // RBX = RBX + 4 * NR * 8(loads)
|
||||
LEA(MEM(RAX, 4 * 4 * 8), RAX) // RAX = RAX + 4 * MR * 8(broascasts)
|
||||
|
||||
/* Jump if RDI != 0 */
|
||||
JNZ(.CK_ITER_BP)
|
||||
|
||||
LABEL(.CK_PREFETCH)
|
||||
/* Check for entering the k-loop(for C prefetch) */
|
||||
ADD(IMM(4), RDI)
|
||||
/* Jump to k-loop(after prefetch) if k/4 <= TAIL_ITER */
|
||||
JLE(.CK_AFTER_PREFETCH)
|
||||
|
||||
/* K-loop(unrolled) to perform A*B */
|
||||
LOOP_ALIGN
|
||||
LABEL(.CK_ITER_P) // Unrolled iteration with (P)refetching
|
||||
|
||||
/* Performing rank-1 update 4 times(based on unroll) */
|
||||
/* Also prefetch C */
|
||||
PREFETCHW0(MEM(R9))
|
||||
SUB_ITER_24x4(0, RBX, RAX)
|
||||
PREFETCHW0(MEM(R9, 64))
|
||||
SUB_ITER_24x4(1, RBX, RAX)
|
||||
PREFETCHW0(MEM(R9, 128))
|
||||
SUB(IMM(1), RDI) // RDI(iterator) -= 1
|
||||
SUB_ITER_24x4(2, RBX, RAX)
|
||||
SUB_ITER_24x4(3, RBX, RAX)
|
||||
|
||||
/* Adjusting the addresses of A and B for the next set */
|
||||
LEA(MEM(RBX, 4 * 24 * 8), RBX) // RBX = RBX + 4 * NR * 8(loads)
|
||||
LEA(MEM(RAX, 4 * 4 * 8), RAX) // RAX = RAX + 4 * MR * 8(broascasts)
|
||||
LEA(MEM(R9, R10, 1), R9) // RCX = RCX + cs_c
|
||||
|
||||
/* Jump if RDI != 0 */
|
||||
JNZ(.CK_ITER_P)
|
||||
|
||||
LABEL(.CK_AFTER_PREFETCH)
|
||||
/* Check for entering the k-loop(for C prefetch) */
|
||||
ADD(IMM(0 + TAIL_ITER), RDI)
|
||||
/* Jump to k-loop(after prefetch) if k/4 <= 0 */
|
||||
JLE(.CK_FRINGE)
|
||||
|
||||
/* K-loop(unrolled) to perform A*B */
|
||||
LOOP_ALIGN
|
||||
LABEL(.CK_ITER_AP) // Unrolled iteration (A)fter (P)refetching
|
||||
|
||||
/* Performing rank-1 update 4 times(based on unroll) */
|
||||
/* Also prefetch C */
|
||||
SUB_ITER_24x4(0, RBX, RAX)
|
||||
SUB_ITER_24x4(1, RBX, RAX)
|
||||
SUB(IMM(1), RDI) // RDI(iterator) -= 1
|
||||
SUB_ITER_24x4(2, RBX, RAX)
|
||||
SUB_ITER_24x4(3, RBX, RAX)
|
||||
|
||||
/* Adjusting the addresses of A and B for the next set */
|
||||
LEA(MEM(RBX, 4 * 24 * 8), RBX) // RBX = RBX + 4 * NR * 8(loads)
|
||||
LEA(MEM(RAX, 4 * 4 * 8), RAX) // RAX = RAX + 4 * MR * 8(broascasts)
|
||||
|
||||
/* Jump if RDI != 0 */
|
||||
JNZ(.CK_ITER_AP)
|
||||
|
||||
LABEL(.CK_FRINGE)
|
||||
/* Check for entering the k-loop(fringe) */
|
||||
TEST(RSI, RSI)
|
||||
JE(.POSTACCUM)
|
||||
|
||||
/* K-loop(unrolled) to perform A*B */
|
||||
LOOP_ALIGN
|
||||
LABEL(.CK_ITER_FRINGE)
|
||||
|
||||
/* Performing rank-1 update */
|
||||
SUB_ITER_24x4(0, RBX, RAX)
|
||||
SUB(IMM(1), RSI) // RSI(iterator) -= 1
|
||||
|
||||
/* Adjusting the addresses of A and B for the next set */
|
||||
/* new_addr = old_addr + ( unroll_factor * {MR or NR} * size_of_type ) */
|
||||
LEA(MEM(RBX, 24 * 8), RBX) // RBX = RBX + NR * 8(loads)
|
||||
LEA(MEM(RAX, 4 * 8), RAX) // RAX = RAX + MR * 8(broascasts)
|
||||
|
||||
/* Jump until RSI becomes 0 */
|
||||
JNZ(.CK_ITER_FRINGE)
|
||||
|
||||
LABEL(.POSTACCUM)
|
||||
|
||||
/* The registers from ZMM(15) to ZMM(26) contain the FMA ops
|
||||
using imaginary components from elements in B matrix.
|
||||
We should shuffle them( even and odd indices )
|
||||
SRC: ZMM(15) = ( Br0*Ai0, Bi0*Ai0, Br1*Ai0, Bi1*Ai0, ... )
|
||||
DST: ZMM(15) = ( Bi0*Ai0, Br0*Ai0, Bi1*Ai0, Br1*Ai0, ... )
|
||||
Similary for the other registers
|
||||
*/
|
||||
PERMUTE(15, 16, 17)
|
||||
PERMUTE(18, 19, 20)
|
||||
PERMUTE(21, 22, 23)
|
||||
PERMUTE(24, 25, 26)
|
||||
|
||||
/* Loading ZMM(0) with 1.0f, for reduction */
|
||||
MOV(VAR(fmaPtr), R14)
|
||||
VMOVAPS(MEM(R14), ZMM(0))
|
||||
|
||||
/* Reducing the result using real/imag accumulators, for complex arithmetic
|
||||
SRC: ZMM(3) = ( Br0*Ar0, Bi0*Ar0, Br1*Ar0, Bi1*Ar0, ... )
|
||||
ZMM(15) = ( Bi0*Ai0, Br0*Ai0, Bi1*Ai0, Br1*Ai0, ... )
|
||||
DST: ZMM(3) = ( Br0*Ar0 - Bi0*Ai0, Bi0*Ar0 + Br0*Ai0, ... )
|
||||
Similarly done for the other registers
|
||||
*/
|
||||
FMADDSUB(3, 15, 4, 16, 5, 17)
|
||||
FMADDSUB(6, 18, 7, 19, 8, 20)
|
||||
FMADDSUB(9, 21, 10, 22, 11, 23)
|
||||
FMADDSUB(12, 24, 13, 25, 14, 26)
|
||||
|
||||
/*
|
||||
The result of A*B(micro-tile) is a 4x24 matrix(row-major), as follows :
|
||||
|
||||
Cols(1-8) Cols(9-16) Cols(17-24)
|
||||
Row-1 ZMM(3) ZMM(4) ZMM(5)
|
||||
Row-2 ZMM(6) ZMM(7) ZMM(8)
|
||||
Row-3 ZMM(9) ZMM(10) ZMM(11)
|
||||
Row-4 ZMM(12) ZMM(13) ZMM(14)
|
||||
*/
|
||||
|
||||
LABEL(.ALPHA_SCALING)
|
||||
/*
|
||||
Check for alpha_mul_type, to jump to the required code-section
|
||||
Intermediate result(IR) = alpha*(A*B)
|
||||
If alpha == ( 1.0, 0.0 ) => BLIS_MUL_ONE
|
||||
IR = A*B
|
||||
else if, alpha != ( -1.0, 0.0 ) => BLIS_MUL_DEFAULT
|
||||
IR = alpha*(A*B), using complex multiplication
|
||||
else => BLIS_MUL_MINUS_ONE
|
||||
IR = 0.0 - A*B, using subtraction
|
||||
*/
|
||||
MOV(VAR(alpha_mul_type), R14)
|
||||
CMP(IMM(1), R14) // Check if alpha = 1.0
|
||||
/* Skip alpha scaling and jump to beta scaling */
|
||||
JE(.BETA_SCALING)
|
||||
|
||||
CMP(IMM(2), R14) // Check if alpha != -1.0
|
||||
/* Jump to the general case of alpha scaling */
|
||||
JE(.ALPHA_GENERAL)
|
||||
|
||||
/* Alpha scaling when alpha == -1.0 */
|
||||
LABEL(.ALPHA_MINUS_ONE)
|
||||
/* Set ZMM(1) to 0.0f, and subtract the registers from ZMM(1) */
|
||||
VXORPS(ZMM(1), ZMM(1), ZMM(1))
|
||||
/* ZMM(3) = ZMM(1) - ZMM(3) = 0.0f - A*B
|
||||
Similarly done for other registers */
|
||||
ALPHA_MINUS_ONE(3, 1, 4, 1, 5, 1)
|
||||
ALPHA_MINUS_ONE(6, 1, 7, 1, 8, 1)
|
||||
ALPHA_MINUS_ONE(9, 1, 10, 1, 11, 1)
|
||||
ALPHA_MINUS_ONE(12, 1, 13, 1, 14, 1)
|
||||
|
||||
/* Jump to beta scaling */
|
||||
JMP(.BETA_SCALING)
|
||||
|
||||
/* Alpha scaling when alpha != 1.0 and alpha != -1.0 */
|
||||
LABEL(.ALPHA_GENERAL)
|
||||
/* Load alpha onto a ZMM register */
|
||||
MOV(VAR(alpha), RAX)
|
||||
/* Broadcast the real and imag components of alpha onto the registers */
|
||||
VBROADCASTSS(MEM(RAX, 0), ZMM(1))
|
||||
VBROADCASTSS(MEM(RAX, 4), ZMM(2))
|
||||
|
||||
/* Scale the result of A*B with alpha */
|
||||
/* ZMM(15) = alphai * ZMM(3)
|
||||
ZMM(3) = alphar * ZMM(3)
|
||||
ZMM(3) = fmaddsub(ZMM(3), permute(ZMM(15)))
|
||||
|
||||
Similarly done for other pairs of registers */
|
||||
ALPHA_DEFAULT(3, 15, 4, 16, 5, 17)
|
||||
ALPHA_DEFAULT(6, 18, 7, 19, 8, 20)
|
||||
ALPHA_DEFAULT(9, 21, 10, 22, 11, 23)
|
||||
ALPHA_DEFAULT(12, 24, 13, 25, 14, 26)
|
||||
|
||||
/* Perform beta scaling */
|
||||
LABEL(.BETA_SCALING)
|
||||
/* Load the row and column strides of C */
|
||||
MOV(VAR(rs_c), RDI)
|
||||
MOV(VAR(cs_c), RSI)
|
||||
|
||||
CMP(IMM(8), RSI) // Check if C is row stored
|
||||
|
||||
JNE(.COLSTORED) // Jump to row stored
|
||||
|
||||
LABEL(.ROWSTORED)
|
||||
/*
|
||||
Check for beta_mul_type, to jump to the required code-section
|
||||
Intermediate C = beta*C + IR, where IR = alpha*A*B
|
||||
If beta == ( 0.0, 0.0 ) => BLIS_MUL_ZERO
|
||||
C = IR, skip beta-scaling
|
||||
else if beta == ( 1.0, 0.0 ) => BLIS_MUL_ONE
|
||||
C = C + IR, using addition
|
||||
else if, beta != ( -1.0, 0.0 ) => BLIS_MUL_DEFAULT
|
||||
C = beta*C + IR, using complex multiplication
|
||||
else => BLIS_MUL_MINUS_ONE
|
||||
C = ( 0.0 - C ) + IR, using subtraction
|
||||
*/
|
||||
MOV(VAR(beta_mul_type), R14)
|
||||
CMP(IMM(0), R14) // Check if beta = 0.0
|
||||
/* Skip beta scaling and jump to store */
|
||||
JE(.BETA_ZERO_ROW)
|
||||
|
||||
CMP(IMM(1), R14) // Check if beta = 1.0
|
||||
/* Jump to beta = 1.0 case */
|
||||
JE(.BETA_ONE_ROW)
|
||||
|
||||
CMP(IMM(2), R14) // Check if alpha != -1.0
|
||||
/* Jump to the general case of alpha scaling */
|
||||
JE(.BETA_DEFAULT_ROW)
|
||||
|
||||
/* Beta scaling when beta == -1.0 */
|
||||
LABEL(.BETA_MINUS_ONE_ROW)
|
||||
|
||||
/* Perform C = alpha*A*B - C */
|
||||
/* ZMM(15) = load(C)
|
||||
ZMM(15) = ZMM(3) - ZMM(15) = alpha*A*B - C
|
||||
store(ZMM(15))
|
||||
|
||||
Similarly done for other registers */
|
||||
BETA_MINUS_ONE_PRIMARY(3, 15, 4, 16, 5, 17)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_MINUS_ONE_PRIMARY(6, 18, 7, 19, 8, 20)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_MINUS_ONE_PRIMARY(9, 21, 10, 22, 11, 23)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_MINUS_ONE_PRIMARY(12, 24, 13, 25, 14, 26)
|
||||
|
||||
JMP(.END)
|
||||
|
||||
/* Beta scaling when beta == -1.0 */
|
||||
LABEL(.BETA_ONE_ROW)
|
||||
|
||||
/* Perform C = C + alpha*A*B */
|
||||
/* ZMM(15) = load(C)
|
||||
ZMM(15) = ZMM(3) + ZMM(15) = alpha*A*B + C
|
||||
store(ZMM(15))
|
||||
|
||||
Similarly done for other registers */
|
||||
BETA_ONE_PRIMARY(3, 15, 4, 16, 5, 17)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_ONE_PRIMARY(6, 18, 7, 19, 8, 20)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_ONE_PRIMARY(9, 21, 10, 22, 11, 23)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_ONE_PRIMARY(12, 24, 13, 25, 14, 26)
|
||||
|
||||
JMP(.END)
|
||||
|
||||
/* Beta scaling for generic case */
|
||||
LABEL(.BETA_DEFAULT_ROW)
|
||||
/* Load beta onto a ZMM register */
|
||||
MOV(VAR(beta), RBX)
|
||||
/* Broadcast the real and imag components of beta onto the registers */
|
||||
VBROADCASTSS(MEM(RBX, 0), ZMM(1))
|
||||
VBROADCASTSS(MEM(RBX, 4), ZMM(2))
|
||||
|
||||
/* Perform C = beta*C + alpha*A*B */
|
||||
/* ZMM(15) = load(C)
|
||||
Perform beta scaling of ZMM(15)(similar to alpha scaling)
|
||||
ZMM(15) = ZMM(3) + ZMM(15) = alpha*A*B + beta*C
|
||||
store(ZMM(15))
|
||||
|
||||
Similarly done for other pairs of registers */
|
||||
BETA_DEFAULT_PRIMARY(3, 15, 4, 16, 5, 17)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_DEFAULT_PRIMARY(6, 18, 7, 19, 8, 20)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_DEFAULT_PRIMARY(9, 21, 10, 22, 11, 23)
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
BETA_DEFAULT_PRIMARY(12, 24, 13, 25, 14, 26)
|
||||
|
||||
JMP(.END)
|
||||
|
||||
LABEL(.BETA_ZERO_ROW)
|
||||
/* This code-section is taken if we want to skip scaling */
|
||||
VMOVUPS(ZMM(3), MEM(RCX))
|
||||
VMOVUPS(ZMM(4), MEM(RCX, 64))
|
||||
VMOVUPS(ZMM(5), MEM(RCX, 128))
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
|
||||
VMOVUPS(ZMM(6), MEM(RCX))
|
||||
VMOVUPS(ZMM(7), MEM(RCX, 64))
|
||||
VMOVUPS(ZMM(8), MEM(RCX, 128))
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
|
||||
VMOVUPS(ZMM(9), MEM(RCX))
|
||||
VMOVUPS(ZMM(10), MEM(RCX, 64))
|
||||
VMOVUPS(ZMM(11), MEM(RCX, 128))
|
||||
LEA((RCX, R10, 1), RCX)
|
||||
|
||||
VMOVUPS(ZMM(12), MEM(RCX))
|
||||
VMOVUPS(ZMM(13), MEM(RCX, 64))
|
||||
VMOVUPS(ZMM(14), MEM(RCX, 128))
|
||||
|
||||
JMP(.END)
|
||||
|
||||
LABEL(.COLSTORED)
|
||||
/* Check for general stride of C */
|
||||
CMP(IMM(8), RDI) // Check if C is col stored
|
||||
JNE(.GENERALSTRIDE) // Jump to general stride
|
||||
|
||||
/* This code-section is taken if C is col-stored */
|
||||
/*
|
||||
Check for beta_mul_type, to jump to the required code-section
|
||||
Intermediate C = beta*C + IR, where IR = alpha*A*B
|
||||
If beta == ( 0.0, 0.0 ) => BLIS_MUL_ZERO
|
||||
C = IR, skip beta-scaling
|
||||
else => BLIS_MUL_DEFAULT
|
||||
C = beta*C + IR, using complex multiplication
|
||||
*/
|
||||
MOV(VAR(beta_mul_type), R14)
|
||||
CMP(IMM(0), R14) // Check if beta = 0.0
|
||||
/* Skip beta scaling and jump to store */
|
||||
JE(.BETA_ZERO_COL)
|
||||
|
||||
LABEL(.BETA_DEFAULT_COL)
|
||||
/* Load beta onto a ZMM register */
|
||||
MOV(VAR(beta), RBX)
|
||||
/* Broadcast the real and imag components of beta onto the registers */
|
||||
VBROADCASTSS(MEM(RBX, 0), ZMM(1))
|
||||
VBROADCASTSS(MEM(RBX, 4), ZMM(2))
|
||||
|
||||
/* We need to transpose the 24x4 block of alpha*A*B,
|
||||
in steps of 8x4.
|
||||
We use an 8x8 transpose routine with additional
|
||||
registers.
|
||||
|
||||
Input for transpose:
|
||||
Columns(1-8)
|
||||
Row-1 ZMM(3)
|
||||
Row-2 ZMM(6)
|
||||
Row-3 ZMM(9)
|
||||
Row-3 ZMM(12)
|
||||
Row-4 ZMM(28)
|
||||
Row-5 ZMM(29)
|
||||
Row-6 ZMM(30)
|
||||
Row-7 ZMM(31)
|
||||
*/
|
||||
|
||||
/* Transpose the 8x8 block of alpha*A*B */
|
||||
/* ZMM(15) to ZMM(22) are used as temporary registers
|
||||
for transpose operation */
|
||||
TRANSPOSE_8X8(3, 6, 9, 12, 28, 29, 30, 31,
|
||||
15, 16, 17, 18, 19, 20, 21, 22)
|
||||
|
||||
/* Scale C by beta and compute the result */
|
||||
/* This is done one row at a time */
|
||||
BETA_DEFAULT_SECONDARY(3, 15, 16)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(6, 17, 18)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(9, 19, 20)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(12, 21, 22)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(28, 15, 16)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(29, 17, 18)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(30, 19, 20)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(31, 21, 22)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
|
||||
/*
|
||||
Input for transpose:
|
||||
Columns(9-15)
|
||||
Row-1 ZMM(4)
|
||||
Row-2 ZMM(7)
|
||||
Row-3 ZMM(10)
|
||||
Row-3 ZMM(13)
|
||||
Row-4 ZMM(28)
|
||||
Row-5 ZMM(29)
|
||||
Row-6 ZMM(30)
|
||||
Row-7 ZMM(31)
|
||||
*/
|
||||
|
||||
/* Transpose the 8x8 block of alpha*A*B */
|
||||
/* ZMM(15) to ZMM(22) are used as temporary registers
|
||||
for transpose operation */
|
||||
TRANSPOSE_8X8(4, 7, 10, 13, 28, 29, 30, 31,
|
||||
15, 16, 17, 18, 19, 20, 21, 22)
|
||||
|
||||
/* Scale C by beta and compute the result */
|
||||
/* This is done one col at a time */
|
||||
BETA_DEFAULT_SECONDARY(4, 15, 16)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(7, 17, 18)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(10, 19, 20)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(13, 21, 22)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(28, 15, 16)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(29, 17, 18)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(30, 19, 20)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(31, 21, 22)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
|
||||
/*
|
||||
Input for transpose:
|
||||
Columns(16-24)
|
||||
Row-1 ZMM(5)
|
||||
Row-2 ZMM(8)
|
||||
Row-3 ZMM(11)
|
||||
Row-3 ZMM(14)
|
||||
Row-4 ZMM(28)
|
||||
Row-5 ZMM(29)
|
||||
Row-6 ZMM(30)
|
||||
Row-7 ZMM(31)
|
||||
*/
|
||||
|
||||
/* Transpose the 8x8 block of alpha*A*B */
|
||||
/* ZMM(15) to ZMM(22) are used as temporary registers
|
||||
for transpose operation */
|
||||
TRANSPOSE_8X8(5, 8, 11, 14, 28, 29, 30, 31,
|
||||
15, 16, 17, 18, 19, 20, 21, 22)
|
||||
|
||||
/* Scale C by beta and compute the result */
|
||||
/* This is done one col at a time */
|
||||
BETA_DEFAULT_SECONDARY(5, 15, 16)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(8, 17, 18)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(11, 19, 20)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(14, 21, 22)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(28, 15, 16)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(29, 17, 18)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(30, 19, 20)
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
BETA_DEFAULT_SECONDARY(31, 21, 22)
|
||||
|
||||
JMP(.END)
|
||||
|
||||
LABEL(.BETA_ZERO_COL)
|
||||
/* We need to transpose the 24x4 block of alpha*A*B,
|
||||
in steps of 8x4.
|
||||
We use an 8x8 transpose routine with additional
|
||||
registers.
|
||||
|
||||
Input for transpose:
|
||||
Columns(1-8)
|
||||
Row-1 ZMM(3)
|
||||
Row-2 ZMM(6)
|
||||
Row-3 ZMM(9)
|
||||
Row-3 ZMM(12)
|
||||
Row-4 ZMM(28)
|
||||
Row-5 ZMM(29)
|
||||
Row-6 ZMM(30)
|
||||
Row-7 ZMM(31)
|
||||
*/
|
||||
|
||||
/* Transpose the 8x8 block of alpha*A*B */
|
||||
/* ZMM(15) to ZMM(22) are used as temporary registers
|
||||
for transpose operation */
|
||||
TRANSPOSE_8X8(3, 6, 9, 12, 28, 29, 30, 31,
|
||||
15, 16, 17, 18, 19, 20, 21, 22)
|
||||
|
||||
/* Store the result back to C */
|
||||
/* We need to store only the first 256-bit lane of the
|
||||
registers post transpose */
|
||||
VMOVUPS(YMM(3), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(6), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(9), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(12), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(28), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(29), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(30), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(31), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
|
||||
/*
|
||||
Input for transpose:
|
||||
Columns(9-15)
|
||||
Row-1 ZMM(4)
|
||||
Row-2 ZMM(7)
|
||||
Row-3 ZMM(10)
|
||||
Row-3 ZMM(13)
|
||||
Row-4 ZMM(28)
|
||||
Row-5 ZMM(29)
|
||||
Row-6 ZMM(30)
|
||||
Row-7 ZMM(31)
|
||||
*/
|
||||
|
||||
/* Transpose the 8x8 block of alpha*A*B */
|
||||
/* ZMM(15) to ZMM(22) are used as temporary registers
|
||||
for transpose operation */
|
||||
TRANSPOSE_8X8(4, 7, 10, 13, 28, 29, 30, 31,
|
||||
15, 16, 17, 18, 19, 20, 21, 22)
|
||||
|
||||
/* Store the result back to C */
|
||||
/* We need to store only the first 256-bit lane of the
|
||||
registers post transpose */
|
||||
VMOVUPS(YMM(4), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(7), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(10), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(13), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(28), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(29), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(30), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(31), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
|
||||
/*
|
||||
Input for transpose:
|
||||
Columns(16-24)
|
||||
Row-1 ZMM(5)
|
||||
Row-2 ZMM(8)
|
||||
Row-3 ZMM(11)
|
||||
Row-3 ZMM(14)
|
||||
Row-4 ZMM(28)
|
||||
Row-5 ZMM(29)
|
||||
Row-6 ZMM(30)
|
||||
Row-7 ZMM(31)
|
||||
*/
|
||||
|
||||
/* Transpose the 8x8 block of alpha*A*B */
|
||||
/* ZMM(15) to ZMM(22) are used as temporary registers
|
||||
for transpose operation */
|
||||
TRANSPOSE_8X8(5, 8, 11, 14, 28, 29, 30, 31,
|
||||
15, 16, 17, 18, 19, 20, 21, 22)
|
||||
|
||||
/* Store the result back to C */
|
||||
/* We need to store only the first 256-bit lane of the
|
||||
registers post transpose */
|
||||
VMOVUPS(YMM(5), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(8), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(11), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(14), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(28), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(29), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(30), MEM(RCX))
|
||||
LEA((RCX, RSI, 1), RCX)
|
||||
VMOVUPS(YMM(31), MEM(RCX))
|
||||
|
||||
JMP(.END)
|
||||
|
||||
LABEL(.GENERALSTRIDE)
|
||||
/* This code-section is taken if C has general stride */
|
||||
/*
|
||||
In case of general strides for C, we need to load/store C
|
||||
using gather/scatter instructions.
|
||||
Visualizing C(4x8):
|
||||
---------------------------------------------
|
||||
| C00 --(cs_c)-- C10 --(cs_c)-- C20 ... C70 |
|
||||
| ------------------------------------------|
|
||||
| C01 --(cs_c)-- C11 --(cs_c)-- C21 ... C71 |
|
||||
| ------------------------------------------|
|
||||
| C02 --(cs_c)-- C12 --(cs_c)-- C22 ... C72 |
|
||||
| ------------------------------------------|
|
||||
| C03 --(cs_c)-- C13 --(cs_c)-- C23 ... C73 |
|
||||
---------------------------------------------
|
||||
|
||||
Loading C :
|
||||
Gather all elements of C row-wise onto ZMM registers
|
||||
|
||||
Compute with C(based on beta):
|
||||
Similar to row-stored case, perform beta scaling and add to
|
||||
alpha*A*B
|
||||
|
||||
Storing C :
|
||||
Scatter the result one row at a time, using ZMM registers
|
||||
*/
|
||||
MOV(VAR(offsetPtr), R9) // Load address of offsets
|
||||
VPBROADCASTQ(RSI, ZMM(31)) // Broadcast cs_c onto a register
|
||||
VPMULLQ(MEM(R9), ZMM(31), ZMM(28)) // ZMM28 = { 0*cs_c, 1*cs_c, 2*cs_c, 3*cs_c, ... }
|
||||
VPMULLQ(MEM(R9, 64), ZMM(31), ZMM(29)) // ZMM29 = { 8*cs_c, 9*cs_c, 10*cs_c, 11*cs_c, ... }
|
||||
VPMULLQ(MEM(R9, 128), ZMM(31), ZMM(30)) // ZMM30 = { 16*cs_c, 17*cs_c, 18*cs_c, 19*cs_c, ... }
|
||||
/*
|
||||
Check for beta_mul_type, to jump to the required code-section
|
||||
Intermediate C = beta*C + IR, where IR = alpha*A*B
|
||||
If beta == ( 0.0, 0.0 ) => BLIS_MUL_ZERO
|
||||
C = IR, skip beta-scaling
|
||||
else => BLIS_MUL_DEFAULT
|
||||
C = beta*C + IR, using complex multiplication
|
||||
*/
|
||||
MOV(VAR(beta_mul_type), R14)
|
||||
CMP(IMM(0), R14) // Check if beta = 0.0
|
||||
/* Skip beta scaling and jump to store */
|
||||
JE(.BETA_ZERO_GENERIC)
|
||||
|
||||
LABEL(.BETA_DEFAULT_GENERIC)
|
||||
/* Load beta onto a ZMM register */
|
||||
MOV(VAR(beta), RBX)
|
||||
/* Broadcast the real and imag components of beta onto the registers */
|
||||
VBROADCASTSS(MEM(RBX, 0), ZMM(1))
|
||||
VBROADCASTSS(MEM(RBX, 4), ZMM(2))
|
||||
|
||||
/* Compute C = beta*C + alpha*A*B, and store to C */
|
||||
BETA_DEFAULT_GENERAL(3, 4, 5)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_GENERAL(6, 7, 8)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_GENERAL(9, 10, 11)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_DEFAULT_GENERAL(12, 13, 14)
|
||||
|
||||
JMP(.END)
|
||||
|
||||
LABEL(.BETA_ZERO_GENERIC)
|
||||
/* Store the result onto C, one column at a time */
|
||||
BETA_ZERO_GENERAL(3, 4, 5)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_ZERO_GENERAL(6, 7, 8)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_ZERO_GENERAL(9, 10, 11)
|
||||
LEA((RCX, RDI, 1), RCX)
|
||||
BETA_ZERO_GENERAL(12, 13, 14)
|
||||
|
||||
LABEL(.END)
|
||||
VZEROUPPER()
|
||||
|
||||
end_asm(
|
||||
: // output operands (none)
|
||||
: // input operands
|
||||
[k] "m"(k),
|
||||
[a] "m"(a),
|
||||
[b] "m"(b),
|
||||
[c] "m"(c),
|
||||
[rs_c] "m"(rs_c),
|
||||
[cs_c] "m"(cs_c),
|
||||
[fmaPtr] "m"(fmaPtr),
|
||||
[offsetPtr] "m"(offsetPtr),
|
||||
[alpha_mul_type] "m"(alpha_mul_type),
|
||||
[beta_mul_type] "m"(beta_mul_type),
|
||||
[alpha] "m"(alpha),
|
||||
[beta] "m"(beta)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdi", "rsi", "r9", "r10", "r12", "r14",
|
||||
"k0", "k1", "k2", "k3", "k4",
|
||||
"ymm0", "ymm1", "ymm2", "ymm3",
|
||||
"ymm4", "ymm5", "ymm6", "ymm7",
|
||||
"ymm8", "ymm9", "ymm10", "ymm11",
|
||||
"ymm12", "ymm13", "ymm14", "ymm15",
|
||||
"ymm16", "ymm17", "ymm18", "ymm19",
|
||||
"ymm20", "ymm21", "ymm22", "ymm28",
|
||||
"ymm29", "ymm30", "ymm31",
|
||||
"zmm0", "zmm1", "zmm2",
|
||||
"zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8",
|
||||
"zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14",
|
||||
"zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20",
|
||||
"zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
|
||||
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory"
|
||||
)
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7);
|
||||
}
|
||||
446
kernels/zen4/3/bli_cgemm_zen4_asm_macros.h
Normal file
446
kernels/zen4/3/bli_cgemm_zen4_asm_macros.h
Normal file
@@ -0,0 +1,446 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#define BLIS_ASM_SYNTAX_ATT
|
||||
#include "bli_x86_asm_macros.h"
|
||||
|
||||
/* NOTE : This file contains the macros used to implement CGEMM native
|
||||
kernels of the following (MR, NR) pairs :
|
||||
Column-major : (24, 4)
|
||||
Row-major : (4, 24)
|
||||
The macro for micro-tile computation(SUB_ITER_24x4) accepts the load
|
||||
and broadcast addresses as parameters. Thus, we could use this macro
|
||||
for in both 24x4 col major kernel and 4x24 row major kernel, by passing
|
||||
the appropriate load and broadcast address registers.
|
||||
*/
|
||||
|
||||
/* Macro to set all the registers to zero */
|
||||
#define SET_ZERO() \
|
||||
VXORPS(ZMM(0), ZMM(0), ZMM(0)) \
|
||||
VXORPS(ZMM(1), ZMM(1), ZMM(1)) \
|
||||
VXORPS(ZMM(2), ZMM(2), ZMM(2)) \
|
||||
VXORPS(ZMM(3), ZMM(3), ZMM(3)) \
|
||||
VXORPS(ZMM(4), ZMM(4), ZMM(4)) \
|
||||
VXORPS(ZMM(5), ZMM(5), ZMM(5)) \
|
||||
VXORPS(ZMM(6), ZMM(6), ZMM(6)) \
|
||||
VXORPS(ZMM(7), ZMM(7), ZMM(7)) \
|
||||
VXORPS(ZMM(8), ZMM(8), ZMM(8)) \
|
||||
VXORPS(ZMM(9), ZMM(9), ZMM(9)) \
|
||||
VXORPS(ZMM(10), ZMM(10), ZMM(10)) \
|
||||
VXORPS(ZMM(11), ZMM(11), ZMM(11)) \
|
||||
VXORPS(ZMM(12), ZMM(12), ZMM(12)) \
|
||||
VXORPS(ZMM(13), ZMM(13), ZMM(13)) \
|
||||
VXORPS(ZMM(14), ZMM(14), ZMM(14)) \
|
||||
VXORPS(ZMM(15), ZMM(15), ZMM(15)) \
|
||||
VXORPS(ZMM(16), ZMM(16), ZMM(16)) \
|
||||
VXORPS(ZMM(17), ZMM(17), ZMM(17)) \
|
||||
VXORPS(ZMM(18), ZMM(18), ZMM(18)) \
|
||||
VXORPS(ZMM(19), ZMM(19), ZMM(19)) \
|
||||
VXORPS(ZMM(20), ZMM(20), ZMM(20)) \
|
||||
VXORPS(ZMM(21), ZMM(21), ZMM(21)) \
|
||||
VXORPS(ZMM(22), ZMM(22), ZMM(22)) \
|
||||
VXORPS(ZMM(23), ZMM(23), ZMM(23)) \
|
||||
VXORPS(ZMM(24), ZMM(24), ZMM(24)) \
|
||||
VXORPS(ZMM(25), ZMM(25), ZMM(25)) \
|
||||
VXORPS(ZMM(26), ZMM(26), ZMM(26)) \
|
||||
VXORPS(ZMM(27), ZMM(27), ZMM(27)) \
|
||||
VXORPS(ZMM(28), ZMM(28), ZMM(28)) \
|
||||
VXORPS(ZMM(29), ZMM(29), ZMM(29)) \
|
||||
VXORPS(ZMM(30), ZMM(30), ZMM(30)) \
|
||||
VXORPS(ZMM(31), ZMM(31), ZMM(31)) \
|
||||
|
||||
/* Macro to perform the rank-1 update using loads and broadcasts from the matrices(24x4) */
|
||||
/* RL represents the register that has load address, RB represents broadcast address */
|
||||
/* For the sake of comments, let's assume RL = A, and RB = B(column-major kernel) */
|
||||
#define SUB_ITER_24x4(n, RL, RB) \
|
||||
VMOVAPS(MEM(RL, (24 * n + 0) * 8), ZMM(0)) /* ZMM(0) = A[0:7][n] */ \
|
||||
VMOVAPS(MEM(RL, (24 * n + 8) * 8), ZMM(1)) /* ZMM(1) = A[8:15][n] */ \
|
||||
VMOVAPS(MEM(RL, (24 * n + 16) * 8), ZMM(2)) /* ZMM(2) = A[16:23][n] */ \
|
||||
\
|
||||
VBROADCASTSS(MEM(RB, (8 * n + 0) * 4), ZMM(28)) /* ZMM(28) = Real(B[n][0]) */ \
|
||||
VBROADCASTSS(MEM(RB, (8 * n + 1) * 4), ZMM(29)) /* ZMM(29) = Imag(B[n][0]) */ \
|
||||
VFMADD231PS(ZMM(0), ZMM(28), ZMM(3)) /* ZMM(3) = A[0:7][n] * Real(B[n][0]) */ \
|
||||
VFMADD231PS(ZMM(1), ZMM(28), ZMM(4)) /* ZMM(4) = A[8:15][n] * Real(B[n][0]) */ \
|
||||
VFMADD231PS(ZMM(2), ZMM(28), ZMM(5)) /* ZMM(5) = A[16:23][n] * Real(B[n][0]) */ \
|
||||
VFMADD231PS(ZMM(0), ZMM(29), ZMM(15)) /* ZMM(15) = A[0:7][n] * Imag(B[n][0]) */ \
|
||||
VFMADD231PS(ZMM(1), ZMM(29), ZMM(16)) /* ZMM(16) = A[8:15][n] * Imag(B[n][0]) */ \
|
||||
VFMADD231PS(ZMM(2), ZMM(29), ZMM(17)) /* ZMM(17) = A[16:23][n] * Imag(B[n][0]) */ \
|
||||
\
|
||||
VBROADCASTSS(MEM(RB, (8 * n + 2) * 4), ZMM(30)) /* ZMM(30) = Real(B[n][1]) */ \
|
||||
VBROADCASTSS(MEM(RB, (8 * n + 3) * 4), ZMM(31)) /* ZMM(31) = Imag(B[n][1]) */ \
|
||||
VFMADD231PS(ZMM(0), ZMM(30), ZMM(6)) /* ZMM(6) = A[0:7][n] * Real(B[n][1]) */ \
|
||||
VFMADD231PS(ZMM(1), ZMM(30), ZMM(7)) /* ZMM(7) = A[8:15][n] * Real(B[n][1]) */ \
|
||||
VFMADD231PS(ZMM(2), ZMM(30), ZMM(8)) /* ZMM(8) = A[16:23][n] * Real(B[n][1]) */ \
|
||||
VFMADD231PS(ZMM(0), ZMM(31), ZMM(18)) /* ZMM(18) = A[0:7][n] * Imag(B[n][1]) */ \
|
||||
VFMADD231PS(ZMM(1), ZMM(31), ZMM(19)) /* ZMM(19) = A[8:15][n] * Imag(B[n][1]) */ \
|
||||
VFMADD231PS(ZMM(2), ZMM(31), ZMM(20)) /* ZMM(20) = A[16:23][n] * Imag(B[n][1]) */ \
|
||||
\
|
||||
VBROADCASTSS(MEM(RB, (8 * n + 4) * 4), ZMM(28)) /* ZMM(28) = Real(B[n][2]) */ \
|
||||
VBROADCASTSS(MEM(RB, (8 * n + 5) * 4), ZMM(29)) /* ZMM(29) = Imag(B[n][2]) */ \
|
||||
VFMADD231PS(ZMM(0), ZMM(28), ZMM(9)) /* ZMM(9) = A[0:7][n] * Real(B[n][2]) */ \
|
||||
VFMADD231PS(ZMM(1), ZMM(28), ZMM(10)) /* ZMM(10) = A[8:15][n] * Real(B[n][2]) */ \
|
||||
VFMADD231PS(ZMM(2), ZMM(28), ZMM(11)) /* ZMM(11) = A[16:23][n] * Real(B[n][2]) */ \
|
||||
VFMADD231PS(ZMM(0), ZMM(29), ZMM(21)) /* ZMM(21) = A[0:7][n] * Imag(B[n][2]) */ \
|
||||
VFMADD231PS(ZMM(1), ZMM(29), ZMM(22)) /* ZMM(22) = A[8:15][n] * Imag(B[n][2]) */ \
|
||||
VFMADD231PS(ZMM(2), ZMM(29), ZMM(23)) /* ZMM(23) = A[16:23][n] * Imag(B[n][2]) */ \
|
||||
\
|
||||
VBROADCASTSS(MEM(RB, (8 * n + 6) * 4), ZMM(30)) /* ZMM(30) = Real(B[n][3]) */ \
|
||||
VBROADCASTSS(MEM(RB, (8 * n + 7) * 4), ZMM(31)) /* ZMM(31) = Imag(B[n][3]) */ \
|
||||
VFMADD231PS(ZMM(0), ZMM(30), ZMM(12)) /* ZMM(12) = A[0:7][n] * Real(B[n][3]) */ \
|
||||
VFMADD231PS(ZMM(1), ZMM(30), ZMM(13)) /* ZMM(13) = A[8:15][n] * Real(B[n][3]) */ \
|
||||
VFMADD231PS(ZMM(2), ZMM(30), ZMM(14)) /* ZMM(14) = A[16:23][n] * Real(B[n][3]) */ \
|
||||
VFMADD231PS(ZMM(0), ZMM(31), ZMM(24)) /* ZMM(24) = A[0:7][n] * Imag(B[n][3]) */ \
|
||||
VFMADD231PS(ZMM(1), ZMM(31), ZMM(25)) /* ZMM(25) = A[8:15][n] * Imag(B[n][3]) */ \
|
||||
VFMADD231PS(ZMM(2), ZMM(31), ZMM(26)) /* ZMM(26) = A[16:23][n] * Imag(B[n][3]) */ \
|
||||
|
||||
/* Macro to scale the registers */
|
||||
/* 'B' represents the broadcasted register, to be used for scaling */
|
||||
#define SCALE(B, R1, O1, R2, O2, R3, O3) \
|
||||
VMULPS(ZMM(B), ZMM(R1), ZMM(O1)) /* ZMM(O1) = ZMM(B) * ZMM(R1) */ \
|
||||
VMULPS(ZMM(B), ZMM(R2), ZMM(O2)) /* ZMM(O2) = ZMM(B) * ZMM(R2) */ \
|
||||
VMULPS(ZMM(B), ZMM(R3), ZMM(O3)) /* ZMM(O3) = ZMM(B) * ZMM(R3) */ \
|
||||
|
||||
/* Macro to shuffle even and odd indexed elements in a ZMM register */
|
||||
#define PERMUTE(I1, I2, I3) \
|
||||
/* For col major kernel: ZMM(I1) = { Ai0.Bi0, Ar0.Bi0, Ai1.Bi1, Ar1.Bi1, ... }
|
||||
For row major kernel: ZMM(I1) = { Bi0.Ai0, Br0.Ai0, Bi1.Ai1, Br1.Ai1, ... }.
|
||||
Similarly done for the other registers */ \
|
||||
VPERMILPS(IMM(0xB1), ZMM(I1), ZMM(I1)) \
|
||||
VPERMILPS(IMM(0xB1), ZMM(I2), ZMM(I2)) \
|
||||
VPERMILPS(IMM(0xB1), ZMM(I3), ZMM(I3)) \
|
||||
|
||||
/* Macro to reduce a real and imag accumulator pair, as per complex arithmetic */
|
||||
/* Macro assumes that ZMM(0) has 1.0f broadcasted in it */
|
||||
#define FMADDSUB(R1, I1, R2, I2, R3, I3) \
|
||||
/* ZMM(R1) = ZMM(R1) - 1.0f * ZMM(I1)
|
||||
For col major kernel: ZMM(R1) = { Ar0.Br0 - Ai0.Bi0, Ai0.Br0 + Ar0.Bi0 }
|
||||
For row major kernel: ZMM(R1) = { Br0.Ar0 - Bi0.Ai0, Bi0.Ar0 + Br0.Ai0 }
|
||||
Similarly done for the other pairs */ \
|
||||
VFMADDSUB132PS(ZMM(0), ZMM(I1), ZMM(R1)) \
|
||||
VFMADDSUB132PS(ZMM(0), ZMM(I2), ZMM(R2)) \
|
||||
VFMADDSUB132PS(ZMM(0), ZMM(I3), ZMM(R3)) \
|
||||
|
||||
/* Macro to handle alpha scaling when alpha is -1.0f */
|
||||
/* Macro assumes that ZMM(1) and ZMM(2) have real and imag components
|
||||
of alpha already broadcasted */
|
||||
/* ZMM(S1) = ZMM(S2) = ZMM(S3) = 0.0f, for alpha-scaling */
|
||||
#define ALPHA_MINUS_ONE(R1, S1, R2, S2, R3, S3) \
|
||||
VSUBPS(ZMM(R1), ZMM(S1), ZMM(R1)) /* ZMM(R1) = 0.0f - ZMM(R1) */ \
|
||||
VSUBPS(ZMM(R2), ZMM(S2), ZMM(R2)) /* ZMM(R2) = 0.0f - ZMM(R2) */ \
|
||||
VSUBPS(ZMM(R3), ZMM(S3), ZMM(R3)) /* ZMM(R3) = 0.0f - ZMM(R3) */ \
|
||||
|
||||
/* Macro to hadnle alpha scaling in generic case */
|
||||
/* Macro assumes that ZMM(1) and ZMM(2) have real and imag components
|
||||
of alpha already broadcasted */
|
||||
#define ALPHA_DEFAULT(R1, I1, R2, I2, R3, I3) \
|
||||
/* Scale with real and imag components of beta */ \
|
||||
/* Assume ZMM(R1) = { Ar0, Ai0, ... } */ \
|
||||
/* ZMM(I1) = { Ar0.alphai, Ai0.alphai, ... }
|
||||
Similarly done for other registers */ \
|
||||
SCALE(2, R1, I1, R2, I2, R3, I3) \
|
||||
/* ZMM(R1) = { Ar0.alphar, Ai0.alphar, ... }
|
||||
Similarly done for other registers */ \
|
||||
SCALE(1, R1, R1, R2, R2, R3, R3) \
|
||||
\
|
||||
/* Shuffle the imag accumulators for reduction */ \
|
||||
/* ZMM(I1) = { Ai0.alphai, Ar0.alphai, ... }
|
||||
Similarly done for other registers */ \
|
||||
PERMUTE(I1, I2, I3) \
|
||||
\
|
||||
/* Reduce using fmaddsub instruction */ \
|
||||
/* ZMM(R1) = { Ar0.alphar - Ai0.alphai, Ai0.alphar + Ar0.alphai, ... }
|
||||
Similarly done for other registers */ \
|
||||
FMADDSUB(R1, I1, R2, I2, R3, I3) \
|
||||
|
||||
/* Macro to handle beta scaling when beta is 1.0f, with primary storage */
|
||||
#define BETA_ONE_PRIMARY(R1, C1, R2, C2, R3, C3) \
|
||||
/* Load C onto the registers*/ \
|
||||
VMOVUPS(MEM(RCX), ZMM(C1)) \
|
||||
VMOVUPS(MEM(RCX, 64), ZMM(C2)) \
|
||||
VMOVUPS(MEM(RCX, 128), ZMM(C3)) \
|
||||
\
|
||||
/* Add C to the result of A*B */ \
|
||||
VADDPS(ZMM(C1), ZMM(R1), ZMM(C1)) /* ZMM(C1) = ZMM(R1) + ZMM(C1) */ \
|
||||
VADDPS(ZMM(C2), ZMM(R2), ZMM(C2)) /* ZMM(C2) = ZMM(R2) + ZMM(C2) */ \
|
||||
VADDPS(ZMM(C3), ZMM(R3), ZMM(C3)) /* ZMM(C3) = ZMM(R2) + ZMM(C3) */ \
|
||||
\
|
||||
/* Store the results onto C */ \
|
||||
VMOVUPS(ZMM(C1), MEM(RCX)) \
|
||||
VMOVUPS(ZMM(C2), MEM(RCX, 64)) \
|
||||
VMOVUPS(ZMM(C3), MEM(RCX, 128)) \
|
||||
|
||||
/* Macro to handle beta scaling when beta is -1.0f, with primary storage */
|
||||
#define BETA_MINUS_ONE_PRIMARY(R1, C1, R2, C2, R3, C3) \
|
||||
/* Load C onto the registers*/ \
|
||||
VMOVUPS(MEM(RCX), ZMM(C1)) \
|
||||
VMOVUPS(MEM(RCX, 64), ZMM(C2)) \
|
||||
VMOVUPS(MEM(RCX, 128), ZMM(C3)) \
|
||||
\
|
||||
/* Subtract C from the result of alpha*A*B(use pre-existing macro) */ \
|
||||
/* ZMM(C1) = ZMM(R1) - ZMM(C1)
|
||||
Similarly done for other registers */ \
|
||||
ALPHA_MINUS_ONE(C1, R1, C2, R2, C3, R3) \
|
||||
\
|
||||
/* Store the results onto C */ \
|
||||
VMOVUPS(ZMM(C1), MEM(RCX)) \
|
||||
VMOVUPS(ZMM(C2), MEM(RCX, 64)) \
|
||||
VMOVUPS(ZMM(C3), MEM(RCX, 128)) \
|
||||
|
||||
/* Macro to scale a set of 3 registers with beta, with primary storage */
|
||||
/* Macro assumes that ZMM(1) and ZMM(2) have real and imag components
|
||||
of beta already broadcasted */
|
||||
/* Macro uses C1...C3 to load C, and R1...R3 contains the result of alpha*A*B */
|
||||
/* Macro uses ZMM(28) - ZMM(30) for beta*C computation */
|
||||
#define BETA_DEFAULT_PRIMARY(R1, C1, R2, C2, R3, C3) \
|
||||
/* Load C onto the registers*/ \
|
||||
VMOVUPS(MEM(RCX), ZMM(C1)) \
|
||||
VMOVUPS(MEM(RCX, 64), ZMM(C2)) \
|
||||
VMOVUPS(MEM(RCX, 128), ZMM(C3)) \
|
||||
\
|
||||
/* Scale C by beta(using pre-existing macro) */ \
|
||||
/* Assume ZMM(C1) = { Cr0, Ci0, ... } */ \
|
||||
/* ZMM(C1) = { Cr0.betar - Ci0.betai, Ci0.betar + Cr0.betai, ... }
|
||||
Similarly done for other registers */ \
|
||||
ALPHA_DEFAULT(C1, 28, C2, 29, C3, 30) \
|
||||
\
|
||||
/* Add beta*C to the result of alpha*A*B */ \
|
||||
VADDPS(ZMM(C1), ZMM(R1), ZMM(C1)) /* ZMM(C1) = ZMM(R1) + ZMM(C1) */ \
|
||||
VADDPS(ZMM(C2), ZMM(R2), ZMM(C2)) /* ZMM(C2) = ZMM(R2) + ZMM(C2) */ \
|
||||
VADDPS(ZMM(C3), ZMM(R3), ZMM(C3)) /* ZMM(C3) = ZMM(R2) + ZMM(C3) */ \
|
||||
\
|
||||
/* Store the results onto C */ \
|
||||
VMOVUPS(ZMM(C1), MEM(RCX)) \
|
||||
VMOVUPS(ZMM(C2), MEM(RCX, 64)) \
|
||||
VMOVUPS(ZMM(C3), MEM(RCX, 128)) \
|
||||
|
||||
/* Macro to perform 8x8 transpose of 64-bit elements */
|
||||
/* Transpose is in-place(R0-R7), T0-T7 are temporary registers */
|
||||
#define TRANSPOSE_8X8(R0, R1, R2, R3, R4, R5, R6, R7, \
|
||||
T0, T1, T2, T3, T4, T5, T6, T7) \
|
||||
/*
|
||||
Let's consider the following case:
|
||||
ZMM(R0) = { 0, 1, 2, 3, 4, 5, 6, 7 }
|
||||
ZMM(R1) = { 8, 9, 10, 11, 12, 13, 14, 15 }
|
||||
.
|
||||
.
|
||||
.
|
||||
ZMM(R7) = { 56, 57, 58, 59, 60, 61, 62, 63 }
|
||||
|
||||
Expected output:
|
||||
ZMM(R0) = { 0, 8, 16, 24, 32, 40, 48, 56 }
|
||||
ZMM(R1) = { 1, 9, 17, 25, 33, 41, 49, 57 }
|
||||
.
|
||||
.
|
||||
.
|
||||
ZMM(R7) = { 7, 15, 23, 31, 39, 47, 55, 63 }.
|
||||
*/ \
|
||||
/* Inputs : ZMM(R0) = { 0, 1, 2, 3, 4, 5, 6, 7 }
|
||||
ZMM(R1) = { 8, 9, 10, 11, 12, 13, 14, 15 }
|
||||
ZMM(R2) = { 16, 17, 18, 19, 20, 21, 22, 23 }
|
||||
ZMM(R3) = { 24, 25, 26, 27, 28, 29, 30, 31 }
|
||||
...
|
||||
Outputs: ZMM(T0) = { 0, 8, 2, 10, 4, 12, 6, 14 }
|
||||
ZMM(R1) = { 1, 9, 3, 11, 5, 13, 7, 15 }
|
||||
ZMM(T2) = { 16, 24, 18, 26, 20, 28, 22, 30 }
|
||||
ZMM(R3) = { 17, 25, 19, 27, 21, 29, 23, 31 }
|
||||
... */ \
|
||||
VUNPCKLPD(ZMM(R1), ZMM(R0), ZMM(T0)) \
|
||||
VUNPCKHPD(ZMM(R1), ZMM(R0), ZMM(R1)) \
|
||||
VUNPCKLPD(ZMM(R3), ZMM(R2), ZMM(T1)) \
|
||||
VUNPCKHPD(ZMM(R3), ZMM(R2), ZMM(R3)) \
|
||||
VUNPCKLPD(ZMM(R5), ZMM(R4), ZMM(T2)) \
|
||||
VUNPCKHPD(ZMM(R5), ZMM(R4), ZMM(R5)) \
|
||||
VUNPCKLPD(ZMM(R7), ZMM(R6), ZMM(T3)) \
|
||||
VUNPCKHPD(ZMM(R7), ZMM(R6), ZMM(R7)) \
|
||||
\
|
||||
/* Moving the contents of temporary registers
|
||||
to input registers for reuse */ \
|
||||
/* Output: ZMM(R0) = { 0, 8, 2, 10, 4, 12, 6, 14 }
|
||||
ZMM(R2) = { 16, 24, 18, 26, 20, 28, 22, 30 }
|
||||
ZMM(R4) = { 32, 40, 34, 42, 36, 44, 38, 46 }
|
||||
ZMM(R6) = { 48, 56, 50, 58, 52, 60, 54, 62 } */ \
|
||||
VMOVAPD(ZMM(T0), ZMM(R0)) \
|
||||
VMOVAPD(ZMM(T1), ZMM(R2)) \
|
||||
VMOVAPD(ZMM(T2), ZMM(R4)) \
|
||||
VMOVAPD(ZMM(T3), ZMM(R6)) \
|
||||
\
|
||||
/* Inputs : ZMM(R0) = { 0, 8, 2, 10, 4, 12, 6, 14 }
|
||||
ZMM(R2) = { 16, 24, 18, 26, 20, 28, 22, 30 }
|
||||
ZMM(R4) = { 32, 40, 34, 42, 36, 44, 38, 46 }
|
||||
ZMM(R6) = { 48, 56, 50, 58, 52, 60, 54, 62 }
|
||||
Outputs : ZMM(T0) = { 0, 8, 4, 12, 16, 24, 20, 28 }
|
||||
ZMM(T1) = { 32, 40, 36, 44, 48, 56, 52, 60 }
|
||||
ZMM(T2) = { 2, 10, 6, 14, 18, 26, 22, 30 }
|
||||
ZMM(T3) = { 34, 42, 38, 46, 50, 58, 54, 62 } */ \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(R2), ZMM(R0), ZMM(T0)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(R6), ZMM(R4), ZMM(T1)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(R2), ZMM(R0), ZMM(T2)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(R6), ZMM(R4), ZMM(T3)) \
|
||||
\
|
||||
/* Inputs : ZMM(R1) = { 1, 9, 3, 11, 5, 13, 7, 15 }
|
||||
ZMM(R3) = { 17, 25, 19, 27, 21, 29, 23, 31 }
|
||||
ZMM(R5) = { 33, 41, 35, 43, 37, 45, 39, 47 }
|
||||
ZMM(R7) = { 49, 57, 51, 59, 53, 61, 55, 63 }
|
||||
Outputs : ZMM(T4) = { 1, 9, 5, 13, 17, 25, 21, 29 }
|
||||
ZMM(T5) = { 33, 41, 37, 45, 49, 57, 53, 61 }
|
||||
ZMM(T6) = { 3, 11, 7, 15, 19, 27, 23, 31 }
|
||||
ZMM(T7) = { 35, 43, 39, 47, 51, 59, 55, 63 } */ \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(R3), ZMM(R1), ZMM(T4)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(R7), ZMM(R5), ZMM(T5)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(R3), ZMM(R1), ZMM(T6)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(R7), ZMM(R5), ZMM(T7)) \
|
||||
\
|
||||
/* Inputs : ZMM(T0) = { 0, 8, 4, 12, 16, 24, 20, 28 }
|
||||
ZMM(T1) = { 32, 40, 36, 44, 48, 56, 52, 60 }
|
||||
ZMM(T2) = { 2, 10, 6, 14, 18, 26, 22, 30 }
|
||||
ZMM(T3) = { 34, 42, 38, 46, 50, 58, 54, 62 }
|
||||
|
||||
Outputs : ZMM(R0) = { 0, 8, 16, 24, 32, 40, 48, 56 }
|
||||
ZMM(R2) = { 2, 10, 18, 26, 34, 42, 50, 58 }
|
||||
ZMM(R4) = { 4, 12, 20, 28, 36, 44, 52, 60 }
|
||||
ZMM(R6) = { 6, 14, 22, 30, 38, 46, 54, 62 } */ \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(T1), ZMM(T0), ZMM(R0)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(T3), ZMM(T2), ZMM(R2)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(T1), ZMM(T0), ZMM(R4)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(T3), ZMM(T2), ZMM(R6)) \
|
||||
\
|
||||
/* Inputs : ZMM(T4) = { 1, 9, 5, 13, 17, 25, 21, 29 }
|
||||
ZMM(T5) = { 33, 41, 37, 45, 49, 57, 53, 61 }
|
||||
ZMM(T6) = { 3, 11, 7, 15, 19, 27, 23, 31 }
|
||||
ZMM(T7) = { 35, 43, 39, 47, 51, 59, 55, 63 }
|
||||
|
||||
Outputs : ZMM(R1) = { 1, 9, 17, 25, 33, 41, 49, 57 }
|
||||
ZMM(R3) = { 3, 11, 19, 27, 35, 43, 51, 59 }
|
||||
ZMM(R5) = { 5, 13, 21, 29, 37, 45, 53, 61 }
|
||||
ZMM(R7) = { 7, 15, 23, 31, 39, 47, 55, 63 } */ \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(T5), ZMM(T4), ZMM(R1)) \
|
||||
VSHUFF64X2(IMM(0x88), ZMM(T7), ZMM(T6), ZMM(R3)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(T5), ZMM(T4), ZMM(R5)) \
|
||||
VSHUFF64X2(IMM(0xDD), ZMM(T7), ZMM(T6), ZMM(R7)) \
|
||||
|
||||
/* Macro to scale a register with beta, with secondary storage */
|
||||
/* Macro uses C1 to load C, and R1 contains the result of alpha*A*B */
|
||||
/* Macro assumes that ZMM(1) and ZMM(2) have real and imag components
|
||||
of beta already broadcasted */
|
||||
#define BETA_DEFAULT_SECONDARY(R1, C1, T1) \
|
||||
/* Load C onto the register*/ \
|
||||
VMOVUPS(MEM(RCX), YMM(C1)) \
|
||||
\
|
||||
/* Scale C by beta */ \
|
||||
/* Assume ZMM(C1) = { Cr0, Ci0, ... } */ \
|
||||
/* YMM(C1) = { Cr0.betar - Ci0.betai, Ci0.betar + Cr0.betai, ... } */ \
|
||||
VMULPS(YMM(2), YMM(C1), YMM(T1)) \
|
||||
VMULPS(YMM(1), YMM(C1), YMM(C1)) \
|
||||
VPERMILPS(IMM(0xB1), YMM(T1), YMM(T1)) \
|
||||
VFMADDSUB132PS(YMM(0), YMM(T1), YMM(C1)) \
|
||||
\
|
||||
/* Add beta*C to the result of alpha*A*B */ \
|
||||
VADDPS(YMM(C1), YMM(R1), YMM(C1)) /* ZMM(C1) = ZMM(R1) + ZMM(C1) */ \
|
||||
\
|
||||
/* Store the result onto C */ \
|
||||
VMOVUPS(YMM(C1), MEM(RCX)) \
|
||||
|
||||
/* Macro to scale a register with beta, with general strides */
|
||||
/* Macro gets alpha*A*B in R1, R2, R3 for a column */
|
||||
/* Macro uses ZMM(15) - ZMM(17) to gather C */
|
||||
/* Macro uses ZMM(18) - ZMM(20) as temmporary registers */
|
||||
/* Macro assumes that ZMM(28) - ZMM(30) have the addresses of C for
|
||||
gather/scatter(one column at a time) */
|
||||
/* Macro assumes that ZMM(1) and ZMM(2) have real and imag components
|
||||
of beta already broadcasted */
|
||||
#define BETA_DEFAULT_GENERAL(R1, R2, R3) \
|
||||
/* Set the masks to 1 */ \
|
||||
KXNORW(K(0), K(0), K(1)) \
|
||||
KXNORW(K(0), K(0), K(2)) \
|
||||
KXNORW(K(0), K(0), K(3)) \
|
||||
\
|
||||
/* Gather elements from C, one column at a time */ \
|
||||
VGATHERQPD(MEM(RCX, ZMM(28), 1), ZMM(15) MASK_K(1)) \
|
||||
VGATHERQPD(MEM(RCX, ZMM(29), 1), ZMM(16) MASK_K(2)) \
|
||||
VGATHERQPD(MEM(RCX, ZMM(30), 1), ZMM(17) MASK_K(3)) \
|
||||
\
|
||||
/* Scale C by beta */ \
|
||||
/* Assume ZMM(15) = { Cr0, Ci0, ... } */ \
|
||||
/* ZMM(15) = { Cr0.betar - Ci0.betai, Ci0.betar + Cr0.betai, ... } */ \
|
||||
VMULPS(ZMM(2), ZMM(15), ZMM(18)) \
|
||||
VMULPS(ZMM(1), ZMM(15), ZMM(15)) \
|
||||
VPERMILPS(IMM(0xB1), ZMM(18), ZMM(18)) \
|
||||
VFMADDSUB132PS(ZMM(0), ZMM(18), ZMM(15)) \
|
||||
\
|
||||
/* Scale C by beta */ \
|
||||
/* Assume ZMM(16) = { Cr1, Ci1, ... } */ \
|
||||
/* ZMM(16) = { Cr1.betar - Ci1.betai, Ci1.betar + Cr1.betai, ... } */ \
|
||||
VMULPS(ZMM(2), ZMM(16), ZMM(19)) \
|
||||
VMULPS(ZMM(1), ZMM(16), ZMM(16)) \
|
||||
VPERMILPS(IMM(0xB1), ZMM(19), ZMM(19)) \
|
||||
VFMADDSUB132PS(ZMM(0), ZMM(19), ZMM(16)) \
|
||||
\
|
||||
/* Scale C by beta */ \
|
||||
/* Assume ZMM(17) = { Cr2, Ci2, ... } */ \
|
||||
/* ZMM(17) = { Cr2.betar - Ci2.betai, Ci2.betar + Cr2.betai, ... } */ \
|
||||
VMULPS(ZMM(2), ZMM(17), ZMM(20)) \
|
||||
VMULPS(ZMM(1), ZMM(17), ZMM(17)) \
|
||||
VPERMILPS(IMM(0xB1), ZMM(20), ZMM(20)) \
|
||||
VFMADDSUB132PS(ZMM(0), ZMM(20), ZMM(17)) \
|
||||
\
|
||||
/* Add beta*C to the result of alpha*A*B */ \
|
||||
VADDPS(ZMM(15), ZMM(R1), ZMM(15)) /* ZMM(15) = ZMM(R1) + ZMM(15) */ \
|
||||
VADDPS(ZMM(16), ZMM(R2), ZMM(16)) /* ZMM(16) = ZMM(R2) + ZMM(16) */ \
|
||||
VADDPS(ZMM(17), ZMM(R3), ZMM(17)) /* ZMM(17) = ZMM(R3) + ZMM(17) */ \
|
||||
\
|
||||
/* Reset the mask to 1 */ \
|
||||
KXNORW(K(0), K(0), K(1)) \
|
||||
KXNORW(K(0), K(0), K(2)) \
|
||||
KXNORW(K(0), K(0), K(3)) \
|
||||
\
|
||||
/* Scatter the result to C, one column at a time */ \
|
||||
VSCATTERQPD(ZMM(15), MEM(RCX, ZMM(28), 1) MASK_K(1)) \
|
||||
VSCATTERQPD(ZMM(16), MEM(RCX, ZMM(29), 1) MASK_K(2)) \
|
||||
VSCATTERQPD(ZMM(17), MEM(RCX, ZMM(30), 1) MASK_K(3)) \
|
||||
|
||||
/* Macro to store alpha*A*B onto C, with general strides */
|
||||
/* Macro gets alpha*A*B in R1, R2, R3 for a column */
|
||||
/* Macro assumes that ZMM(28) - ZMM(30) have the addresses of C for
|
||||
scatter(one column at a time) */
|
||||
#define BETA_ZERO_GENERAL(R1, R2, R3) \
|
||||
/* Set the masks to 1 */ \
|
||||
KXNORW(K(0), K(0), K(1)) \
|
||||
KXNORW(K(0), K(0), K(2)) \
|
||||
KXNORW(K(0), K(0), K(3)) \
|
||||
\
|
||||
/* Scatter the result to C, one column at a time */ \
|
||||
VSCATTERQPD(ZMM(R1), MEM(RCX, ZMM(28), 1) MASK_K(1)) \
|
||||
VSCATTERQPD(ZMM(R2), MEM(RCX, ZMM(29), 1) MASK_K(2)) \
|
||||
VSCATTERQPD(ZMM(R3), MEM(RCX, ZMM(30), 1) MASK_K(3))
|
||||
|
||||
@@ -168,6 +168,8 @@ PACKM_KER_PROT( double, d, packm_zen4_asm_8xk )
|
||||
PACKM_KER_PROT( double, d, packm_zen4_asm_24xk )
|
||||
PACKM_KER_PROT( double, d, packm_zen4_asm_32xk )
|
||||
PACKM_KER_PROT( double, d, packm_32xk_zen4_ref )
|
||||
PACKM_KER_PROT( scomplex, c, packm_zen4_asm_24xk )
|
||||
PACKM_KER_PROT( scomplex, c, packm_zen4_asm_4xk )
|
||||
PACKM_KER_PROT( dcomplex, z, packm_zen4_asm_12xk )
|
||||
PACKM_KER_PROT( dcomplex, z, packm_zen4_asm_4xk )
|
||||
|
||||
@@ -176,6 +178,8 @@ GEMM_UKR_PROT( double, d, gemm_avx512_asm_8x24 )
|
||||
GEMM_UKR_PROT( double, d, gemm_zen4_asm_32x6 )
|
||||
GEMM_UKR_PROT( dcomplex, z, gemm_zen4_asm_12x4 )
|
||||
GEMM_UKR_PROT( dcomplex, z, gemm_zen4_asm_4x12 )
|
||||
GEMM_UKR_PROT( scomplex, c, gemm_zen4_asm_24x4 )
|
||||
GEMM_UKR_PROT( scomplex, c, gemm_zen4_asm_4x24 )
|
||||
|
||||
// dgemm native macro kernel
|
||||
void bli_dgemm_avx512_asm_8x24_macro_kernel
|
||||
|
||||
Reference in New Issue
Block a user