AVX512 optimizations for CGEMM(Native)

- Implemented the following AVX512 native
  computational kernels for CGEMM :
  Row-preferential    : 4x24
  Column-preferential : 24x4

- The implementations use a common set of macros,
  defined in a separate header. This is due to the
  fact that the implementations differ solely on
  the matrix chosen for load/broadcast operations.

- Added the associated AVX512 based packing kernels,
  packing 24xk and 4xk panels of input.

- Registered the column-preferential kernel(24x4) in
  ZEN4 and ZEN5 contexts. Further updated the cache-blocking
  parameters.

- Removed redundant BLIS object creation and its contingencies
  in the native micro-kernel testing interface(for complex types).
  Added the required unit-tests for memory and functionality
  checks of the new kernels.

AMD-Interal: [CPUPL-6498]
Change-Id: I520ff17dba4c2f9bc277bf33ba9ab4384408ffe1
This commit is contained in:
Vignesh Balasubramanian
2025-02-18 11:15:53 +05:30
committed by Vignesh Balasubramanian
parent 6c29236166
commit 99770558bb
11 changed files with 3491 additions and 30 deletions

View File

@@ -0,0 +1,845 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#define BLIS_ASM_SYNTAX_ATT
#include "bli_x86_asm_macros.h"
/**
* Shuffle 2 scomplex elements selected by imm8 from S1 and S2,
* and store the results in D1
* S1 : 1 9 3 11 5 13 7 15
* S2 : 2 10 4 12 6 14 8 16
* D1 : 1 9 5 13 2 10 6 14
* D2 : 3 11 7 15 4 12 8 16
*/
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
\
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
/**
* Unpacks and interleave low half and high half of each
* 128-bit lane in S1 and S2 and store into D1 and D2
* respectively.
* S1 : 1 2 3 4 5 6 7 8
* S2 : 9 10 11 12 13 14 15 16
* D1 : 1 9 3 11 5 13 7 15
* D2 : 2 10 4 12 6 14 8 16
*/
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
\
vunpcklpd(zmm(S1), zmm(S2), zmm(D1)) \
vunpckhpd(zmm(S1), zmm(S2), zmm(D2)) \
vunpcklpd(zmm(S3), zmm(S4), zmm(D3)) \
vunpckhpd(zmm(S3), zmm(S4), zmm(D4))
void bli_cpackm_zen4_asm_24xk
(
conj_t conja,
pack_t schema,
dim_t cdim0,
dim_t k0,
dim_t k0_max,
scomplex* restrict kappa,
scomplex* restrict a, inc_t inca0, inc_t lda0,
scomplex* restrict p, inc_t ldp0,
cntx_t* restrict cntx
)
{
// This is the panel dimension assumed by the packm kernel.
const dim_t mnr = 24;
// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
// NOTE : k_iter is in blocks of 8, due to the SIMD width of scomplex
// in a 512-register. This way, we could still perform AVX512 loads
// and stores in case of the matrix being in row-major format.
const uint64_t k_iter = k0 / 8;
const uint64_t k_left = k0 % 8;
/**
* Preparing the mask for k_left, since we are computing in blocks of 8.
* For the edge cases, mask is set to load and store only the leftover elements.
*/
uint16_t one = 1;
uint16_t mask = ( one << ( 2 * k_left ) ) - one;
// NOTE: For the purposes of the comments in this packm kernel, we
// interpret inca and lda as rs_a and cs_a, respectively, and similarly
// interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading
// this packm kernel, you should think of the operation as packing an
// m x n micropanel, where m and n are tiny and large, respectively, and
// where elements of each column of the packed matrix P are contiguous.
// (This packm kernel can still be used to pack micropanels of matrix B
// in a gemm operation.)
const uint64_t inca = inca0;
const uint64_t lda = lda0;
const uint64_t ldp = ldp0;
const bool gs = ( inca0 != 1 && lda0 != 1 );
// NOTE: If/when this kernel ever supports scaling by kappa within the
// assembly region, this constraint should be lifted.
const bool unitk = bli_ceq1( *kappa );
// -------------------------------------------------------------------------
if ( cdim0 == mnr && !gs && !conja && unitk )
{
begin_asm()
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(a), rax) // load address of source buffer
mov(var(a), r13) // load address of source buffer
mov(var(inca), r8) // load inca
mov(var(lda), r10) // load lda
lea(mem( , r8, 8), r8) // inca *= sizeof(scomplex)
lea(mem( , r10, 8), r10) // lda *= sizeof(scomplex)
mov(var(p), rbx) // load address of p.
lea(mem( , r8, 8), r14) // r14 = 8*inca
cmp(imm(8), r8) // set ZF if (8*inca) == 8.
jz(.CCOLUNIT) // jump to column storage case
// -- kappa unit, row storage on A -------------------------------
label(.CROWUNIT)
lea(mem(r8, r8, 2), r12) // r12 = 3*inca
lea(mem(r12, r8, 2), rcx) // rcx = 5*inca
lea(mem(r12, r8, 4), rdx) // rdx = 7*inca
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.CCONKLEFTROWU) // if i == 0, jump to code that
// contains the k_left loop.
label(.CKITERROWU) // MAIN LOOP (k_iter)
/* The 24x8 block in every iteration is broken down into 3
sets of 8x8 packing in every iteration of the main loop */
/* Load first 8 rows of matrix */
vmovups(mem(rax, 0), zmm6)
vmovups(mem(rax, r8, 1, 0), zmm8)
vmovups(mem(rax, r8, 2, 0), zmm10)
vmovups(mem(rax, r12, 1, 0), zmm12)
vmovups(mem(rax, r8, 4, 0), zmm14)
vmovups(mem(rax, rcx, 1, 0), zmm16)
vmovups(mem(rax, r12, 2, 0), zmm18)
vmovups(mem(rax, rdx, 1, 0), zmm20)
/* Transpose the 8x8 matrix onto another set of 8 registers */
/*
Input :
zmm6 --> row1
zmm8 --> row2
zmm10 --> row3
zmm12 --> row4
zmm14 --> row5
zmm16 --> row6
zmm18 --> row7
zmm20 --> row8
Output(after transpose) :
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
| | | | | | | |
| | | | | | | |
V V V V V V V V
col1 col2 col3 col4 col5 col6 col7 col8
*/
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
/* Store the 8 rows post transpose */
vmovups(zmm0, mem(rbx, 0*192))
vmovups(zmm4, mem(rbx, 1*192))
vmovups(zmm2, mem(rbx, 2*192))
vmovups(zmm6, mem(rbx, 3*192))
vmovups(zmm1, mem(rbx, 4*192))
vmovups(zmm5, mem(rbx, 5*192))
vmovups(zmm3, mem(rbx, 6*192))
vmovups(zmm8, mem(rbx, 7*192))
add(r14, rax) // a += 8*inca;
/* Load another 8 rows of matrix */
vmovups(mem(rax, 0), zmm6)
vmovups(mem(rax, r8, 1, 0), zmm8)
vmovups(mem(rax, r8, 2, 0), zmm10)
vmovups(mem(rax, r12, 1, 0), zmm12)
vmovups(mem(rax, r8, 4, 0), zmm14)
vmovups(mem(rax, rcx, 1, 0), zmm16)
vmovups(mem(rax, r12, 2, 0), zmm18)
vmovups(mem(rax, rdx, 1, 0), zmm20)
/* Transpose the 8x8 matrix onto another set of 8 registers */
/*
Input :
zmm6 --> row9
zmm8 --> row10
zmm10 --> row11
zmm12 --> row12
zmm14 --> row13
zmm16 --> row14
zmm18 --> row15
zmm20 --> row16
Output(after transpose) :
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
| | | | | | | |
| | | | | | | |
V V V V V V V V
col9 col10 col11 col12 col13 col14 col15 col16
*/
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
/* Store the 8 rows post transpose */
vmovups(zmm0, mem(rbx, 0*192 + 64))
vmovups(zmm4, mem(rbx, 1*192 + 64))
vmovups(zmm2, mem(rbx, 2*192 + 64))
vmovups(zmm6, mem(rbx, 3*192 + 64))
vmovups(zmm1, mem(rbx, 4*192 + 64))
vmovups(zmm5, mem(rbx, 5*192 + 64))
vmovups(zmm3, mem(rbx, 6*192 + 64))
vmovups(zmm8, mem(rbx, 7*192 + 64))
add(r14, rax) // a += 8*inca;
/* Load the final 8 rows of matrix */
vmovups(mem(rax, 0), zmm6)
vmovups(mem(rax, r8, 1, 0), zmm8)
vmovups(mem(rax, r8, 2, 0), zmm10)
vmovups(mem(rax, r12, 1, 0), zmm12)
vmovups(mem(rax, r8, 4, 0), zmm14)
vmovups(mem(rax, rcx, 1, 0), zmm16)
vmovups(mem(rax, r12, 2, 0), zmm18)
vmovups(mem(rax, rdx, 1, 0), zmm20)
/* Transpose the 8x8 matrix onto another set of 8 registers */
/*
Input :
zmm6 --> row17
zmm8 --> row18
zmm10 --> row19
zmm12 --> row20
zmm14 --> row21
zmm16 --> row22
zmm18 --> row23
zmm20 --> row24
Output(after transpose) :
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
| | | | | | | |
| | | | | | | |
V V V V V V V V
col17 col18 col19 col20 col21 col22 col23 col24
*/
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
/* Store the 8 rows post transpose */
vmovups(zmm0, mem(rbx, 0*192 + 128))
vmovups(zmm4, mem(rbx, 1*192 + 128))
vmovups(zmm2, mem(rbx, 2*192 + 128))
vmovups(zmm6, mem(rbx, 3*192 + 128))
vmovups(zmm1, mem(rbx, 4*192 + 128))
vmovups(zmm5, mem(rbx, 5*192 + 128))
vmovups(zmm3, mem(rbx, 6*192 + 128))
vmovups(zmm8, mem(rbx, 7*192 + 128))
add(imm(8*8), r13)
mov(r13, rax) // a += 8*8*lda
add(imm(8*8*24), rbx) // p += 8*ldp
dec(rsi) // i -= 1;
jne(.CKITERROWU) // iterate again if i != 0.
label(.CCONKLEFTROWU)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.CDONE) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.CKLEFTROWU) // EDGE LOOP (k_left)
LABEL(.UPDATEL1)
/* Move the first 8xk_left block of data */
vmovups(mem(rax, 0), zmm6 MASK_KZ(2))
vmovups(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
vmovups(mem(rax, r8, 2, 0), zmm10 MASK_KZ(2))
vmovups(mem(rax, r12, 1, 0), zmm12 MASK_KZ(2))
vmovups(mem(rax, r8, 4, 0), zmm14 MASK_KZ(2))
vmovups(mem(rax, rcx, 1, 0), zmm16 MASK_KZ(2))
vmovups(mem(rax, r12, 2, 0), zmm18 MASK_KZ(2))
vmovups(mem(rax, rdx, 1, 0), zmm20 MASK_KZ(2))
/* Transpose the 8x8 matrix onto another set of 8 registers */
/*
Input :
zmm6 --> row1(masked loads)
zmm8 --> row2(masked loads)
zmm10 --> row3(masked loads)
zmm12 --> row4(masked loads)
zmm14 --> row5(masked loads)
zmm16 --> row6(masked loads)
zmm18 --> row7(masked loads)
zmm20 --> row8(masked loads)
Output(after transpose) :
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
| | | | | | | |
| | | | | | | |
V V V V V V V V
col1 col2 col3 col4 col5 col6 col7 col8
*/
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
add(r14, rax) // a += 8*inca
cmp(imm(7), rsi)
JZ(.UPDATE7L1)
cmp(imm(6), rsi)
JZ(.UPDATE6L1)
cmp(imm(5), rsi)
JZ(.UPDATE5L1)
cmp(imm(4), rsi)
JZ(.UPDATE4L1)
cmp(imm(3), rsi)
JZ(.UPDATE3L1)
cmp(imm(2), rsi)
JZ(.UPDATE2L1)
cmp(imm(1), rsi)
JZ(.UPDATE1L1)
LABEL(.UPDATE7L1)
// Update 8x7 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192))
vmovups(zmm4, mem(rbx, 1*192))
vmovups(zmm2, mem(rbx, 2*192))
vmovups(zmm6, mem(rbx, 3*192))
vmovups(zmm1, mem(rbx, 4*192))
vmovups(zmm5, mem(rbx, 5*192))
vmovups(zmm3, mem(rbx, 6*192))
jmp(.UPDATEL2)
LABEL(.UPDATE6L1)
// Update 8x6 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192))
vmovups(zmm4, mem(rbx, 1*192))
vmovups(zmm2, mem(rbx, 2*192))
vmovups(zmm6, mem(rbx, 3*192))
vmovups(zmm1, mem(rbx, 4*192))
vmovups(zmm5, mem(rbx, 5*192))
jmp(.UPDATEL2)
LABEL(.UPDATE5L1)
// Update 8x5 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192))
vmovups(zmm4, mem(rbx, 1*192))
vmovups(zmm2, mem(rbx, 2*192))
vmovups(zmm6, mem(rbx, 3*192))
vmovups(zmm1, mem(rbx, 4*192))
jmp(.UPDATEL2)
LABEL(.UPDATE4L1)
// Update 8x4 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192))
vmovups(zmm4, mem(rbx, 1*192))
vmovups(zmm2, mem(rbx, 2*192))
vmovups(zmm6, mem(rbx, 3*192))
jmp(.UPDATEL2)
LABEL(.UPDATE3L1)
// Update 8x3 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192))
vmovups(zmm4, mem(rbx, 1*192))
vmovups(zmm2, mem(rbx, 2*192))
jmp(.UPDATEL2)
LABEL(.UPDATE2L1)
// Update 8x2 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192))
vmovups(zmm4, mem(rbx, 1*192))
jmp(.UPDATEL2)
LABEL(.UPDATE1L1)
// Update 8x1 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192))
jmp(.UPDATEL2)
LABEL(.UPDATEL2)
/* Move the next 8xk_left block of data */
vmovups(mem(rax, 0), zmm6 MASK_KZ(2))
vmovups(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
vmovups(mem(rax, r8, 2, 0), zmm10 MASK_KZ(2))
vmovups(mem(rax, r12, 1, 0), zmm12 MASK_KZ(2))
vmovups(mem(rax, r8, 4, 0), zmm14 MASK_KZ(2))
vmovups(mem(rax, rcx, 1, 0), zmm16 MASK_KZ(2))
vmovups(mem(rax, r12, 2, 0), zmm18 MASK_KZ(2))
vmovups(mem(rax, rdx, 1, 0), zmm20 MASK_KZ(2))
/* Transpose the 8x8 matrix onto another set of 8 registers */
/*
Input :
zmm6 --> row9(masked loads)
zmm8 --> row10(masked loads)
zmm10 --> row11(masked loads)
zmm12 --> row12(masked loads)
zmm14 --> row13(masked loads)
zmm16 --> row14(masked loads)
zmm18 --> row15(masked loads)
zmm20 --> row16(masked loads)
Output(after transpose) :
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
| | | | | | | |
| | | | | | | |
V V V V V V V V
col9 col10 col11 col12 col13 col14 col15 col16
*/
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
add(r14, rax) // a += 8*inca
cmp(imm(7), rsi)
JZ(.UPDATE7L2)
cmp(imm(6), rsi)
JZ(.UPDATE6L2)
cmp(imm(5), rsi)
JZ(.UPDATE5L2)
cmp(imm(4), rsi)
JZ(.UPDATE4L2)
cmp(imm(3), rsi)
JZ(.UPDATE3L2)
cmp(imm(2), rsi)
JZ(.UPDATE2L2)
cmp(imm(1), rsi)
JZ(.UPDATE1L2)
LABEL(.UPDATE7L2)
// Update 8x7 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 64))
vmovups(zmm4, mem(rbx, 1*192 + 64))
vmovups(zmm2, mem(rbx, 2*192 + 64))
vmovups(zmm6, mem(rbx, 3*192 + 64))
vmovups(zmm1, mem(rbx, 4*192 + 64))
vmovups(zmm5, mem(rbx, 5*192 + 64))
vmovups(zmm3, mem(rbx, 6*192 + 64))
jmp(.UPDATEL3)
LABEL(.UPDATE6L2)
// Update 8x6 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 64))
vmovups(zmm4, mem(rbx, 1*192 + 64))
vmovups(zmm2, mem(rbx, 2*192 + 64))
vmovups(zmm6, mem(rbx, 3*192 + 64))
vmovups(zmm1, mem(rbx, 4*192 + 64))
vmovups(zmm5, mem(rbx, 5*192 + 64))
jmp(.UPDATEL3)
LABEL(.UPDATE5L2)
// Update 8x5 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 64))
vmovups(zmm4, mem(rbx, 1*192 + 64))
vmovups(zmm2, mem(rbx, 2*192 + 64))
vmovups(zmm6, mem(rbx, 3*192 + 64))
vmovups(zmm1, mem(rbx, 4*192 + 64))
jmp(.UPDATEL3)
LABEL(.UPDATE4L2)
// Update 8x4 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 64))
vmovups(zmm4, mem(rbx, 1*192 + 64))
vmovups(zmm2, mem(rbx, 2*192 + 64))
vmovups(zmm6, mem(rbx, 3*192 + 64))
jmp(.UPDATEL3)
LABEL(.UPDATE3L2)
// Update 8x3 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 64))
vmovups(zmm4, mem(rbx, 1*192 + 64))
vmovups(zmm2, mem(rbx, 2*192 + 64))
jmp(.UPDATEL3)
LABEL(.UPDATE2L2)
// Update 8x2 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 64))
vmovups(zmm4, mem(rbx, 1*192 + 64))
jmp(.UPDATEL3)
LABEL(.UPDATE1L2)
// Update 8x1 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 64))
jmp(.UPDATEL3)
LABEL(.UPDATEL3)
/* Move the next 8xk_left block of data */
vmovups(mem(rax, 0), zmm6 MASK_KZ(2))
vmovups(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
vmovups(mem(rax, r8, 2, 0), zmm10 MASK_KZ(2))
vmovups(mem(rax, r12, 1, 0), zmm12 MASK_KZ(2))
vmovups(mem(rax, r8, 4, 0), zmm14 MASK_KZ(2))
vmovups(mem(rax, rcx, 1, 0), zmm16 MASK_KZ(2))
vmovups(mem(rax, r12, 2, 0), zmm18 MASK_KZ(2))
vmovups(mem(rax, rdx, 1, 0), zmm20 MASK_KZ(2))
/* Transpose the 8x8 matrix onto another set of 8 registers */
/*
Input :
zmm6 --> row16(masked loads)
zmm8 --> row17(masked loads)
zmm10 --> row18(masked loads)
zmm12 --> row19(masked loads)
zmm14 --> row20(masked loads)
zmm16 --> row21(masked loads)
zmm18 --> row22(masked loads)
zmm20 --> row23(masked loads)
Output(after transpose) :
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
| | | | | | | |
| | | | | | | |
V V V V V V V V
col16 col17 col18 col19 col20 col21 col22 col23
*/
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
cmp(imm(7), rsi)
JZ(.UPDATE7L3)
cmp(imm(6), rsi)
JZ(.UPDATE6L3)
cmp(imm(5), rsi)
JZ(.UPDATE5L3)
cmp(imm(4), rsi)
JZ(.UPDATE4L3)
cmp(imm(3), rsi)
JZ(.UPDATE3L3)
cmp(imm(2), rsi)
JZ(.UPDATE2L3)
cmp(imm(1), rsi)
JZ(.UPDATE1L3)
LABEL(.UPDATE7L3)
// Update 8x7 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 128))
vmovups(zmm4, mem(rbx, 1*192 + 128))
vmovups(zmm2, mem(rbx, 2*192 + 128))
vmovups(zmm6, mem(rbx, 3*192 + 128))
vmovups(zmm1, mem(rbx, 4*192 + 128))
vmovups(zmm5, mem(rbx, 5*192 + 128))
vmovups(zmm3, mem(rbx, 6*192 + 128))
jmp(.CDONE)
LABEL(.UPDATE6L3)
// Update 8x6 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 128))
vmovups(zmm4, mem(rbx, 1*192 + 128))
vmovups(zmm2, mem(rbx, 2*192 + 128))
vmovups(zmm6, mem(rbx, 3*192 + 128))
vmovups(zmm1, mem(rbx, 4*192 + 128))
vmovups(zmm5, mem(rbx, 5*192 + 128))
jmp(.CDONE)
LABEL(.UPDATE5L3)
// Update 8x5 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 128))
vmovups(zmm4, mem(rbx, 1*192 + 128))
vmovups(zmm2, mem(rbx, 2*192 + 128))
vmovups(zmm6, mem(rbx, 3*192 + 128))
vmovups(zmm1, mem(rbx, 4*192 + 128))
jmp(.CDONE)
LABEL(.UPDATE4L3)
// Update 8x4 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 128))
vmovups(zmm4, mem(rbx, 1*192 + 128))
vmovups(zmm2, mem(rbx, 2*192 + 128))
vmovups(zmm6, mem(rbx, 3*192 + 128))
jmp(.CDONE)
LABEL(.UPDATE3L3)
// Update 8x3 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 128))
vmovups(zmm4, mem(rbx, 1*192 + 128))
vmovups(zmm2, mem(rbx, 2*192 + 128))
jmp(.CDONE)
LABEL(.UPDATE2L3)
// Update 8x2 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 128))
vmovups(zmm4, mem(rbx, 1*192 + 128))
jmp(.CDONE)
LABEL(.UPDATE1L3)
// Update 8x1 tile to destination buffer
vmovups(zmm0, mem(rbx, 0*192 + 128))
jmp(.CDONE)
// -- column storage on A --------------------------------------
label(.CCOLUNIT)
mov(var(ldp), r8) // load ldp
lea(mem(, r8, 8), r8) // r8 *= sizeof(scomplex)
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.CCONKLEFTCOLU) // if i == 0, jump to code that
// contains the k_left loop.
label(.CKITERCOLU) // MAIN LOOP (k_iter)
/* Load and store 3 ZMM registers(24 elements) */
/* Unroll-1 */
vmovups(mem(rax), zmm6)
vmovups(mem(rax, 64), zmm8)
vmovups(mem(rax, 128), zmm10)
vmovups(zmm6, mem(rbx, 0))
vmovups(zmm8, mem(rbx, 64))
vmovups(zmm10, mem(rbx, 128))
add(r10, rax)
add(r8, rbx)
/* Unroll-2 */
vmovups(mem(rax), zmm6)
vmovups(mem(rax, 64), zmm8)
vmovups(mem(rax, 128), zmm10)
vmovups(zmm6, mem(rbx, 0))
vmovups(zmm8, mem(rbx, 64))
vmovups(zmm10, mem(rbx, 128))
add(r10, rax)
add(r8, rbx)
/* Unroll-3 */
vmovups(mem(rax), zmm6)
vmovups(mem(rax, 64), zmm8)
vmovups(mem(rax, 128), zmm10)
vmovups(zmm6, mem(rbx, 0))
vmovups(zmm8, mem(rbx, 64))
vmovups(zmm10, mem(rbx, 128))
add(r10, rax)
add(r8, rbx)
/* Unroll-4 */
vmovups(mem(rax), zmm6)
vmovups(mem(rax, 64), zmm8)
vmovups(mem(rax, 128), zmm10)
vmovups(zmm6, mem(rbx, 0))
vmovups(zmm8, mem(rbx, 64))
vmovups(zmm10, mem(rbx, 128))
add(r10, rax)
add(r8, rbx)
/* Unroll-5 */
vmovups(mem(rax), zmm6)
vmovups(mem(rax, 64), zmm8)
vmovups(mem(rax, 128), zmm10)
vmovups(zmm6, mem(rbx, 0))
vmovups(zmm8, mem(rbx, 64))
vmovups(zmm10, mem(rbx, 128))
add(r10, rax)
add(r8, rbx)
/* Unroll-6 */
vmovups(mem(rax), zmm6)
vmovups(mem(rax, 64), zmm8)
vmovups(mem(rax, 128), zmm10)
vmovups(zmm6, mem(rbx, 0))
vmovups(zmm8, mem(rbx, 64))
vmovups(zmm10, mem(rbx, 128))
add(r10, rax)
add(r8, rbx)
/* Unroll-7 */
vmovups(mem(rax), zmm6)
vmovups(mem(rax, 64), zmm8)
vmovups(mem(rax, 128), zmm10)
vmovups(zmm6, mem(rbx, 0))
vmovups(zmm8, mem(rbx, 64))
vmovups(zmm10, mem(rbx, 128))
add(r10, rax)
add(r8, rbx)
/* Unroll-8 */
vmovups(mem(rax), zmm6)
vmovups(mem(rax, 64), zmm8)
vmovups(mem(rax, 128), zmm10)
vmovups(zmm6, mem(rbx, 0))
vmovups(zmm8, mem(rbx, 64))
vmovups(zmm10, mem(rbx, 128))
add(r10, rax)
add(r8, rbx)
dec(rsi) // i -= 1;
jne(.CKITERCOLU) // iterate again if i != 0.
label(.CCONKLEFTCOLU)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.CDONE) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.CKLEFTCOLU) // EDGE LOOP (k_left)
/* Load and store 3 ZMM registers(24 elements) */
vmovups(mem(rax), zmm6)
vmovups(mem(rax, 64), zmm8)
vmovups(mem(rax, 128), zmm10)
vmovups(zmm6, mem(rbx, 0))
vmovups(zmm8, mem(rbx, 64))
vmovups(zmm10, mem(rbx, 128))
add(r10, rax)
add(r8, rbx)
dec(rsi) // i -= 1;
jne(.CKLEFTCOLU) // iterate again if i != 0.
label(.CDONE)
end_asm(
: // output operands (none)
: // input operands
[mask] "m" (mask),
[k_iter] "m" (k_iter),
[k_left] "m" (k_left),
[a] "m" (a),
[inca] "m" (inca),
[lda] "m" (lda),
[p] "m" (p),
[ldp] "m" (ldp)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi",
"r8", "r10", "r12", "r13", "r14",
"xmm0", "xmm1", "xmm2", "xmm3",
"zmm0", "zmm1", "zmm2", "zmm3",
"zmm4", "zmm5", "zmm6", "zmm7",
"zmm8", "zmm10", "zmm12", "zmm14",
"zmm16", "zmm18", "zmm20", "zmm30",
"zmm31", "k2", "memory"
)
}
else // if ( cdim0 < mnr || gs || bli_does_conj( conja ) || !unitk )
{
PASTEMAC(cscal2m,BLIS_TAPI_EX_SUF)
(
0,
BLIS_NONUNIT_DIAG,
BLIS_DENSE,
( trans_t )conja,
cdim0,
k0,
kappa,
a, inca0, lda0,
p, 1, ldp0,
cntx,
NULL
);
if ( cdim0 < mnr )
{
// Handle zero-filling along the "long" edge of the micropanel.
const dim_t i = cdim0;
const dim_t m_edge = mnr - cdim0;
const dim_t n_edge = k0_max;
scomplex* restrict p_edge = p + (i )*1;
bli_cset0s_mxn
(
m_edge,
n_edge,
p_edge, 1, ldp
);
}
}
if ( k0 < k0_max )
{
// Handle zero-filling along the "short" (far) edge of the micropanel.
const dim_t j = k0;
const dim_t m_edge = mnr;
const dim_t n_edge = k0_max - k0;
scomplex* restrict p_edge = p + (j )*ldp;
bli_cset0s_mxn
(
m_edge,
n_edge,
p_edge, 1, ldp
);
}
}

View File

@@ -0,0 +1,498 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#define BLIS_ASM_SYNTAX_ATT
#include "bli_x86_asm_macros.h"
/**
* Shuffle 2 scomplex elements selected by imm8 from S1 and S2,
* and store the results in D1
* S1 : 1 9 3 11 5 13 7 15
* S2 : 2 10 4 12 6 14 8 16
* D1 : 1 9 5 13 2 10 6 14
* D2 : 3 11 7 15 4 12 8 16
*/
#define SHUFFLE_DATA(S1, S2, D1, D2, S3, S4, D3, D4) \
\
VSHUFF64X2(IMM(0x88), ZMM(S1), ZMM(S2), ZMM(D1)) \
VSHUFF64X2(IMM(0xDD), ZMM(S1), ZMM(S2), ZMM(D2)) \
VSHUFF64X2(IMM(0x88), ZMM(S3), ZMM(S4), ZMM(D3)) \
VSHUFF64X2(IMM(0xDD), ZMM(S3), ZMM(S4), ZMM(D4)) \
/**
* Unpacks and interleave low half and high half of each
* 128-bit lane in S1 and S2 and store into D1 and D2
* respectively.
* S1 : 1 2 3 4 5 6 7 8
* S2 : 9 10 11 12 13 14 15 16
* D1 : 1 9 3 11 5 13 7 15
* D2 : 2 10 4 12 6 14 8 16
*/
#define UNPACK_LO_HIGH(S1, S2, D1, D2, S3, S4, D3, D4) \
\
vunpcklpd(zmm(S1), zmm(S2), zmm(D1)) \
vunpckhpd(zmm(S1), zmm(S2), zmm(D2)) \
vunpcklpd(zmm(S3), zmm(S4), zmm(D3)) \
vunpckhpd(zmm(S3), zmm(S4), zmm(D4))
void bli_cpackm_zen4_asm_4xk
(
conj_t conja,
pack_t schema,
dim_t cdim0,
dim_t k0,
dim_t k0_max,
scomplex* restrict kappa,
scomplex* restrict a, inc_t inca0, inc_t lda0,
scomplex* restrict p, inc_t ldp0,
cntx_t* restrict cntx
)
{
// This is the panel dimension assumed by the packm kernel.
const dim_t mnr = 4;
// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
// NOTE : k_iter is in blocks of 8, due to the SIMD width of scomplex
// in a 512-register. This way, we could still perform AVX512 loads
// and stores in case of the matrix being in row-major format.
const uint64_t k_iter = k0 / 8;
const uint64_t k_left = k0 % 8;
/**
* Preparing the mask for k_left, since we are computing in blocks of 8.
* For the edge cases, mask is set to load and store only the leftover elements.
*/
uint16_t one = 1;
uint16_t mask = ( one << ( 2 * k_left ) ) - one;
// NOTE: For the purposes of the comments in this packm kernel, we
// interpret inca and lda as rs_a and cs_a, respectively, and similarly
// interpret ldp as cs_p (with rs_p implicitly unit). Thus, when reading
// this packm kernel, you should think of the operation as packing an
// m x n micropanel, where m and n are tiny and large, respectively, and
// where elements of each column of the packed matrix P are contiguous.
// (This packm kernel can still be used to pack micropanels of matrix B
// in a gemm operation.)
const uint64_t inca = inca0;
const uint64_t lda = lda0;
const uint64_t ldp = ldp0;
const bool gs = ( inca0 != 1 && lda0 != 1 );
// NOTE: If/when this kernel ever supports scaling by kappa within the
// assembly region, this constraint should be lifted.
const bool unitk = bli_ceq1( *kappa );
// -------------------------------------------------------------------------
if ( cdim0 == mnr && !gs && !conja && unitk )
{
begin_asm()
mov(var(mask), rdx) // load mask
kmovw(edx, k(2)) // move mask to k2 register
mov(var(a), rax) // load address of source buffer
mov(var(a), r13) // load address of source buffer
mov(var(inca), r8) // load inca
mov(var(lda), r10) // load lda
lea(mem( , r8, 8), r8) // inca *= sizeof(scomplex)
lea(mem( , r10, 8), r10) // lda *= sizeof(scomplex)
mov(var(p), rbx) // load address of p.
lea(mem( , r8, 8), r14) // r14 = 8*inca
cmp(imm(8), r8) // set ZF if (8*inca) == 8.
jz(.CCOLUNIT) // jump to column storage case
// -- kappa unit, row storage on A -------------------------------
label(.CROWUNIT)
lea(mem(r8, r8, 2), r12) // r12 = 3*inca
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.CCONKLEFTROWU) // if i == 0, jump to code that
// contains the k_left loop.
label(.CKITERROWU) // MAIN LOOP (k_iter)
/**
* Load first 4 rows of matrix.
* Set 4 additional registers to zero
* Transpose 8x8(by extending 4x8 with 0.0 padding ) tile
and store it back to destination buffer.
*/
vmovups(mem(rax, 0), zmm6)
vmovups(mem(rax, r8, 1, 0), zmm8)
vmovups(mem(rax, r8, 2, 0), zmm10)
vmovups(mem(rax, r12, 1, 0), zmm12)
vxorps(zmm14, zmm14, zmm14)
vxorps(zmm16, zmm16, zmm16)
vxorps(zmm18, zmm18, zmm18)
vxorps(zmm20, zmm20, zmm20)
/* Transpose the 8x8 matrix onto another set of 8 registers */
/*
Input :
zmm6 --> row1
zmm8 --> row2
zmm10 --> row3
zmm12 --> row4
zmm14 --> 0.0f
zmm16 --> 0.0f
zmm18 --> 0.0f
zmm20 --> 0.0f
Output(after transpose) :
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
| | | | | | | |
| | | | | | | |
V V V V V V V V
col1 col2 col3 col4 col5 col6 col7 col8
Every column(register) will have the last 256-bit lane as 0.0f
Thus, we only store the YMM registers to the destination buffer.
*/
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
/* Store the 256-bit lanes(YMM) onto the destination */
vmovups(ymm0, mem(rbx, 0*32))
vmovups(ymm4, mem(rbx, 1*32))
vmovups(ymm2, mem(rbx, 2*32))
vmovups(ymm6, mem(rbx, 3*32))
vmovups(ymm1, mem(rbx, 4*32))
vmovups(ymm5, mem(rbx, 5*32))
vmovups(ymm3, mem(rbx, 6*32))
vmovups(ymm8, mem(rbx, 7*32))
add(imm(8*8), r13)
mov(r13, rax) // a += 8*8*lda
add(imm(8*8*4), rbx) // p += 8*ldp
dec(rsi) // i -= 1;
jne(.CKITERROWU) // iterate again if i != 0.
label(.CCONKLEFTROWU)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.CDONE) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.CKLEFTROWU) // EDGE LOOP (k_left)
LABEL(.UPDATEL1)
/* Move the first 4xk_left block of data */
vmovups(mem(rax, 0), zmm6 MASK_KZ(2))
vmovups(mem(rax, r8, 1, 0), zmm8 MASK_KZ(2))
vmovups(mem(rax, r8, 2, 0), zmm10 MASK_KZ(2))
vmovups(mem(rax, r12, 1, 0), zmm12 MASK_KZ(2))
vxorps(zmm14, zmm14, zmm14)
vxorps(zmm16, zmm16, zmm16)
vxorps(zmm18, zmm18, zmm18)
vxorps(zmm20, zmm20, zmm20)
/*
Input :
zmm6 --> row1(masked loads)
zmm8 --> row2(masked loads)
zmm10 --> row3(masked loads)
zmm12 --> row4(masked loads)
zmm14 --> 0.0f
zmm16 --> 0.0f
zmm18 --> 0.0f
zmm20 --> 0.0f
Output(after transpose) :
zmm0 zmm4 zmm2 zmm6 zmm1 zmm5 zmm3 zmm7
| | | | | | | |
| | | | | | | |
V V V V V V V V
col1 col2 col3 col4 col5 col6 col7 col8
Every column(register) will have the last 256-bit lane as 0.0f
Thus, we only store the YMM registers to the destination buffer.
*/
UNPACK_LO_HIGH(8, 6, 0, 1, 12, 10, 2, 3)
SHUFFLE_DATA(2, 0, 4, 5, 3, 1, 30, 31)
UNPACK_LO_HIGH(16, 14, 0, 1, 20, 18, 2, 3)
SHUFFLE_DATA(2, 0, 6, 8, 3, 1, 10, 12)
SHUFFLE_DATA(6, 4, 0, 1, 8, 5, 2, 3)
SHUFFLE_DATA(10, 30, 4, 5, 12, 31, 6, 8)
cmp(imm(7), rsi)
JZ(.UPDATE7L1)
cmp(imm(6), rsi)
JZ(.UPDATE6L1)
cmp(imm(5), rsi)
JZ(.UPDATE5L1)
cmp(imm(4), rsi)
JZ(.UPDATE4L1)
cmp(imm(3), rsi)
JZ(.UPDATE3L1)
cmp(imm(2), rsi)
JZ(.UPDATE2L1)
cmp(imm(1), rsi)
JZ(.UPDATE1L1)
LABEL(.UPDATE7L1)
// Update 4x7 tile to destination buffer
vmovups(ymm0, mem(rbx, 0*32))
vmovups(ymm4, mem(rbx, 1*32))
vmovups(ymm2, mem(rbx, 2*32))
vmovups(ymm6, mem(rbx, 3*32))
vmovups(ymm1, mem(rbx, 4*32))
vmovups(ymm5, mem(rbx, 5*32))
vmovups(ymm3, mem(rbx, 6*32))
jmp(.CDONE)
LABEL(.UPDATE6L1)
// Update 4x6 tile to destination buffer
vmovups(ymm0, mem(rbx, 0*32))
vmovups(ymm4, mem(rbx, 1*32))
vmovups(ymm2, mem(rbx, 2*32))
vmovups(ymm6, mem(rbx, 3*32))
vmovups(ymm1, mem(rbx, 4*32))
vmovups(ymm5, mem(rbx, 5*32))
jmp(.CDONE)
LABEL(.UPDATE5L1)
// Update 4x5 tile to destination buffer
vmovups(ymm0, mem(rbx, 0*32))
vmovups(ymm4, mem(rbx, 1*32))
vmovups(ymm2, mem(rbx, 2*32))
vmovups(ymm6, mem(rbx, 3*32))
vmovups(ymm1, mem(rbx, 4*32))
jmp(.CDONE)
LABEL(.UPDATE4L1)
// Update 4x4 tile to destination buffer
vmovups(ymm0, mem(rbx, 0*32))
vmovups(ymm4, mem(rbx, 1*32))
vmovups(ymm2, mem(rbx, 2*32))
vmovups(ymm6, mem(rbx, 3*32))
jmp(.CDONE)
LABEL(.UPDATE3L1)
// Update 4x3 tile to destination buffer
vmovups(ymm0, mem(rbx, 0*32))
vmovups(ymm4, mem(rbx, 1*32))
vmovups(ymm2, mem(rbx, 2*32))
jmp(.CDONE)
LABEL(.UPDATE2L1)
// Update 4x2 tile to destination buffer
vmovups(ymm0, mem(rbx, 0*32))
vmovups(ymm4, mem(rbx, 1*32))
jmp(.CDONE)
LABEL(.UPDATE1L1)
// Update 4x1 tile to destination buffer
vmovups(ymm0, mem(rbx, 0*32))
jmp(.CDONE)
// -- column storage on A --------------------------------------
label(.CCOLUNIT)
mov(var(ldp), r8) // load ldp
lea(mem(, r8, 8), r8) // r8 *= sizeof(scomplex)
mov(var(k_iter), rsi) // i = k_iter;
test(rsi, rsi) // check i via logical AND.
je(.CCONKLEFTCOLU) // if i == 0, jump to code that
// contains the k_left loop.
label(.CKITERCOLU) // MAIN LOOP (k_iter)
/* Load/store a column of C using YMM regsiters */
/* Unroll-1 */
vmovups(mem(rax), ymm6)
vmovups(ymm6, mem(rbx))
add(r10, rax)
add(r8, rbx)
/* Unroll-2 */
vmovups(mem(rax), ymm6)
vmovups(ymm6, mem(rbx))
add(r10, rax)
add(r8, rbx)
/* Unroll-3 */
vmovups(mem(rax), ymm6)
vmovups(ymm6, mem(rbx))
add(r10, rax)
add(r8, rbx)
/* Unroll-4 */
vmovups(mem(rax), ymm6)
vmovups(ymm6, mem(rbx))
add(r10, rax)
add(r8, rbx)
/* Unroll-5 */
vmovups(mem(rax), ymm6)
vmovups(ymm6, mem(rbx))
add(r10, rax)
add(r8, rbx)
/* Unroll-6 */
vmovups(mem(rax), ymm6)
vmovups(ymm6, mem(rbx))
add(r10, rax)
add(r8, rbx)
/* Unroll-7 */
vmovups(mem(rax), ymm6)
vmovups(ymm6, mem(rbx))
add(r10, rax)
add(r8, rbx)
/* Unroll-8 */
vmovups(mem(rax), ymm6)
vmovups(ymm6, mem(rbx))
add(r10, rax)
add(r8, rbx)
dec(rsi) // i -= 1;
jne(.CKITERCOLU) // iterate again if i != 0.
label(.CCONKLEFTCOLU)
mov(var(k_left), rsi) // i = k_left;
test(rsi, rsi) // check i via logical AND.
je(.CDONE) // if i == 0, we're done; jump to end.
// else, we prepare to enter k_left loop.
label(.CKLEFTCOLU) // EDGE LOOP (k_left)
/* Load/store a column of C using YMM register */
vmovups(mem(rax), ymm6)
vmovups(ymm6, mem(rbx))
add(r10, rax)
add(r8, rbx)
dec(rsi) // i -= 1;
jne(.CKLEFTCOLU) // iterate again if i != 0.
label(.CDONE)
end_asm(
: // output operands (none)
: // input operands
[mask] "m" (mask),
[k_iter] "m" (k_iter),
[k_left] "m" (k_left),
[a] "m" (a),
[inca] "m" (inca),
[lda] "m" (lda),
[p] "m" (p),
[ldp] "m" (ldp)
: // register clobber list
"rax", "rbx", "rcx", "rdx", "rsi",
"r8", "r10", "r12", "r13", "r14",
"xmm0", "xmm1", "xmm2", "xmm3",
"ymm0", "ymm1", "ymm2", "ymm3",
"ymm4", "ymm5", "ymm6", "ymm8",
"zmm0", "zmm1", "zmm2", "zmm3",
"zmm4", "zmm5", "zmm6", "zmm7",
"zmm8", "zmm10", "zmm12", "zmm14",
"zmm16", "zmm18", "zmm20", "zmm30",
"zmm31", "k2", "memory"
)
}
else // if ( cdim0 < mnr || gs || bli_does_conj( conja ) || !unitk )
{
PASTEMAC(cscal2m,BLIS_TAPI_EX_SUF)
(
0,
BLIS_NONUNIT_DIAG,
BLIS_DENSE,
( trans_t )conja,
cdim0,
k0,
kappa,
a, inca0, lda0,
p, 1, ldp0,
cntx,
NULL
);
if ( cdim0 < mnr )
{
// Handle zero-filling along the "long" edge of the micropanel.
const dim_t i = cdim0;
const dim_t m_edge = mnr - cdim0;
const dim_t n_edge = k0_max;
scomplex* restrict p_edge = p + (i )*1;
bli_cset0s_mxn
(
m_edge,
n_edge,
p_edge, 1, ldp
);
}
}
if ( k0 < k0_max )
{
// Handle zero-filling along the "short" (far) edge of the micropanel.
const dim_t j = k0;
const dim_t m_edge = mnr;
const dim_t n_edge = k0_max - k0;
scomplex* restrict p_edge = p + (j )*ldp;
bli_cset0s_mxn
(
m_edge,
n_edge,
p_edge, 1, ldp
);
}
}

View File

@@ -0,0 +1,799 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "bli_cgemm_zen4_asm_macros.h"
#define LOOP_ALIGN ALIGN32
/* Minimum number of iterations required for prefetching C */
#define TAIL_ITER 6
// This array is used to support ADDSUB instruction.
static float fma_vec[16] __attribute__((aligned(64)))
= {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
// This is an array used for the scatter/gather instructions.
static int64_t offsets[24] __attribute__((aligned(64))) =
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23 };
/*
Register usage :
ZMM(0) - ZMM(2) : Load A
ZMM(28) - ZMM(31) : Bdcst B
ZMM(3) - ZMM(14) : Accumulate real_scaling
ZMM(15) - ZMM(26) : Accumulate imag_scaling / Load C
Total registers used : 31
*/
void bli_cgemm_zen4_asm_24x4(
dim_t k0,
scomplex *restrict alpha,
scomplex *restrict a,
scomplex *restrict b,
scomplex *restrict beta,
scomplex *restrict c, inc_t rs_c0, inc_t cs_c0,
auxinfo_t *data,
cntx_t *restrict cntx
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7);
/* Casting all the integers to the same type */
const uint64_t k = k0;
uint64_t rs_c = rs_c0 * 8; // rs_c0 = rs_c * 8(size of scomplex datatype)
uint64_t cs_c = cs_c0 * 8; // cs_c0 = cs_c * 8(size of scomplex datatype)
/* Storing the address of the fma_vec array, to be used in the computation */
const float *fmaPtr = &fma_vec[0];
/* Storing the address of offsets, to be used for general-stride computation */
const int64_t *offsetPtr = &offsets[0];
/* Determining the alpha and beta multiplication types */
/* This is done as part of optimizing the alpha and beta scaling */
uint64_t alpha_mul_type = BLIS_MUL_DEFAULT;
uint64_t beta_mul_type = BLIS_MUL_DEFAULT;
/* Setting alpha_mul_type and bet_mul_type, based on special cases
of alpha and beta. */
if ( alpha->imag == 0.0 )
{
if ( alpha->real == 1.0 )
alpha_mul_type = BLIS_MUL_ONE;
else if ( alpha->real == -1.0 )
alpha_mul_type = BLIS_MUL_MINUS_ONE;
}
if ( beta->imag == 0.0 )
{
if ( beta->real == 1.0 )
beta_mul_type = BLIS_MUL_ONE ;
else if ( beta->real == -1.0 )
beta_mul_type = BLIS_MUL_MINUS_ONE;
else if ( beta->real == 0.0 )
beta_mul_type = BLIS_MUL_ZERO;
}
/* Start of the assembly code-section */
BEGIN_ASM()
/* Setting the registers to zero */
SET_ZERO()
/* Loading the value of k in RSI */
MOV(VAR(k), RSI)
/* Loading the addresses of A, B and C */
MOV(VAR(a), RAX) // RAX = addr of A
MOV(VAR(b), RBX) // RBX = addr of B
MOV(VAR(c), RCX) // RCX = addr of C
/* Load R9 with address of C to be used during prefetch */
MOV(RCX, R9)
/* Loading column-stride of C, since this is a column major kernel */
/* NOTE : cs_c has already been scaled by the size of the datatype */
MOV(VAR(cs_c), R10) // R10 = cs_c = 8 * cs_c0
/* Unrolling by a factor of 4, along k-dimension */
MOV(RSI, RDI)
AND(IMM(3), RSI) // RSI = k % 4(k_fringe)
SAR(IMM(2), RDI) // RDI = k / 4(k_iter)
/* The k-loop is divided into 4 parts, to have a fixed distance for prefetching C */
/* The k-loop is divided as follows :
1. .CK_BEFORE_PREFETCH : k/4 - 4 - TAIL_ITER(before C prefetch)
2. .CK_PREFETCH : 4(prefetches C)
3. .CK_AFTER_PREFETCH : TAIL_ITER(after C prefetch(prefetch distance))
4. .CK_FRINGE
*/
LABEL(.CK_BEFORE_PREFETCH)
/* Check for entering the k-loop(before prefetch) */
SUB(IMM(4 + TAIL_ITER), RDI)
/* Jump to k-loop(prefetch) if k/4 <= 4 + TAIL_ITER */
JLE(.CK_PREFETCH)
/* K-loop(unrolled) to perform A*B */
LOOP_ALIGN
LABEL(.CK_ITER_BP) // Unrolled iteration (B)efore (P)refetching
/* Performing rank-1 update 4 times(based on unroll) */
SUB_ITER_24x4(0, RAX, RBX)
SUB_ITER_24x4(1, RAX, RBX)
SUB(IMM(1), RDI) // RDI(iterator) -= 1
SUB_ITER_24x4(2, RAX, RBX)
SUB_ITER_24x4(3, RAX, RBX)
/* Adjusting the addresses of A and B for the next set */
LEA(MEM(RAX, 4 * 24 * 8), RAX) // RAX = RAX + 4 * MR * 8(loads)
LEA(MEM(RBX, 4 * 4 * 8), RBX) // RBX = RBX + 4 * NR * 8(broascasts)
/* Jump if RDI != 0 */
JNZ(.CK_ITER_BP)
LABEL(.CK_PREFETCH)
/* Check for entering the k-loop(for C prefetch) */
ADD(IMM(4), RDI)
/* Jump to k-loop(after prefetch) if k/4 <= TAIL_ITER */
JLE(.CK_AFTER_PREFETCH)
/* K-loop(unrolled) to perform A*B */
LOOP_ALIGN
LABEL(.CK_ITER_P) // Unrolled iteration with (P)refetching
/* Performing rank-1 update 4 times(based on unroll) */
/* Also prefetch C */
PREFETCHW0(MEM(R9))
SUB_ITER_24x4(0, RAX, RBX)
PREFETCHW0(MEM(R9, 64))
SUB_ITER_24x4(1, RAX, RBX)
PREFETCHW0(MEM(R9, 128))
SUB(IMM(1), RDI) // RDI(iterator) -= 1
SUB_ITER_24x4(2, RAX, RBX)
SUB_ITER_24x4(3, RAX, RBX)
/* Adjusting the addresses of A and B for the next set */
LEA(MEM(RAX, 4 * 24 * 8), RAX) // RAX = RAX + 4 * MR * 8(loads)
LEA(MEM(RBX, 4 * 4 * 8), RBX) // RBX = RBX + 4 * NR * 8(broascasts)
LEA(MEM(R9, R10, 1), R9) // RCX = RCX + cs_c
/* Jump if RDI != 0 */
JNZ(.CK_ITER_P)
LABEL(.CK_AFTER_PREFETCH)
/* Check for entering the k-loop(for C prefetch) */
ADD(IMM(0 + TAIL_ITER), RDI)
/* Jump to k-loop(after prefetch) if k/4 <= 0 */
JLE(.CK_FRINGE)
/* K-loop(unrolled) to perform A*B */
LOOP_ALIGN
LABEL(.CK_ITER_AP) // Unrolled iteration (A)fter (P)refetching
/* Performing rank-1 update 4 times(based on unroll) */
/* Also prefetch C */
SUB_ITER_24x4(0, RAX, RBX)
SUB_ITER_24x4(1, RAX, RBX)
SUB(IMM(1), RDI) // RDI(iterator) -= 1
SUB_ITER_24x4(2, RAX, RBX)
SUB_ITER_24x4(3, RAX, RBX)
/* Adjusting the addresses of A and B for the next set */
LEA(MEM(RAX, 4 * 24 * 8), RAX) // RAX = RAX + 4 * MR * 8(loads)
LEA(MEM(RBX, 4 * 4 * 8), RBX) // RBX = RBX + 4 * NR * 8(broascasts)
/* Jump if RDI != 0 */
JNZ(.CK_ITER_AP)
LABEL(.CK_FRINGE)
/* Check for entering the k-loop(fringe) */
TEST(RSI, RSI)
JE(.POSTACCUM)
/* K-loop(unrolled) to perform A*B */
LOOP_ALIGN
LABEL(.CK_ITER_FRINGE)
/* Performing rank-1 update */
SUB_ITER_24x4(0, RAX, RBX)
SUB(IMM(1), RSI) // RSI(iterator) -= 1
/* Adjusting the addresses of A and B for the next set */
/* new_addr = old_addr + ( unroll_factor * {MR or NR} * size_of_type ) */
LEA(MEM(RAX, 24 * 8), RAX) // RAX = RAX + MR * 8(loads)
LEA(MEM(RBX, 4 * 8), RBX) // RBX = RBX + NR * 8(broascasts)
/* Jump until RSI becomes 0 */
JNZ(.CK_ITER_FRINGE)
LABEL(.POSTACCUM)
/* The registers from ZMM(15) to ZMM(26) contain the FMA ops
using imaginary components from elements in B matrix.
We should shuffle them( even and odd indices )
SRC: ZMM(15) = ( Ar0*Bi0, Ai0*Bi0, Ar1*Bi0, Ai1*Bi0, ... )
DST: ZMM(15) = ( Ai0*Bi0, Ar0*Bi0, Ai1*Bi0, Ar1*Bi0, ... )
Similary for the other registers
*/
PERMUTE(15, 16, 17)
PERMUTE(18, 19, 20)
PERMUTE(21, 22, 23)
PERMUTE(24, 25, 26)
/* Loading ZMM(0) with 1.0f, for reduction */
MOV(VAR(fmaPtr), R14)
VMOVAPS(MEM(R14), ZMM(0))
/* Reducing the result using real/imag accumulators, for complex arithmetic
SRC: ZMM(3) = ( Ar0*Br0, Ai0*Br0, Ar1*Br0, Ai1*Br0, ... )
ZMM(15) = ( Ai0*Bi0, Ar0*Bi0, Ai1*Bi0, Ar1*Bi0, ... )
DST: ZMM(3) = ( Ar0*Br0 - Ai0*Bi0, Ai0*Br0 + Ar0*Bi0, ... )
Similarly done for the other registers
*/
FMADDSUB(3, 15, 4, 16, 5, 17)
FMADDSUB(6, 18, 7, 19, 8, 20)
FMADDSUB(9, 21, 10, 22, 11, 23)
FMADDSUB(12, 24, 13, 25, 14, 26)
/*
The result of A*B(micro-tile) is a 24x4 matrix(column-major), as follows :
Column-1 Column-2 Column-3 Column-4
Rows(1-8) ZMM(3) ZMM(6) ZMM(9) ZMM(12)
Rows(9-16) ZMM(4) ZMM(7) ZMM(10) ZMM(13)
Rows(17-24) ZMM(5) ZMM(8) ZMM(11) ZMM(14)
*/
LABEL(.ALPHA_SCALING)
/*
Check for alpha_mul_type, to jump to the required code-section
Intermediate result(IR) = alpha*(A*B)
If alpha == ( 1.0, 0.0 ) => BLIS_MUL_ONE
IR = A*B
else if, alpha != ( -1.0, 0.0 ) => BLIS_MUL_DEFAULT
IR = alpha*(A*B), using complex multiplication
else => BLIS_MUL_MINUS_ONE
IR = 0.0 - A*B, using subtraction
*/
MOV(VAR(alpha_mul_type), R14)
CMP(IMM(1), R14) // Check if alpha = 1.0
/* Skip alpha scaling and jump to beta scaling */
JE(.BETA_SCALING)
CMP(IMM(2), R14) // Check if alpha != -1.0
/* Jump to the general case of alpha scaling */
JE(.ALPHA_GENERAL)
/* Alpha scaling when alpha == -1.0 */
LABEL(.ALPHA_MINUS_ONE)
/* Set ZMM(1) to 0.0f, and subtract the registers from ZMM(1) */
VXORPS(ZMM(1), ZMM(1), ZMM(1))
/* ZMM(3) = ZMM(1) - ZMM(3) = 0.0f - A*B
Similarly done for other registers */
ALPHA_MINUS_ONE(3, 1, 4, 1, 5, 1)
ALPHA_MINUS_ONE(6, 1, 7, 1, 8, 1)
ALPHA_MINUS_ONE(9, 1, 10, 1, 11, 1)
ALPHA_MINUS_ONE(12, 1, 13, 1, 14, 1)
/* Jump to beta scaling */
JMP(.BETA_SCALING)
/* Alpha scaling when alpha != 1.0 and alpha != -1.0 */
LABEL(.ALPHA_GENERAL)
/* Load alpha onto a ZMM register */
MOV(VAR(alpha), RAX)
/* Broadcast the real and imag components of alpha onto the registers */
VBROADCASTSS(MEM(RAX, 0), ZMM(1))
VBROADCASTSS(MEM(RAX, 4), ZMM(2))
/* Scale the result of A*B with alpha */
/* ZMM(15) = alphai * ZMM(3)
ZMM(3) = alphar * ZMM(3)
ZMM(3) = fmaddsub(ZMM(3), permute(ZMM(15)))
Similarly done for other pairs of registers */
ALPHA_DEFAULT(3, 15, 4, 16, 5, 17)
ALPHA_DEFAULT(6, 18, 7, 19, 8, 20)
ALPHA_DEFAULT(9, 21, 10, 22, 11, 23)
ALPHA_DEFAULT(12, 24, 13, 25, 14, 26)
/* Perform beta scaling */
LABEL(.BETA_SCALING)
/* Load the row and column strides of C */
MOV(VAR(rs_c), RDI)
MOV(VAR(cs_c), RSI)
CMP(IMM(8), RDI) // Check if C is column stored
JNE(.ROWSTORED) // Jump to row stored
LABEL(.COLSTORED)
/*
Check for beta_mul_type, to jump to the required code-section
Intermediate C = beta*C + IR, where IR = alpha*A*B
If beta == ( 0.0, 0.0 ) => BLIS_MUL_ZERO
C = IR, skip beta-scaling
else if beta == ( 1.0, 0.0 ) => BLIS_MUL_ONE
C = C + IR, using addition
else if, beta != ( -1.0, 0.0 ) => BLIS_MUL_DEFAULT
C = beta*C + IR, using complex multiplication
else => BLIS_MUL_MINUS_ONE
C = ( 0.0 - C ) + IR, using subtraction
*/
MOV(VAR(beta_mul_type), R14)
CMP(IMM(0), R14) // Check if beta = 0.0
/* Skip beta scaling and jump to store */
JE(.BETA_ZERO_COL)
CMP(IMM(1), R14) // Check if beta = 1.0
/* Jump to beta = 1.0 case */
JE(.BETA_ONE_COL)
CMP(IMM(2), R14) // Check if alpha != -1.0
/* Jump to the general case of alpha scaling */
JE(.BETA_DEFAULT_COL)
/* Beta scaling when beta == -1.0 */
LABEL(.BETA_MINUS_ONE_COL)
/* Perform C = alpha*A*B - C */
/* ZMM(15) = load(C)
ZMM(15) = ZMM(3) - ZMM(15) = alpha*A*B - C
store(ZMM(15))
Similarly done for other registers */
BETA_MINUS_ONE_PRIMARY(3, 15, 4, 16, 5, 17)
LEA((RCX, R10, 1), RCX)
BETA_MINUS_ONE_PRIMARY(6, 18, 7, 19, 8, 20)
LEA((RCX, R10, 1), RCX)
BETA_MINUS_ONE_PRIMARY(9, 21, 10, 22, 11, 23)
LEA((RCX, R10, 1), RCX)
BETA_MINUS_ONE_PRIMARY(12, 24, 13, 25, 14, 26)
JMP(.END)
/* Beta scaling when beta == -1.0 */
LABEL(.BETA_ONE_COL)
/* Perform C = C + alpha*A*B */
/* ZMM(15) = load(C)
ZMM(15) = ZMM(3) + ZMM(15) = alpha*A*B + C
store(ZMM(15))
Similarly done for other registers */
BETA_ONE_PRIMARY(3, 15, 4, 16, 5, 17)
LEA((RCX, R10, 1), RCX)
BETA_ONE_PRIMARY(6, 18, 7, 19, 8, 20)
LEA((RCX, R10, 1), RCX)
BETA_ONE_PRIMARY(9, 21, 10, 22, 11, 23)
LEA((RCX, R10, 1), RCX)
BETA_ONE_PRIMARY(12, 24, 13, 25, 14, 26)
JMP(.END)
/* Beta scaling for generic case */
LABEL(.BETA_DEFAULT_COL)
/* Load beta onto a ZMM register */
MOV(VAR(beta), RBX)
/* Broadcast the real and imag components of beta onto the registers */
VBROADCASTSS(MEM(RBX, 0), ZMM(1))
VBROADCASTSS(MEM(RBX, 4), ZMM(2))
/* Perform C = beta*C + alpha*A*B */
/* ZMM(15) = load(C)
Perform beta scaling of ZMM(15)(similar to alpha scaling)
ZMM(15) = ZMM(3) + ZMM(15) = alpha*A*B + beta*C
store(ZMM(15))
Similarly done for other pairs of registers */
BETA_DEFAULT_PRIMARY(3, 15, 4, 16, 5, 17)
LEA((RCX, R10, 1), RCX)
BETA_DEFAULT_PRIMARY(6, 18, 7, 19, 8, 20)
LEA((RCX, R10, 1), RCX)
BETA_DEFAULT_PRIMARY(9, 21, 10, 22, 11, 23)
LEA((RCX, R10, 1), RCX)
BETA_DEFAULT_PRIMARY(12, 24, 13, 25, 14, 26)
JMP(.END)
LABEL(.BETA_ZERO_COL)
/* This code-section is taken if we want to skip scaling */
VMOVUPS(ZMM(3), MEM(RCX))
VMOVUPS(ZMM(4), MEM(RCX, 64))
VMOVUPS(ZMM(5), MEM(RCX, 128))
LEA((RCX, R10, 1), RCX)
VMOVUPS(ZMM(6), MEM(RCX))
VMOVUPS(ZMM(7), MEM(RCX, 64))
VMOVUPS(ZMM(8), MEM(RCX, 128))
LEA((RCX, R10, 1), RCX)
VMOVUPS(ZMM(9), MEM(RCX))
VMOVUPS(ZMM(10), MEM(RCX, 64))
VMOVUPS(ZMM(11), MEM(RCX, 128))
LEA((RCX, R10, 1), RCX)
VMOVUPS(ZMM(12), MEM(RCX))
VMOVUPS(ZMM(13), MEM(RCX, 64))
VMOVUPS(ZMM(14), MEM(RCX, 128))
JMP(.END)
LABEL(.ROWSTORED)
/* Check for general stride of C */
CMP(IMM(8), RSI) // Check if C is row stored
JNE(.GENERALSTRIDE) // Jump to general stride
/* This code-section is taken if C is row-stored */
/*
Check for beta_mul_type, to jump to the required code-section
Intermediate C = beta*C + IR, where IR = alpha*A*B
If beta == ( 0.0, 0.0 ) => BLIS_MUL_ZERO
C = IR, skip beta-scaling
else => BLIS_MUL_DEFAULT
C = beta*C + IR, using complex multiplication
*/
MOV(VAR(beta_mul_type), R14)
CMP(IMM(0), R14) // Check if beta = 0.0
/* Skip beta scaling and jump to store */
JE(.BETA_ZERO_ROW)
LABEL(.BETA_DEFAULT_ROW)
/* Load beta onto a ZMM register */
MOV(VAR(beta), RBX)
/* Broadcast the real and imag components of beta onto the registers */
VBROADCASTSS(MEM(RBX, 0), ZMM(1))
VBROADCASTSS(MEM(RBX, 4), ZMM(2))
/* We need to transpose the 24x4 block of alpha*A*B,
in steps of 8x4.
We use an 8x8 transpose routine with additional
registers.
Input for transpose:
Column-1 Column-2 Column-3 Column-4 Column-5 Column-6 Column-7 Column-8
Rows(1-8) ZMM(3) ZMM(6) ZMM(9) ZMM(12) ZMM(28) ZMM(29) ZMM(30) ZMM(31)
*/
/* Transpose the 8x8 block of alpha*A*B */
/* ZMM(15) to ZMM(22) are used as temporary registers
for transpose operation */
TRANSPOSE_8X8(3, 6, 9, 12, 28, 29, 30, 31,
15, 16, 17, 18, 19, 20, 21, 22)
/* Scale C by beta and compute the result */
/* This is done one row at a time */
BETA_DEFAULT_SECONDARY(3, 15, 16)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(6, 17, 18)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(9, 19, 20)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(12, 21, 22)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(28, 15, 16)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(29, 17, 18)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(30, 19, 20)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(31, 21, 22)
LEA((RCX, RDI, 1), RCX)
/*
Input for transpose:
Column-1 Column-2 Column-3 Column-4 Column-5 Column-6 Column-7 Column-8
Rows(9-15) ZMM(4) ZMM(7) ZMM(10) ZMM(13) ZMM(28) ZMM(29) ZMM(30) ZMM(31)
*/
/* Transpose the 8x8 block of alpha*A*B */
/* ZMM(15) to ZMM(22) are used as temporary registers
for transpose operation */
TRANSPOSE_8X8(4, 7, 10, 13, 28, 29, 30, 31,
15, 16, 17, 18, 19, 20, 21, 22)
/* Scale C by beta and compute the result */
/* This is done one row at a time */
BETA_DEFAULT_SECONDARY(4, 15, 16)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(7, 17, 18)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(10, 19, 20)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(13, 21, 22)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(28, 15, 16)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(29, 17, 18)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(30, 19, 20)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(31, 21, 22)
LEA((RCX, RDI, 1), RCX)
/*
Input for transpose:
Column-1 Column-2 Column-3 Column-4 Column-5 Column-6 Column-7 Column-8
Rows(16-23) ZMM(5) ZMM(8) ZMM(11) ZMM(14) ZMM(28) ZMM(29) ZMM(30) ZMM(31)
*/
/* Transpose the 8x8 block of alpha*A*B */
/* ZMM(15) to ZMM(22) are used as temporary registers
for transpose operation */
TRANSPOSE_8X8(5, 8, 11, 14, 28, 29, 30, 31,
15, 16, 17, 18, 19, 20, 21, 22)
/* Scale C by beta and compute the result */
/* This is done one row at a time */
BETA_DEFAULT_SECONDARY(5, 15, 16)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(8, 17, 18)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(11, 19, 20)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(14, 21, 22)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(28, 15, 16)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(29, 17, 18)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(30, 19, 20)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_SECONDARY(31, 21, 22)
JMP(.END)
LABEL(.BETA_ZERO_ROW)
/* We need to transpose the 24x4 block of alpha*A*B,
in steps of 8x4.
We use an 8x8 transpose routine with additional
registers.
Input for transpose:
Column-1 Column-2 Column-3 Column-4 Column-5 Column-6 Column-7 Column-8
Rows(1-8) ZMM(3) ZMM(6) ZMM(9) ZMM(12) ZMM(28) ZMM(29) ZMM(30) ZMM(31)
*/
/* Set 4 additional registers to 0.0 */
VXORPS(ZMM(28), ZMM(28), ZMM(28))
VXORPS(ZMM(29), ZMM(29), ZMM(29))
VXORPS(ZMM(30), ZMM(30), ZMM(30))
VXORPS(ZMM(31), ZMM(31), ZMM(31))
/* Transpose the 8x8 block of alpha*A*B */
/* ZMM(15) to ZMM(22) are used as temporary registers
for transpose operation */
TRANSPOSE_8X8(3, 6, 9, 12, 28, 29, 30, 31,
15, 16, 17, 18, 19, 20, 21, 22)
/* Store the result back to C */
/* We need to store only the first 256-bit lane of the
registers post transpose */
VMOVUPS(YMM(3), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(6), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(9), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(12), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(28), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(29), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(30), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(31), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
/*
Input for transpose:
Column-1 Column-2 Column-3 Column-4 Column-5 Column-6 Column-7 Column-8
Rows(9-15) ZMM(4) ZMM(7) ZMM(10) ZMM(13) ZMM(28) ZMM(29) ZMM(30) ZMM(31)
*/
/* Transpose the 8x8 block of alpha*A*B */
/* ZMM(15) to ZMM(22) are used as temporary registers
for transpose operation */
TRANSPOSE_8X8(4, 7, 10, 13, 28, 29, 30, 31,
15, 16, 17, 18, 19, 20, 21, 22)
/* Store the result back to C */
/* We need to store only the first 256-bit lane of the
registers post transpose */
VMOVUPS(YMM(4), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(7), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(10), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(13), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(28), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(29), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(30), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(31), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
/*
Input for transpose:
Column-1 Column-2 Column-3 Column-4 Column-5 Column-6 Column-7 Column-8
Rows(16-23) ZMM(5) ZMM(8) ZMM(11) ZMM(14) ZMM(28) ZMM(29) ZMM(30) ZMM(31)
*/
/* Transpose the 8x8 block of alpha*A*B */
/* ZMM(15) to ZMM(22) are used as temporary registers
for transpose operation */
TRANSPOSE_8X8(5, 8, 11, 14, 28, 29, 30, 31,
15, 16, 17, 18, 19, 20, 21, 22)
/* Store the result back to C */
/* We need to store only the first 256-bit lane of the
registers post transpose */
VMOVUPS(YMM(5), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(8), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(11), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(14), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(28), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(29), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(30), MEM(RCX))
LEA((RCX, RDI, 1), RCX)
VMOVUPS(YMM(31), MEM(RCX))
JMP(.END)
LABEL(.GENERALSTRIDE)
/* This code-section is taken if C has general stride */
/*
In case of general strides for C, we need to load/store C
using gather/scatter instructions.
Visualizing C(8x4):
---------------------------------------------
| C00 | C10 | ... | C30 |
| | (rs_c) | | (rs_c) | ... | | (rs_c) |
| C01 | C11 | ... | C31 |
| | (rs_c) | | (rs_c) | ... | | (rs_c) |
| C02 | C12 | ... | C32 |
| . | . | ... | . |
| . | . | ... | . |
| . | . | ... | . |
| C07 | C17 | ... | C37 |
---------------------------------------------
Loading C :
Gather all elements of C column-wise onto ZMM registers
Compute with C(based on beta):
Similar to column-stored case, perform beta scaling and add to
alpha*A*B
Storing C :
Scatter the result one column at a time, using ZMM registers
*/
MOV(VAR(offsetPtr), R9) // Load address of offsets
VPBROADCASTQ(RDI, ZMM(31)) // Broadcast rs_c onto a register
VPMULLQ(MEM(R9), ZMM(31), ZMM(28)) // ZMM28 = { 0*rs_c, 1*rs_c, 2*rs_c, 3*rs_c, ... }
VPMULLQ(MEM(R9, 64), ZMM(31), ZMM(29)) // ZMM29 = { 8*rs_c, 9*rs_c, 10*rs_c, 11*rs_c, ... }
VPMULLQ(MEM(R9, 128), ZMM(31), ZMM(30)) // ZMM30 = { 16*rs_c, 17*rs_c, 18*rs_c, 19*rs_c, ... }
/*
Check for beta_mul_type, to jump to the required code-section
Intermediate C = beta*C + IR, where IR = alpha*A*B
If beta == ( 0.0, 0.0 ) => BLIS_MUL_ZERO
C = IR, skip beta-scaling
else => BLIS_MUL_DEFAULT
C = beta*C + IR, using complex multiplication
*/
MOV(VAR(beta_mul_type), R14)
CMP(IMM(0), R14) // Check if beta = 0.0
/* Skip beta scaling and jump to store */
JE(.BETA_ZERO_GENERIC)
LABEL(.BETA_DEFAULT_GENERIC)
/* Load beta onto a ZMM register */
MOV(VAR(beta), RBX)
/* Broadcast the real and imag components of beta onto the registers */
VBROADCASTSS(MEM(RBX, 0), ZMM(1))
VBROADCASTSS(MEM(RBX, 4), ZMM(2))
/* Compute C = beta*C + alpha*A*B, and store to C */
BETA_DEFAULT_GENERAL(3, 4, 5)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_GENERAL(6, 7, 8)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_GENERAL(9, 10, 11)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_GENERAL(12, 13, 14)
JMP(.END)
LABEL(.BETA_ZERO_GENERIC)
/* Store the result onto C, one column at a time */
BETA_ZERO_GENERAL(3, 4, 5)
LEA((RCX, RSI, 1), RCX)
BETA_ZERO_GENERAL(6, 7, 8)
LEA((RCX, RSI, 1), RCX)
BETA_ZERO_GENERAL(9, 10, 11)
LEA((RCX, RSI, 1), RCX)
BETA_ZERO_GENERAL(12, 13, 14)
LABEL(.END)
VZEROUPPER()
end_asm(
: // output operands (none)
: // input operands
[k] "m"(k),
[a] "m"(a),
[b] "m"(b),
[c] "m"(c),
[rs_c] "m"(rs_c),
[cs_c] "m"(cs_c),
[fmaPtr] "m"(fmaPtr),
[offsetPtr] "m"(offsetPtr),
[alpha_mul_type] "m"(alpha_mul_type),
[beta_mul_type] "m"(beta_mul_type),
[alpha] "m"(alpha),
[beta] "m"(beta)
: // register clobber list
"rax", "rbx", "rcx", "rdi", "rsi", "r9", "r10", "r12", "r14",
"k0", "k1", "k2", "k3", "k4",
"ymm0", "ymm1", "ymm2", "ymm3",
"ymm4", "ymm5", "ymm6", "ymm7",
"ymm8", "ymm9", "ymm10", "ymm11",
"ymm12", "ymm13", "ymm14", "ymm15",
"ymm16", "ymm17", "ymm18", "ymm19",
"ymm20", "ymm21", "ymm22", "ymm28",
"ymm29", "ymm30", "ymm31",
"zmm0", "zmm1", "zmm2",
"zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8",
"zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14",
"zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20",
"zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory"
)
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7);
}

View File

@@ -0,0 +1,834 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "bli_cgemm_zen4_asm_macros.h"
#define LOOP_ALIGN ALIGN32
/* Minimum number of iterations required for prefetching C */
#define TAIL_ITER 6
// This array is used to support ADDSUB instruction.
static float fma_vec[16] __attribute__((aligned(64)))
= {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
// This is an array used for the scatter/gather instructions.
static int64_t offsets[24] __attribute__((aligned(64))) =
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23 };
/*
Register usage :
ZMM(0) - ZMM(2) : Load B
ZMM(28) - ZMM(31) : Bdcst A
ZMM(3) - ZMM(14) : Accumulate real_scaling
ZMM(15) - ZMM(26) : Accumulate imag_scaling / Load C
Total registers used : 31
*/
void bli_cgemm_zen4_asm_4x24(
dim_t k0,
scomplex *restrict alpha,
scomplex *restrict a,
scomplex *restrict b,
scomplex *restrict beta,
scomplex *restrict c, inc_t rs_c0, inc_t cs_c0,
auxinfo_t *data,
cntx_t *restrict cntx
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7);
/* Casting all the integers to the same type */
const uint64_t k = k0;
uint64_t rs_c = rs_c0 * 8; // rs_c0 = rs_c * 8(size of scomplex datatype)
uint64_t cs_c = cs_c0 * 8; // cs_c0 = cs_c * 8(size of scomplex datatype)
/* Storing the address of the fma_vec array, to be used in the computation */
const float *fmaPtr = &fma_vec[0];
/* Storing the address of offsets, to be used for general-stride computation */
const int64_t *offsetPtr = &offsets[0];
/* Determining the alpha and beta multiplication types */
/* This is done as part of optimizing the alpha and beta scaling */
uint64_t alpha_mul_type = BLIS_MUL_DEFAULT;
uint64_t beta_mul_type = BLIS_MUL_DEFAULT;
/* Setting alpha_mul_type and bet_mul_type, based on special cases
of alpha and beta. */
if ( alpha->imag == 0.0 )
{
if ( alpha->real == 1.0 )
alpha_mul_type = BLIS_MUL_ONE;
else if ( alpha->real == -1.0 )
alpha_mul_type = BLIS_MUL_MINUS_ONE;
}
if ( beta->imag == 0.0 )
{
if ( beta->real == 1.0 )
beta_mul_type = BLIS_MUL_ONE ;
else if ( beta->real == -1.0 )
beta_mul_type = BLIS_MUL_MINUS_ONE;
else if ( beta->real == 0.0 )
beta_mul_type = BLIS_MUL_ZERO;
}
/* Start of the assembly code-section */
BEGIN_ASM()
/* Setting the registers to zero */
SET_ZERO()
/* Loading the value of k in RSI */
MOV(VAR(k), RSI)
/* Loading the addresses of A, B and C */
MOV(VAR(a), RAX) // RAX = addr of A
MOV(VAR(b), RBX) // RBX = addr of B
MOV(VAR(c), RCX) // RCX = addr of C
/* Load R9 with address of C to be used during prefetch */
MOV(RCX, R9)
/* Loading row-stride of C, since this is a row major kernel */
/* NOTE : rs_c has already been scaled by the size of the datatype */
MOV(VAR(rs_c), R10) // R10 = rs_c = 8 * rs_c0
/* Unrolling by a factor of 4, along k-dimension */
MOV(RSI, RDI)
AND(IMM(3), RSI) // RSI = k % 4(k_fringe)
SAR(IMM(2), RDI) // RDI = k / 4(k_iter)
/* The k-loop is divided into 4 parts, to have a fixed distance for prefetching C */
/* The k-loop is divided as follows :
1. .CK_BEFORE_PREFETCH : k/4 - 4 - TAIL_ITER(before C prefetch)
2. .CK_PREFETCH : 4(prefetches C)
3. .CK_AFTER_PREFETCH : TAIL_ITER(after C prefetch(prefetch distance))
4. .CK_FRINGE
*/
LABEL(.CK_BEFORE_PREFETCH)
/* Check for entering the k-loop(before prefetch) */
SUB(IMM(4 + TAIL_ITER), RDI)
/* Jump to k-loop(prefetch) if k/4 <= 4 + TAIL_ITER */
JLE(.CK_PREFETCH)
/* K-loop(unrolled) to perform A*B */
LOOP_ALIGN
LABEL(.CK_ITER_BP) // Unrolled iteration (B)efore (P)refetching
/* Performing rank-1 update 4 times(based on unroll) */
SUB_ITER_24x4(0, RBX, RAX)
SUB_ITER_24x4(1, RBX, RAX)
SUB(IMM(1), RDI) // RDI(iterator) -= 1
SUB_ITER_24x4(2, RBX, RAX)
SUB_ITER_24x4(3, RBX, RAX)
/* Adjusting the addresses of A and B for the next set */
LEA(MEM(RBX, 4 * 24 * 8), RBX) // RBX = RBX + 4 * NR * 8(loads)
LEA(MEM(RAX, 4 * 4 * 8), RAX) // RAX = RAX + 4 * MR * 8(broascasts)
/* Jump if RDI != 0 */
JNZ(.CK_ITER_BP)
LABEL(.CK_PREFETCH)
/* Check for entering the k-loop(for C prefetch) */
ADD(IMM(4), RDI)
/* Jump to k-loop(after prefetch) if k/4 <= TAIL_ITER */
JLE(.CK_AFTER_PREFETCH)
/* K-loop(unrolled) to perform A*B */
LOOP_ALIGN
LABEL(.CK_ITER_P) // Unrolled iteration with (P)refetching
/* Performing rank-1 update 4 times(based on unroll) */
/* Also prefetch C */
PREFETCHW0(MEM(R9))
SUB_ITER_24x4(0, RBX, RAX)
PREFETCHW0(MEM(R9, 64))
SUB_ITER_24x4(1, RBX, RAX)
PREFETCHW0(MEM(R9, 128))
SUB(IMM(1), RDI) // RDI(iterator) -= 1
SUB_ITER_24x4(2, RBX, RAX)
SUB_ITER_24x4(3, RBX, RAX)
/* Adjusting the addresses of A and B for the next set */
LEA(MEM(RBX, 4 * 24 * 8), RBX) // RBX = RBX + 4 * NR * 8(loads)
LEA(MEM(RAX, 4 * 4 * 8), RAX) // RAX = RAX + 4 * MR * 8(broascasts)
LEA(MEM(R9, R10, 1), R9) // RCX = RCX + cs_c
/* Jump if RDI != 0 */
JNZ(.CK_ITER_P)
LABEL(.CK_AFTER_PREFETCH)
/* Check for entering the k-loop(for C prefetch) */
ADD(IMM(0 + TAIL_ITER), RDI)
/* Jump to k-loop(after prefetch) if k/4 <= 0 */
JLE(.CK_FRINGE)
/* K-loop(unrolled) to perform A*B */
LOOP_ALIGN
LABEL(.CK_ITER_AP) // Unrolled iteration (A)fter (P)refetching
/* Performing rank-1 update 4 times(based on unroll) */
/* Also prefetch C */
SUB_ITER_24x4(0, RBX, RAX)
SUB_ITER_24x4(1, RBX, RAX)
SUB(IMM(1), RDI) // RDI(iterator) -= 1
SUB_ITER_24x4(2, RBX, RAX)
SUB_ITER_24x4(3, RBX, RAX)
/* Adjusting the addresses of A and B for the next set */
LEA(MEM(RBX, 4 * 24 * 8), RBX) // RBX = RBX + 4 * NR * 8(loads)
LEA(MEM(RAX, 4 * 4 * 8), RAX) // RAX = RAX + 4 * MR * 8(broascasts)
/* Jump if RDI != 0 */
JNZ(.CK_ITER_AP)
LABEL(.CK_FRINGE)
/* Check for entering the k-loop(fringe) */
TEST(RSI, RSI)
JE(.POSTACCUM)
/* K-loop(unrolled) to perform A*B */
LOOP_ALIGN
LABEL(.CK_ITER_FRINGE)
/* Performing rank-1 update */
SUB_ITER_24x4(0, RBX, RAX)
SUB(IMM(1), RSI) // RSI(iterator) -= 1
/* Adjusting the addresses of A and B for the next set */
/* new_addr = old_addr + ( unroll_factor * {MR or NR} * size_of_type ) */
LEA(MEM(RBX, 24 * 8), RBX) // RBX = RBX + NR * 8(loads)
LEA(MEM(RAX, 4 * 8), RAX) // RAX = RAX + MR * 8(broascasts)
/* Jump until RSI becomes 0 */
JNZ(.CK_ITER_FRINGE)
LABEL(.POSTACCUM)
/* The registers from ZMM(15) to ZMM(26) contain the FMA ops
using imaginary components from elements in B matrix.
We should shuffle them( even and odd indices )
SRC: ZMM(15) = ( Br0*Ai0, Bi0*Ai0, Br1*Ai0, Bi1*Ai0, ... )
DST: ZMM(15) = ( Bi0*Ai0, Br0*Ai0, Bi1*Ai0, Br1*Ai0, ... )
Similary for the other registers
*/
PERMUTE(15, 16, 17)
PERMUTE(18, 19, 20)
PERMUTE(21, 22, 23)
PERMUTE(24, 25, 26)
/* Loading ZMM(0) with 1.0f, for reduction */
MOV(VAR(fmaPtr), R14)
VMOVAPS(MEM(R14), ZMM(0))
/* Reducing the result using real/imag accumulators, for complex arithmetic
SRC: ZMM(3) = ( Br0*Ar0, Bi0*Ar0, Br1*Ar0, Bi1*Ar0, ... )
ZMM(15) = ( Bi0*Ai0, Br0*Ai0, Bi1*Ai0, Br1*Ai0, ... )
DST: ZMM(3) = ( Br0*Ar0 - Bi0*Ai0, Bi0*Ar0 + Br0*Ai0, ... )
Similarly done for the other registers
*/
FMADDSUB(3, 15, 4, 16, 5, 17)
FMADDSUB(6, 18, 7, 19, 8, 20)
FMADDSUB(9, 21, 10, 22, 11, 23)
FMADDSUB(12, 24, 13, 25, 14, 26)
/*
The result of A*B(micro-tile) is a 4x24 matrix(row-major), as follows :
Cols(1-8) Cols(9-16) Cols(17-24)
Row-1 ZMM(3) ZMM(4) ZMM(5)
Row-2 ZMM(6) ZMM(7) ZMM(8)
Row-3 ZMM(9) ZMM(10) ZMM(11)
Row-4 ZMM(12) ZMM(13) ZMM(14)
*/
LABEL(.ALPHA_SCALING)
/*
Check for alpha_mul_type, to jump to the required code-section
Intermediate result(IR) = alpha*(A*B)
If alpha == ( 1.0, 0.0 ) => BLIS_MUL_ONE
IR = A*B
else if, alpha != ( -1.0, 0.0 ) => BLIS_MUL_DEFAULT
IR = alpha*(A*B), using complex multiplication
else => BLIS_MUL_MINUS_ONE
IR = 0.0 - A*B, using subtraction
*/
MOV(VAR(alpha_mul_type), R14)
CMP(IMM(1), R14) // Check if alpha = 1.0
/* Skip alpha scaling and jump to beta scaling */
JE(.BETA_SCALING)
CMP(IMM(2), R14) // Check if alpha != -1.0
/* Jump to the general case of alpha scaling */
JE(.ALPHA_GENERAL)
/* Alpha scaling when alpha == -1.0 */
LABEL(.ALPHA_MINUS_ONE)
/* Set ZMM(1) to 0.0f, and subtract the registers from ZMM(1) */
VXORPS(ZMM(1), ZMM(1), ZMM(1))
/* ZMM(3) = ZMM(1) - ZMM(3) = 0.0f - A*B
Similarly done for other registers */
ALPHA_MINUS_ONE(3, 1, 4, 1, 5, 1)
ALPHA_MINUS_ONE(6, 1, 7, 1, 8, 1)
ALPHA_MINUS_ONE(9, 1, 10, 1, 11, 1)
ALPHA_MINUS_ONE(12, 1, 13, 1, 14, 1)
/* Jump to beta scaling */
JMP(.BETA_SCALING)
/* Alpha scaling when alpha != 1.0 and alpha != -1.0 */
LABEL(.ALPHA_GENERAL)
/* Load alpha onto a ZMM register */
MOV(VAR(alpha), RAX)
/* Broadcast the real and imag components of alpha onto the registers */
VBROADCASTSS(MEM(RAX, 0), ZMM(1))
VBROADCASTSS(MEM(RAX, 4), ZMM(2))
/* Scale the result of A*B with alpha */
/* ZMM(15) = alphai * ZMM(3)
ZMM(3) = alphar * ZMM(3)
ZMM(3) = fmaddsub(ZMM(3), permute(ZMM(15)))
Similarly done for other pairs of registers */
ALPHA_DEFAULT(3, 15, 4, 16, 5, 17)
ALPHA_DEFAULT(6, 18, 7, 19, 8, 20)
ALPHA_DEFAULT(9, 21, 10, 22, 11, 23)
ALPHA_DEFAULT(12, 24, 13, 25, 14, 26)
/* Perform beta scaling */
LABEL(.BETA_SCALING)
/* Load the row and column strides of C */
MOV(VAR(rs_c), RDI)
MOV(VAR(cs_c), RSI)
CMP(IMM(8), RSI) // Check if C is row stored
JNE(.COLSTORED) // Jump to row stored
LABEL(.ROWSTORED)
/*
Check for beta_mul_type, to jump to the required code-section
Intermediate C = beta*C + IR, where IR = alpha*A*B
If beta == ( 0.0, 0.0 ) => BLIS_MUL_ZERO
C = IR, skip beta-scaling
else if beta == ( 1.0, 0.0 ) => BLIS_MUL_ONE
C = C + IR, using addition
else if, beta != ( -1.0, 0.0 ) => BLIS_MUL_DEFAULT
C = beta*C + IR, using complex multiplication
else => BLIS_MUL_MINUS_ONE
C = ( 0.0 - C ) + IR, using subtraction
*/
MOV(VAR(beta_mul_type), R14)
CMP(IMM(0), R14) // Check if beta = 0.0
/* Skip beta scaling and jump to store */
JE(.BETA_ZERO_ROW)
CMP(IMM(1), R14) // Check if beta = 1.0
/* Jump to beta = 1.0 case */
JE(.BETA_ONE_ROW)
CMP(IMM(2), R14) // Check if alpha != -1.0
/* Jump to the general case of alpha scaling */
JE(.BETA_DEFAULT_ROW)
/* Beta scaling when beta == -1.0 */
LABEL(.BETA_MINUS_ONE_ROW)
/* Perform C = alpha*A*B - C */
/* ZMM(15) = load(C)
ZMM(15) = ZMM(3) - ZMM(15) = alpha*A*B - C
store(ZMM(15))
Similarly done for other registers */
BETA_MINUS_ONE_PRIMARY(3, 15, 4, 16, 5, 17)
LEA((RCX, R10, 1), RCX)
BETA_MINUS_ONE_PRIMARY(6, 18, 7, 19, 8, 20)
LEA((RCX, R10, 1), RCX)
BETA_MINUS_ONE_PRIMARY(9, 21, 10, 22, 11, 23)
LEA((RCX, R10, 1), RCX)
BETA_MINUS_ONE_PRIMARY(12, 24, 13, 25, 14, 26)
JMP(.END)
/* Beta scaling when beta == -1.0 */
LABEL(.BETA_ONE_ROW)
/* Perform C = C + alpha*A*B */
/* ZMM(15) = load(C)
ZMM(15) = ZMM(3) + ZMM(15) = alpha*A*B + C
store(ZMM(15))
Similarly done for other registers */
BETA_ONE_PRIMARY(3, 15, 4, 16, 5, 17)
LEA((RCX, R10, 1), RCX)
BETA_ONE_PRIMARY(6, 18, 7, 19, 8, 20)
LEA((RCX, R10, 1), RCX)
BETA_ONE_PRIMARY(9, 21, 10, 22, 11, 23)
LEA((RCX, R10, 1), RCX)
BETA_ONE_PRIMARY(12, 24, 13, 25, 14, 26)
JMP(.END)
/* Beta scaling for generic case */
LABEL(.BETA_DEFAULT_ROW)
/* Load beta onto a ZMM register */
MOV(VAR(beta), RBX)
/* Broadcast the real and imag components of beta onto the registers */
VBROADCASTSS(MEM(RBX, 0), ZMM(1))
VBROADCASTSS(MEM(RBX, 4), ZMM(2))
/* Perform C = beta*C + alpha*A*B */
/* ZMM(15) = load(C)
Perform beta scaling of ZMM(15)(similar to alpha scaling)
ZMM(15) = ZMM(3) + ZMM(15) = alpha*A*B + beta*C
store(ZMM(15))
Similarly done for other pairs of registers */
BETA_DEFAULT_PRIMARY(3, 15, 4, 16, 5, 17)
LEA((RCX, R10, 1), RCX)
BETA_DEFAULT_PRIMARY(6, 18, 7, 19, 8, 20)
LEA((RCX, R10, 1), RCX)
BETA_DEFAULT_PRIMARY(9, 21, 10, 22, 11, 23)
LEA((RCX, R10, 1), RCX)
BETA_DEFAULT_PRIMARY(12, 24, 13, 25, 14, 26)
JMP(.END)
LABEL(.BETA_ZERO_ROW)
/* This code-section is taken if we want to skip scaling */
VMOVUPS(ZMM(3), MEM(RCX))
VMOVUPS(ZMM(4), MEM(RCX, 64))
VMOVUPS(ZMM(5), MEM(RCX, 128))
LEA((RCX, R10, 1), RCX)
VMOVUPS(ZMM(6), MEM(RCX))
VMOVUPS(ZMM(7), MEM(RCX, 64))
VMOVUPS(ZMM(8), MEM(RCX, 128))
LEA((RCX, R10, 1), RCX)
VMOVUPS(ZMM(9), MEM(RCX))
VMOVUPS(ZMM(10), MEM(RCX, 64))
VMOVUPS(ZMM(11), MEM(RCX, 128))
LEA((RCX, R10, 1), RCX)
VMOVUPS(ZMM(12), MEM(RCX))
VMOVUPS(ZMM(13), MEM(RCX, 64))
VMOVUPS(ZMM(14), MEM(RCX, 128))
JMP(.END)
LABEL(.COLSTORED)
/* Check for general stride of C */
CMP(IMM(8), RDI) // Check if C is col stored
JNE(.GENERALSTRIDE) // Jump to general stride
/* This code-section is taken if C is col-stored */
/*
Check for beta_mul_type, to jump to the required code-section
Intermediate C = beta*C + IR, where IR = alpha*A*B
If beta == ( 0.0, 0.0 ) => BLIS_MUL_ZERO
C = IR, skip beta-scaling
else => BLIS_MUL_DEFAULT
C = beta*C + IR, using complex multiplication
*/
MOV(VAR(beta_mul_type), R14)
CMP(IMM(0), R14) // Check if beta = 0.0
/* Skip beta scaling and jump to store */
JE(.BETA_ZERO_COL)
LABEL(.BETA_DEFAULT_COL)
/* Load beta onto a ZMM register */
MOV(VAR(beta), RBX)
/* Broadcast the real and imag components of beta onto the registers */
VBROADCASTSS(MEM(RBX, 0), ZMM(1))
VBROADCASTSS(MEM(RBX, 4), ZMM(2))
/* We need to transpose the 24x4 block of alpha*A*B,
in steps of 8x4.
We use an 8x8 transpose routine with additional
registers.
Input for transpose:
Columns(1-8)
Row-1 ZMM(3)
Row-2 ZMM(6)
Row-3 ZMM(9)
Row-3 ZMM(12)
Row-4 ZMM(28)
Row-5 ZMM(29)
Row-6 ZMM(30)
Row-7 ZMM(31)
*/
/* Transpose the 8x8 block of alpha*A*B */
/* ZMM(15) to ZMM(22) are used as temporary registers
for transpose operation */
TRANSPOSE_8X8(3, 6, 9, 12, 28, 29, 30, 31,
15, 16, 17, 18, 19, 20, 21, 22)
/* Scale C by beta and compute the result */
/* This is done one row at a time */
BETA_DEFAULT_SECONDARY(3, 15, 16)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(6, 17, 18)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(9, 19, 20)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(12, 21, 22)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(28, 15, 16)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(29, 17, 18)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(30, 19, 20)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(31, 21, 22)
LEA((RCX, RSI, 1), RCX)
/*
Input for transpose:
Columns(9-15)
Row-1 ZMM(4)
Row-2 ZMM(7)
Row-3 ZMM(10)
Row-3 ZMM(13)
Row-4 ZMM(28)
Row-5 ZMM(29)
Row-6 ZMM(30)
Row-7 ZMM(31)
*/
/* Transpose the 8x8 block of alpha*A*B */
/* ZMM(15) to ZMM(22) are used as temporary registers
for transpose operation */
TRANSPOSE_8X8(4, 7, 10, 13, 28, 29, 30, 31,
15, 16, 17, 18, 19, 20, 21, 22)
/* Scale C by beta and compute the result */
/* This is done one col at a time */
BETA_DEFAULT_SECONDARY(4, 15, 16)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(7, 17, 18)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(10, 19, 20)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(13, 21, 22)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(28, 15, 16)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(29, 17, 18)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(30, 19, 20)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(31, 21, 22)
LEA((RCX, RSI, 1), RCX)
/*
Input for transpose:
Columns(16-24)
Row-1 ZMM(5)
Row-2 ZMM(8)
Row-3 ZMM(11)
Row-3 ZMM(14)
Row-4 ZMM(28)
Row-5 ZMM(29)
Row-6 ZMM(30)
Row-7 ZMM(31)
*/
/* Transpose the 8x8 block of alpha*A*B */
/* ZMM(15) to ZMM(22) are used as temporary registers
for transpose operation */
TRANSPOSE_8X8(5, 8, 11, 14, 28, 29, 30, 31,
15, 16, 17, 18, 19, 20, 21, 22)
/* Scale C by beta and compute the result */
/* This is done one col at a time */
BETA_DEFAULT_SECONDARY(5, 15, 16)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(8, 17, 18)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(11, 19, 20)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(14, 21, 22)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(28, 15, 16)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(29, 17, 18)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(30, 19, 20)
LEA((RCX, RSI, 1), RCX)
BETA_DEFAULT_SECONDARY(31, 21, 22)
JMP(.END)
LABEL(.BETA_ZERO_COL)
/* We need to transpose the 24x4 block of alpha*A*B,
in steps of 8x4.
We use an 8x8 transpose routine with additional
registers.
Input for transpose:
Columns(1-8)
Row-1 ZMM(3)
Row-2 ZMM(6)
Row-3 ZMM(9)
Row-3 ZMM(12)
Row-4 ZMM(28)
Row-5 ZMM(29)
Row-6 ZMM(30)
Row-7 ZMM(31)
*/
/* Transpose the 8x8 block of alpha*A*B */
/* ZMM(15) to ZMM(22) are used as temporary registers
for transpose operation */
TRANSPOSE_8X8(3, 6, 9, 12, 28, 29, 30, 31,
15, 16, 17, 18, 19, 20, 21, 22)
/* Store the result back to C */
/* We need to store only the first 256-bit lane of the
registers post transpose */
VMOVUPS(YMM(3), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(6), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(9), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(12), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(28), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(29), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(30), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(31), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
/*
Input for transpose:
Columns(9-15)
Row-1 ZMM(4)
Row-2 ZMM(7)
Row-3 ZMM(10)
Row-3 ZMM(13)
Row-4 ZMM(28)
Row-5 ZMM(29)
Row-6 ZMM(30)
Row-7 ZMM(31)
*/
/* Transpose the 8x8 block of alpha*A*B */
/* ZMM(15) to ZMM(22) are used as temporary registers
for transpose operation */
TRANSPOSE_8X8(4, 7, 10, 13, 28, 29, 30, 31,
15, 16, 17, 18, 19, 20, 21, 22)
/* Store the result back to C */
/* We need to store only the first 256-bit lane of the
registers post transpose */
VMOVUPS(YMM(4), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(7), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(10), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(13), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(28), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(29), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(30), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(31), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
/*
Input for transpose:
Columns(16-24)
Row-1 ZMM(5)
Row-2 ZMM(8)
Row-3 ZMM(11)
Row-3 ZMM(14)
Row-4 ZMM(28)
Row-5 ZMM(29)
Row-6 ZMM(30)
Row-7 ZMM(31)
*/
/* Transpose the 8x8 block of alpha*A*B */
/* ZMM(15) to ZMM(22) are used as temporary registers
for transpose operation */
TRANSPOSE_8X8(5, 8, 11, 14, 28, 29, 30, 31,
15, 16, 17, 18, 19, 20, 21, 22)
/* Store the result back to C */
/* We need to store only the first 256-bit lane of the
registers post transpose */
VMOVUPS(YMM(5), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(8), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(11), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(14), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(28), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(29), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(30), MEM(RCX))
LEA((RCX, RSI, 1), RCX)
VMOVUPS(YMM(31), MEM(RCX))
JMP(.END)
LABEL(.GENERALSTRIDE)
/* This code-section is taken if C has general stride */
/*
In case of general strides for C, we need to load/store C
using gather/scatter instructions.
Visualizing C(4x8):
---------------------------------------------
| C00 --(cs_c)-- C10 --(cs_c)-- C20 ... C70 |
| ------------------------------------------|
| C01 --(cs_c)-- C11 --(cs_c)-- C21 ... C71 |
| ------------------------------------------|
| C02 --(cs_c)-- C12 --(cs_c)-- C22 ... C72 |
| ------------------------------------------|
| C03 --(cs_c)-- C13 --(cs_c)-- C23 ... C73 |
---------------------------------------------
Loading C :
Gather all elements of C row-wise onto ZMM registers
Compute with C(based on beta):
Similar to row-stored case, perform beta scaling and add to
alpha*A*B
Storing C :
Scatter the result one row at a time, using ZMM registers
*/
MOV(VAR(offsetPtr), R9) // Load address of offsets
VPBROADCASTQ(RSI, ZMM(31)) // Broadcast cs_c onto a register
VPMULLQ(MEM(R9), ZMM(31), ZMM(28)) // ZMM28 = { 0*cs_c, 1*cs_c, 2*cs_c, 3*cs_c, ... }
VPMULLQ(MEM(R9, 64), ZMM(31), ZMM(29)) // ZMM29 = { 8*cs_c, 9*cs_c, 10*cs_c, 11*cs_c, ... }
VPMULLQ(MEM(R9, 128), ZMM(31), ZMM(30)) // ZMM30 = { 16*cs_c, 17*cs_c, 18*cs_c, 19*cs_c, ... }
/*
Check for beta_mul_type, to jump to the required code-section
Intermediate C = beta*C + IR, where IR = alpha*A*B
If beta == ( 0.0, 0.0 ) => BLIS_MUL_ZERO
C = IR, skip beta-scaling
else => BLIS_MUL_DEFAULT
C = beta*C + IR, using complex multiplication
*/
MOV(VAR(beta_mul_type), R14)
CMP(IMM(0), R14) // Check if beta = 0.0
/* Skip beta scaling and jump to store */
JE(.BETA_ZERO_GENERIC)
LABEL(.BETA_DEFAULT_GENERIC)
/* Load beta onto a ZMM register */
MOV(VAR(beta), RBX)
/* Broadcast the real and imag components of beta onto the registers */
VBROADCASTSS(MEM(RBX, 0), ZMM(1))
VBROADCASTSS(MEM(RBX, 4), ZMM(2))
/* Compute C = beta*C + alpha*A*B, and store to C */
BETA_DEFAULT_GENERAL(3, 4, 5)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_GENERAL(6, 7, 8)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_GENERAL(9, 10, 11)
LEA((RCX, RDI, 1), RCX)
BETA_DEFAULT_GENERAL(12, 13, 14)
JMP(.END)
LABEL(.BETA_ZERO_GENERIC)
/* Store the result onto C, one column at a time */
BETA_ZERO_GENERAL(3, 4, 5)
LEA((RCX, RDI, 1), RCX)
BETA_ZERO_GENERAL(6, 7, 8)
LEA((RCX, RDI, 1), RCX)
BETA_ZERO_GENERAL(9, 10, 11)
LEA((RCX, RDI, 1), RCX)
BETA_ZERO_GENERAL(12, 13, 14)
LABEL(.END)
VZEROUPPER()
end_asm(
: // output operands (none)
: // input operands
[k] "m"(k),
[a] "m"(a),
[b] "m"(b),
[c] "m"(c),
[rs_c] "m"(rs_c),
[cs_c] "m"(cs_c),
[fmaPtr] "m"(fmaPtr),
[offsetPtr] "m"(offsetPtr),
[alpha_mul_type] "m"(alpha_mul_type),
[beta_mul_type] "m"(beta_mul_type),
[alpha] "m"(alpha),
[beta] "m"(beta)
: // register clobber list
"rax", "rbx", "rcx", "rdi", "rsi", "r9", "r10", "r12", "r14",
"k0", "k1", "k2", "k3", "k4",
"ymm0", "ymm1", "ymm2", "ymm3",
"ymm4", "ymm5", "ymm6", "ymm7",
"ymm8", "ymm9", "ymm10", "ymm11",
"ymm12", "ymm13", "ymm14", "ymm15",
"ymm16", "ymm17", "ymm18", "ymm19",
"ymm20", "ymm21", "ymm22", "ymm28",
"ymm29", "ymm30", "ymm31",
"zmm0", "zmm1", "zmm2",
"zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8",
"zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14",
"zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20",
"zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26",
"zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory"
)
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7);
}

View File

@@ -0,0 +1,446 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#define BLIS_ASM_SYNTAX_ATT
#include "bli_x86_asm_macros.h"
/* NOTE : This file contains the macros used to implement CGEMM native
kernels of the following (MR, NR) pairs :
Column-major : (24, 4)
Row-major : (4, 24)
The macro for micro-tile computation(SUB_ITER_24x4) accepts the load
and broadcast addresses as parameters. Thus, we could use this macro
for in both 24x4 col major kernel and 4x24 row major kernel, by passing
the appropriate load and broadcast address registers.
*/
/* Macro to set all the registers to zero */
#define SET_ZERO() \
VXORPS(ZMM(0), ZMM(0), ZMM(0)) \
VXORPS(ZMM(1), ZMM(1), ZMM(1)) \
VXORPS(ZMM(2), ZMM(2), ZMM(2)) \
VXORPS(ZMM(3), ZMM(3), ZMM(3)) \
VXORPS(ZMM(4), ZMM(4), ZMM(4)) \
VXORPS(ZMM(5), ZMM(5), ZMM(5)) \
VXORPS(ZMM(6), ZMM(6), ZMM(6)) \
VXORPS(ZMM(7), ZMM(7), ZMM(7)) \
VXORPS(ZMM(8), ZMM(8), ZMM(8)) \
VXORPS(ZMM(9), ZMM(9), ZMM(9)) \
VXORPS(ZMM(10), ZMM(10), ZMM(10)) \
VXORPS(ZMM(11), ZMM(11), ZMM(11)) \
VXORPS(ZMM(12), ZMM(12), ZMM(12)) \
VXORPS(ZMM(13), ZMM(13), ZMM(13)) \
VXORPS(ZMM(14), ZMM(14), ZMM(14)) \
VXORPS(ZMM(15), ZMM(15), ZMM(15)) \
VXORPS(ZMM(16), ZMM(16), ZMM(16)) \
VXORPS(ZMM(17), ZMM(17), ZMM(17)) \
VXORPS(ZMM(18), ZMM(18), ZMM(18)) \
VXORPS(ZMM(19), ZMM(19), ZMM(19)) \
VXORPS(ZMM(20), ZMM(20), ZMM(20)) \
VXORPS(ZMM(21), ZMM(21), ZMM(21)) \
VXORPS(ZMM(22), ZMM(22), ZMM(22)) \
VXORPS(ZMM(23), ZMM(23), ZMM(23)) \
VXORPS(ZMM(24), ZMM(24), ZMM(24)) \
VXORPS(ZMM(25), ZMM(25), ZMM(25)) \
VXORPS(ZMM(26), ZMM(26), ZMM(26)) \
VXORPS(ZMM(27), ZMM(27), ZMM(27)) \
VXORPS(ZMM(28), ZMM(28), ZMM(28)) \
VXORPS(ZMM(29), ZMM(29), ZMM(29)) \
VXORPS(ZMM(30), ZMM(30), ZMM(30)) \
VXORPS(ZMM(31), ZMM(31), ZMM(31)) \
/* Macro to perform the rank-1 update using loads and broadcasts from the matrices(24x4) */
/* RL represents the register that has load address, RB represents broadcast address */
/* For the sake of comments, let's assume RL = A, and RB = B(column-major kernel) */
#define SUB_ITER_24x4(n, RL, RB) \
VMOVAPS(MEM(RL, (24 * n + 0) * 8), ZMM(0)) /* ZMM(0) = A[0:7][n] */ \
VMOVAPS(MEM(RL, (24 * n + 8) * 8), ZMM(1)) /* ZMM(1) = A[8:15][n] */ \
VMOVAPS(MEM(RL, (24 * n + 16) * 8), ZMM(2)) /* ZMM(2) = A[16:23][n] */ \
\
VBROADCASTSS(MEM(RB, (8 * n + 0) * 4), ZMM(28)) /* ZMM(28) = Real(B[n][0]) */ \
VBROADCASTSS(MEM(RB, (8 * n + 1) * 4), ZMM(29)) /* ZMM(29) = Imag(B[n][0]) */ \
VFMADD231PS(ZMM(0), ZMM(28), ZMM(3)) /* ZMM(3) = A[0:7][n] * Real(B[n][0]) */ \
VFMADD231PS(ZMM(1), ZMM(28), ZMM(4)) /* ZMM(4) = A[8:15][n] * Real(B[n][0]) */ \
VFMADD231PS(ZMM(2), ZMM(28), ZMM(5)) /* ZMM(5) = A[16:23][n] * Real(B[n][0]) */ \
VFMADD231PS(ZMM(0), ZMM(29), ZMM(15)) /* ZMM(15) = A[0:7][n] * Imag(B[n][0]) */ \
VFMADD231PS(ZMM(1), ZMM(29), ZMM(16)) /* ZMM(16) = A[8:15][n] * Imag(B[n][0]) */ \
VFMADD231PS(ZMM(2), ZMM(29), ZMM(17)) /* ZMM(17) = A[16:23][n] * Imag(B[n][0]) */ \
\
VBROADCASTSS(MEM(RB, (8 * n + 2) * 4), ZMM(30)) /* ZMM(30) = Real(B[n][1]) */ \
VBROADCASTSS(MEM(RB, (8 * n + 3) * 4), ZMM(31)) /* ZMM(31) = Imag(B[n][1]) */ \
VFMADD231PS(ZMM(0), ZMM(30), ZMM(6)) /* ZMM(6) = A[0:7][n] * Real(B[n][1]) */ \
VFMADD231PS(ZMM(1), ZMM(30), ZMM(7)) /* ZMM(7) = A[8:15][n] * Real(B[n][1]) */ \
VFMADD231PS(ZMM(2), ZMM(30), ZMM(8)) /* ZMM(8) = A[16:23][n] * Real(B[n][1]) */ \
VFMADD231PS(ZMM(0), ZMM(31), ZMM(18)) /* ZMM(18) = A[0:7][n] * Imag(B[n][1]) */ \
VFMADD231PS(ZMM(1), ZMM(31), ZMM(19)) /* ZMM(19) = A[8:15][n] * Imag(B[n][1]) */ \
VFMADD231PS(ZMM(2), ZMM(31), ZMM(20)) /* ZMM(20) = A[16:23][n] * Imag(B[n][1]) */ \
\
VBROADCASTSS(MEM(RB, (8 * n + 4) * 4), ZMM(28)) /* ZMM(28) = Real(B[n][2]) */ \
VBROADCASTSS(MEM(RB, (8 * n + 5) * 4), ZMM(29)) /* ZMM(29) = Imag(B[n][2]) */ \
VFMADD231PS(ZMM(0), ZMM(28), ZMM(9)) /* ZMM(9) = A[0:7][n] * Real(B[n][2]) */ \
VFMADD231PS(ZMM(1), ZMM(28), ZMM(10)) /* ZMM(10) = A[8:15][n] * Real(B[n][2]) */ \
VFMADD231PS(ZMM(2), ZMM(28), ZMM(11)) /* ZMM(11) = A[16:23][n] * Real(B[n][2]) */ \
VFMADD231PS(ZMM(0), ZMM(29), ZMM(21)) /* ZMM(21) = A[0:7][n] * Imag(B[n][2]) */ \
VFMADD231PS(ZMM(1), ZMM(29), ZMM(22)) /* ZMM(22) = A[8:15][n] * Imag(B[n][2]) */ \
VFMADD231PS(ZMM(2), ZMM(29), ZMM(23)) /* ZMM(23) = A[16:23][n] * Imag(B[n][2]) */ \
\
VBROADCASTSS(MEM(RB, (8 * n + 6) * 4), ZMM(30)) /* ZMM(30) = Real(B[n][3]) */ \
VBROADCASTSS(MEM(RB, (8 * n + 7) * 4), ZMM(31)) /* ZMM(31) = Imag(B[n][3]) */ \
VFMADD231PS(ZMM(0), ZMM(30), ZMM(12)) /* ZMM(12) = A[0:7][n] * Real(B[n][3]) */ \
VFMADD231PS(ZMM(1), ZMM(30), ZMM(13)) /* ZMM(13) = A[8:15][n] * Real(B[n][3]) */ \
VFMADD231PS(ZMM(2), ZMM(30), ZMM(14)) /* ZMM(14) = A[16:23][n] * Real(B[n][3]) */ \
VFMADD231PS(ZMM(0), ZMM(31), ZMM(24)) /* ZMM(24) = A[0:7][n] * Imag(B[n][3]) */ \
VFMADD231PS(ZMM(1), ZMM(31), ZMM(25)) /* ZMM(25) = A[8:15][n] * Imag(B[n][3]) */ \
VFMADD231PS(ZMM(2), ZMM(31), ZMM(26)) /* ZMM(26) = A[16:23][n] * Imag(B[n][3]) */ \
/* Macro to scale the registers */
/* 'B' represents the broadcasted register, to be used for scaling */
#define SCALE(B, R1, O1, R2, O2, R3, O3) \
VMULPS(ZMM(B), ZMM(R1), ZMM(O1)) /* ZMM(O1) = ZMM(B) * ZMM(R1) */ \
VMULPS(ZMM(B), ZMM(R2), ZMM(O2)) /* ZMM(O2) = ZMM(B) * ZMM(R2) */ \
VMULPS(ZMM(B), ZMM(R3), ZMM(O3)) /* ZMM(O3) = ZMM(B) * ZMM(R3) */ \
/* Macro to shuffle even and odd indexed elements in a ZMM register */
#define PERMUTE(I1, I2, I3) \
/* For col major kernel: ZMM(I1) = { Ai0.Bi0, Ar0.Bi0, Ai1.Bi1, Ar1.Bi1, ... }
For row major kernel: ZMM(I1) = { Bi0.Ai0, Br0.Ai0, Bi1.Ai1, Br1.Ai1, ... }.
Similarly done for the other registers */ \
VPERMILPS(IMM(0xB1), ZMM(I1), ZMM(I1)) \
VPERMILPS(IMM(0xB1), ZMM(I2), ZMM(I2)) \
VPERMILPS(IMM(0xB1), ZMM(I3), ZMM(I3)) \
/* Macro to reduce a real and imag accumulator pair, as per complex arithmetic */
/* Macro assumes that ZMM(0) has 1.0f broadcasted in it */
#define FMADDSUB(R1, I1, R2, I2, R3, I3) \
/* ZMM(R1) = ZMM(R1) - 1.0f * ZMM(I1)
For col major kernel: ZMM(R1) = { Ar0.Br0 - Ai0.Bi0, Ai0.Br0 + Ar0.Bi0 }
For row major kernel: ZMM(R1) = { Br0.Ar0 - Bi0.Ai0, Bi0.Ar0 + Br0.Ai0 }
Similarly done for the other pairs */ \
VFMADDSUB132PS(ZMM(0), ZMM(I1), ZMM(R1)) \
VFMADDSUB132PS(ZMM(0), ZMM(I2), ZMM(R2)) \
VFMADDSUB132PS(ZMM(0), ZMM(I3), ZMM(R3)) \
/* Macro to handle alpha scaling when alpha is -1.0f */
/* Macro assumes that ZMM(1) and ZMM(2) have real and imag components
of alpha already broadcasted */
/* ZMM(S1) = ZMM(S2) = ZMM(S3) = 0.0f, for alpha-scaling */
#define ALPHA_MINUS_ONE(R1, S1, R2, S2, R3, S3) \
VSUBPS(ZMM(R1), ZMM(S1), ZMM(R1)) /* ZMM(R1) = 0.0f - ZMM(R1) */ \
VSUBPS(ZMM(R2), ZMM(S2), ZMM(R2)) /* ZMM(R2) = 0.0f - ZMM(R2) */ \
VSUBPS(ZMM(R3), ZMM(S3), ZMM(R3)) /* ZMM(R3) = 0.0f - ZMM(R3) */ \
/* Macro to hadnle alpha scaling in generic case */
/* Macro assumes that ZMM(1) and ZMM(2) have real and imag components
of alpha already broadcasted */
#define ALPHA_DEFAULT(R1, I1, R2, I2, R3, I3) \
/* Scale with real and imag components of beta */ \
/* Assume ZMM(R1) = { Ar0, Ai0, ... } */ \
/* ZMM(I1) = { Ar0.alphai, Ai0.alphai, ... }
Similarly done for other registers */ \
SCALE(2, R1, I1, R2, I2, R3, I3) \
/* ZMM(R1) = { Ar0.alphar, Ai0.alphar, ... }
Similarly done for other registers */ \
SCALE(1, R1, R1, R2, R2, R3, R3) \
\
/* Shuffle the imag accumulators for reduction */ \
/* ZMM(I1) = { Ai0.alphai, Ar0.alphai, ... }
Similarly done for other registers */ \
PERMUTE(I1, I2, I3) \
\
/* Reduce using fmaddsub instruction */ \
/* ZMM(R1) = { Ar0.alphar - Ai0.alphai, Ai0.alphar + Ar0.alphai, ... }
Similarly done for other registers */ \
FMADDSUB(R1, I1, R2, I2, R3, I3) \
/* Macro to handle beta scaling when beta is 1.0f, with primary storage */
#define BETA_ONE_PRIMARY(R1, C1, R2, C2, R3, C3) \
/* Load C onto the registers*/ \
VMOVUPS(MEM(RCX), ZMM(C1)) \
VMOVUPS(MEM(RCX, 64), ZMM(C2)) \
VMOVUPS(MEM(RCX, 128), ZMM(C3)) \
\
/* Add C to the result of A*B */ \
VADDPS(ZMM(C1), ZMM(R1), ZMM(C1)) /* ZMM(C1) = ZMM(R1) + ZMM(C1) */ \
VADDPS(ZMM(C2), ZMM(R2), ZMM(C2)) /* ZMM(C2) = ZMM(R2) + ZMM(C2) */ \
VADDPS(ZMM(C3), ZMM(R3), ZMM(C3)) /* ZMM(C3) = ZMM(R2) + ZMM(C3) */ \
\
/* Store the results onto C */ \
VMOVUPS(ZMM(C1), MEM(RCX)) \
VMOVUPS(ZMM(C2), MEM(RCX, 64)) \
VMOVUPS(ZMM(C3), MEM(RCX, 128)) \
/* Macro to handle beta scaling when beta is -1.0f, with primary storage */
#define BETA_MINUS_ONE_PRIMARY(R1, C1, R2, C2, R3, C3) \
/* Load C onto the registers*/ \
VMOVUPS(MEM(RCX), ZMM(C1)) \
VMOVUPS(MEM(RCX, 64), ZMM(C2)) \
VMOVUPS(MEM(RCX, 128), ZMM(C3)) \
\
/* Subtract C from the result of alpha*A*B(use pre-existing macro) */ \
/* ZMM(C1) = ZMM(R1) - ZMM(C1)
Similarly done for other registers */ \
ALPHA_MINUS_ONE(C1, R1, C2, R2, C3, R3) \
\
/* Store the results onto C */ \
VMOVUPS(ZMM(C1), MEM(RCX)) \
VMOVUPS(ZMM(C2), MEM(RCX, 64)) \
VMOVUPS(ZMM(C3), MEM(RCX, 128)) \
/* Macro to scale a set of 3 registers with beta, with primary storage */
/* Macro assumes that ZMM(1) and ZMM(2) have real and imag components
of beta already broadcasted */
/* Macro uses C1...C3 to load C, and R1...R3 contains the result of alpha*A*B */
/* Macro uses ZMM(28) - ZMM(30) for beta*C computation */
#define BETA_DEFAULT_PRIMARY(R1, C1, R2, C2, R3, C3) \
/* Load C onto the registers*/ \
VMOVUPS(MEM(RCX), ZMM(C1)) \
VMOVUPS(MEM(RCX, 64), ZMM(C2)) \
VMOVUPS(MEM(RCX, 128), ZMM(C3)) \
\
/* Scale C by beta(using pre-existing macro) */ \
/* Assume ZMM(C1) = { Cr0, Ci0, ... } */ \
/* ZMM(C1) = { Cr0.betar - Ci0.betai, Ci0.betar + Cr0.betai, ... }
Similarly done for other registers */ \
ALPHA_DEFAULT(C1, 28, C2, 29, C3, 30) \
\
/* Add beta*C to the result of alpha*A*B */ \
VADDPS(ZMM(C1), ZMM(R1), ZMM(C1)) /* ZMM(C1) = ZMM(R1) + ZMM(C1) */ \
VADDPS(ZMM(C2), ZMM(R2), ZMM(C2)) /* ZMM(C2) = ZMM(R2) + ZMM(C2) */ \
VADDPS(ZMM(C3), ZMM(R3), ZMM(C3)) /* ZMM(C3) = ZMM(R2) + ZMM(C3) */ \
\
/* Store the results onto C */ \
VMOVUPS(ZMM(C1), MEM(RCX)) \
VMOVUPS(ZMM(C2), MEM(RCX, 64)) \
VMOVUPS(ZMM(C3), MEM(RCX, 128)) \
/* Macro to perform 8x8 transpose of 64-bit elements */
/* Transpose is in-place(R0-R7), T0-T7 are temporary registers */
#define TRANSPOSE_8X8(R0, R1, R2, R3, R4, R5, R6, R7, \
T0, T1, T2, T3, T4, T5, T6, T7) \
/*
Let's consider the following case:
ZMM(R0) = { 0, 1, 2, 3, 4, 5, 6, 7 }
ZMM(R1) = { 8, 9, 10, 11, 12, 13, 14, 15 }
.
.
.
ZMM(R7) = { 56, 57, 58, 59, 60, 61, 62, 63 }
Expected output:
ZMM(R0) = { 0, 8, 16, 24, 32, 40, 48, 56 }
ZMM(R1) = { 1, 9, 17, 25, 33, 41, 49, 57 }
.
.
.
ZMM(R7) = { 7, 15, 23, 31, 39, 47, 55, 63 }.
*/ \
/* Inputs : ZMM(R0) = { 0, 1, 2, 3, 4, 5, 6, 7 }
ZMM(R1) = { 8, 9, 10, 11, 12, 13, 14, 15 }
ZMM(R2) = { 16, 17, 18, 19, 20, 21, 22, 23 }
ZMM(R3) = { 24, 25, 26, 27, 28, 29, 30, 31 }
...
Outputs: ZMM(T0) = { 0, 8, 2, 10, 4, 12, 6, 14 }
ZMM(R1) = { 1, 9, 3, 11, 5, 13, 7, 15 }
ZMM(T2) = { 16, 24, 18, 26, 20, 28, 22, 30 }
ZMM(R3) = { 17, 25, 19, 27, 21, 29, 23, 31 }
... */ \
VUNPCKLPD(ZMM(R1), ZMM(R0), ZMM(T0)) \
VUNPCKHPD(ZMM(R1), ZMM(R0), ZMM(R1)) \
VUNPCKLPD(ZMM(R3), ZMM(R2), ZMM(T1)) \
VUNPCKHPD(ZMM(R3), ZMM(R2), ZMM(R3)) \
VUNPCKLPD(ZMM(R5), ZMM(R4), ZMM(T2)) \
VUNPCKHPD(ZMM(R5), ZMM(R4), ZMM(R5)) \
VUNPCKLPD(ZMM(R7), ZMM(R6), ZMM(T3)) \
VUNPCKHPD(ZMM(R7), ZMM(R6), ZMM(R7)) \
\
/* Moving the contents of temporary registers
to input registers for reuse */ \
/* Output: ZMM(R0) = { 0, 8, 2, 10, 4, 12, 6, 14 }
ZMM(R2) = { 16, 24, 18, 26, 20, 28, 22, 30 }
ZMM(R4) = { 32, 40, 34, 42, 36, 44, 38, 46 }
ZMM(R6) = { 48, 56, 50, 58, 52, 60, 54, 62 } */ \
VMOVAPD(ZMM(T0), ZMM(R0)) \
VMOVAPD(ZMM(T1), ZMM(R2)) \
VMOVAPD(ZMM(T2), ZMM(R4)) \
VMOVAPD(ZMM(T3), ZMM(R6)) \
\
/* Inputs : ZMM(R0) = { 0, 8, 2, 10, 4, 12, 6, 14 }
ZMM(R2) = { 16, 24, 18, 26, 20, 28, 22, 30 }
ZMM(R4) = { 32, 40, 34, 42, 36, 44, 38, 46 }
ZMM(R6) = { 48, 56, 50, 58, 52, 60, 54, 62 }
Outputs : ZMM(T0) = { 0, 8, 4, 12, 16, 24, 20, 28 }
ZMM(T1) = { 32, 40, 36, 44, 48, 56, 52, 60 }
ZMM(T2) = { 2, 10, 6, 14, 18, 26, 22, 30 }
ZMM(T3) = { 34, 42, 38, 46, 50, 58, 54, 62 } */ \
VSHUFF64X2(IMM(0x88), ZMM(R2), ZMM(R0), ZMM(T0)) \
VSHUFF64X2(IMM(0x88), ZMM(R6), ZMM(R4), ZMM(T1)) \
VSHUFF64X2(IMM(0xDD), ZMM(R2), ZMM(R0), ZMM(T2)) \
VSHUFF64X2(IMM(0xDD), ZMM(R6), ZMM(R4), ZMM(T3)) \
\
/* Inputs : ZMM(R1) = { 1, 9, 3, 11, 5, 13, 7, 15 }
ZMM(R3) = { 17, 25, 19, 27, 21, 29, 23, 31 }
ZMM(R5) = { 33, 41, 35, 43, 37, 45, 39, 47 }
ZMM(R7) = { 49, 57, 51, 59, 53, 61, 55, 63 }
Outputs : ZMM(T4) = { 1, 9, 5, 13, 17, 25, 21, 29 }
ZMM(T5) = { 33, 41, 37, 45, 49, 57, 53, 61 }
ZMM(T6) = { 3, 11, 7, 15, 19, 27, 23, 31 }
ZMM(T7) = { 35, 43, 39, 47, 51, 59, 55, 63 } */ \
VSHUFF64X2(IMM(0x88), ZMM(R3), ZMM(R1), ZMM(T4)) \
VSHUFF64X2(IMM(0x88), ZMM(R7), ZMM(R5), ZMM(T5)) \
VSHUFF64X2(IMM(0xDD), ZMM(R3), ZMM(R1), ZMM(T6)) \
VSHUFF64X2(IMM(0xDD), ZMM(R7), ZMM(R5), ZMM(T7)) \
\
/* Inputs : ZMM(T0) = { 0, 8, 4, 12, 16, 24, 20, 28 }
ZMM(T1) = { 32, 40, 36, 44, 48, 56, 52, 60 }
ZMM(T2) = { 2, 10, 6, 14, 18, 26, 22, 30 }
ZMM(T3) = { 34, 42, 38, 46, 50, 58, 54, 62 }
Outputs : ZMM(R0) = { 0, 8, 16, 24, 32, 40, 48, 56 }
ZMM(R2) = { 2, 10, 18, 26, 34, 42, 50, 58 }
ZMM(R4) = { 4, 12, 20, 28, 36, 44, 52, 60 }
ZMM(R6) = { 6, 14, 22, 30, 38, 46, 54, 62 } */ \
VSHUFF64X2(IMM(0x88), ZMM(T1), ZMM(T0), ZMM(R0)) \
VSHUFF64X2(IMM(0x88), ZMM(T3), ZMM(T2), ZMM(R2)) \
VSHUFF64X2(IMM(0xDD), ZMM(T1), ZMM(T0), ZMM(R4)) \
VSHUFF64X2(IMM(0xDD), ZMM(T3), ZMM(T2), ZMM(R6)) \
\
/* Inputs : ZMM(T4) = { 1, 9, 5, 13, 17, 25, 21, 29 }
ZMM(T5) = { 33, 41, 37, 45, 49, 57, 53, 61 }
ZMM(T6) = { 3, 11, 7, 15, 19, 27, 23, 31 }
ZMM(T7) = { 35, 43, 39, 47, 51, 59, 55, 63 }
Outputs : ZMM(R1) = { 1, 9, 17, 25, 33, 41, 49, 57 }
ZMM(R3) = { 3, 11, 19, 27, 35, 43, 51, 59 }
ZMM(R5) = { 5, 13, 21, 29, 37, 45, 53, 61 }
ZMM(R7) = { 7, 15, 23, 31, 39, 47, 55, 63 } */ \
VSHUFF64X2(IMM(0x88), ZMM(T5), ZMM(T4), ZMM(R1)) \
VSHUFF64X2(IMM(0x88), ZMM(T7), ZMM(T6), ZMM(R3)) \
VSHUFF64X2(IMM(0xDD), ZMM(T5), ZMM(T4), ZMM(R5)) \
VSHUFF64X2(IMM(0xDD), ZMM(T7), ZMM(T6), ZMM(R7)) \
/* Macro to scale a register with beta, with secondary storage */
/* Macro uses C1 to load C, and R1 contains the result of alpha*A*B */
/* Macro assumes that ZMM(1) and ZMM(2) have real and imag components
of beta already broadcasted */
#define BETA_DEFAULT_SECONDARY(R1, C1, T1) \
/* Load C onto the register*/ \
VMOVUPS(MEM(RCX), YMM(C1)) \
\
/* Scale C by beta */ \
/* Assume ZMM(C1) = { Cr0, Ci0, ... } */ \
/* YMM(C1) = { Cr0.betar - Ci0.betai, Ci0.betar + Cr0.betai, ... } */ \
VMULPS(YMM(2), YMM(C1), YMM(T1)) \
VMULPS(YMM(1), YMM(C1), YMM(C1)) \
VPERMILPS(IMM(0xB1), YMM(T1), YMM(T1)) \
VFMADDSUB132PS(YMM(0), YMM(T1), YMM(C1)) \
\
/* Add beta*C to the result of alpha*A*B */ \
VADDPS(YMM(C1), YMM(R1), YMM(C1)) /* ZMM(C1) = ZMM(R1) + ZMM(C1) */ \
\
/* Store the result onto C */ \
VMOVUPS(YMM(C1), MEM(RCX)) \
/* Macro to scale a register with beta, with general strides */
/* Macro gets alpha*A*B in R1, R2, R3 for a column */
/* Macro uses ZMM(15) - ZMM(17) to gather C */
/* Macro uses ZMM(18) - ZMM(20) as temmporary registers */
/* Macro assumes that ZMM(28) - ZMM(30) have the addresses of C for
gather/scatter(one column at a time) */
/* Macro assumes that ZMM(1) and ZMM(2) have real and imag components
of beta already broadcasted */
#define BETA_DEFAULT_GENERAL(R1, R2, R3) \
/* Set the masks to 1 */ \
KXNORW(K(0), K(0), K(1)) \
KXNORW(K(0), K(0), K(2)) \
KXNORW(K(0), K(0), K(3)) \
\
/* Gather elements from C, one column at a time */ \
VGATHERQPD(MEM(RCX, ZMM(28), 1), ZMM(15) MASK_K(1)) \
VGATHERQPD(MEM(RCX, ZMM(29), 1), ZMM(16) MASK_K(2)) \
VGATHERQPD(MEM(RCX, ZMM(30), 1), ZMM(17) MASK_K(3)) \
\
/* Scale C by beta */ \
/* Assume ZMM(15) = { Cr0, Ci0, ... } */ \
/* ZMM(15) = { Cr0.betar - Ci0.betai, Ci0.betar + Cr0.betai, ... } */ \
VMULPS(ZMM(2), ZMM(15), ZMM(18)) \
VMULPS(ZMM(1), ZMM(15), ZMM(15)) \
VPERMILPS(IMM(0xB1), ZMM(18), ZMM(18)) \
VFMADDSUB132PS(ZMM(0), ZMM(18), ZMM(15)) \
\
/* Scale C by beta */ \
/* Assume ZMM(16) = { Cr1, Ci1, ... } */ \
/* ZMM(16) = { Cr1.betar - Ci1.betai, Ci1.betar + Cr1.betai, ... } */ \
VMULPS(ZMM(2), ZMM(16), ZMM(19)) \
VMULPS(ZMM(1), ZMM(16), ZMM(16)) \
VPERMILPS(IMM(0xB1), ZMM(19), ZMM(19)) \
VFMADDSUB132PS(ZMM(0), ZMM(19), ZMM(16)) \
\
/* Scale C by beta */ \
/* Assume ZMM(17) = { Cr2, Ci2, ... } */ \
/* ZMM(17) = { Cr2.betar - Ci2.betai, Ci2.betar + Cr2.betai, ... } */ \
VMULPS(ZMM(2), ZMM(17), ZMM(20)) \
VMULPS(ZMM(1), ZMM(17), ZMM(17)) \
VPERMILPS(IMM(0xB1), ZMM(20), ZMM(20)) \
VFMADDSUB132PS(ZMM(0), ZMM(20), ZMM(17)) \
\
/* Add beta*C to the result of alpha*A*B */ \
VADDPS(ZMM(15), ZMM(R1), ZMM(15)) /* ZMM(15) = ZMM(R1) + ZMM(15) */ \
VADDPS(ZMM(16), ZMM(R2), ZMM(16)) /* ZMM(16) = ZMM(R2) + ZMM(16) */ \
VADDPS(ZMM(17), ZMM(R3), ZMM(17)) /* ZMM(17) = ZMM(R3) + ZMM(17) */ \
\
/* Reset the mask to 1 */ \
KXNORW(K(0), K(0), K(1)) \
KXNORW(K(0), K(0), K(2)) \
KXNORW(K(0), K(0), K(3)) \
\
/* Scatter the result to C, one column at a time */ \
VSCATTERQPD(ZMM(15), MEM(RCX, ZMM(28), 1) MASK_K(1)) \
VSCATTERQPD(ZMM(16), MEM(RCX, ZMM(29), 1) MASK_K(2)) \
VSCATTERQPD(ZMM(17), MEM(RCX, ZMM(30), 1) MASK_K(3)) \
/* Macro to store alpha*A*B onto C, with general strides */
/* Macro gets alpha*A*B in R1, R2, R3 for a column */
/* Macro assumes that ZMM(28) - ZMM(30) have the addresses of C for
scatter(one column at a time) */
#define BETA_ZERO_GENERAL(R1, R2, R3) \
/* Set the masks to 1 */ \
KXNORW(K(0), K(0), K(1)) \
KXNORW(K(0), K(0), K(2)) \
KXNORW(K(0), K(0), K(3)) \
\
/* Scatter the result to C, one column at a time */ \
VSCATTERQPD(ZMM(R1), MEM(RCX, ZMM(28), 1) MASK_K(1)) \
VSCATTERQPD(ZMM(R2), MEM(RCX, ZMM(29), 1) MASK_K(2)) \
VSCATTERQPD(ZMM(R3), MEM(RCX, ZMM(30), 1) MASK_K(3))

View File

@@ -168,6 +168,8 @@ PACKM_KER_PROT( double, d, packm_zen4_asm_8xk )
PACKM_KER_PROT( double, d, packm_zen4_asm_24xk )
PACKM_KER_PROT( double, d, packm_zen4_asm_32xk )
PACKM_KER_PROT( double, d, packm_32xk_zen4_ref )
PACKM_KER_PROT( scomplex, c, packm_zen4_asm_24xk )
PACKM_KER_PROT( scomplex, c, packm_zen4_asm_4xk )
PACKM_KER_PROT( dcomplex, z, packm_zen4_asm_12xk )
PACKM_KER_PROT( dcomplex, z, packm_zen4_asm_4xk )
@@ -176,6 +178,8 @@ GEMM_UKR_PROT( double, d, gemm_avx512_asm_8x24 )
GEMM_UKR_PROT( double, d, gemm_zen4_asm_32x6 )
GEMM_UKR_PROT( dcomplex, z, gemm_zen4_asm_12x4 )
GEMM_UKR_PROT( dcomplex, z, gemm_zen4_asm_4x12 )
GEMM_UKR_PROT( scomplex, c, gemm_zen4_asm_24x4 )
GEMM_UKR_PROT( scomplex, c, gemm_zen4_asm_4x24 )
// dgemm native macro kernel
void bli_dgemm_avx512_asm_8x24_macro_kernel