Enabled DGEMM row major kernel for ZEN4

- Merged ZEN4 and ZEN5 DGEMM 8x24 kernel.
- Replaced 32x6 kernel with 8x24. Now same
  kernel is used for ZEN4 and ZEN5.
- Blocksizes have been tuned for genoa only.
- DGEMM kernel for DTRSM native code path
  is replaced with 8x24 kernel.
- Enabled alpha scaling during packing for ZEN4.
- ZEN4 8x24 kernel has been removed.

AMD-Internal: [CPUPL-5912]
Change-Id: I89a16a7e3355af037d21d453aabf53c5ecccb754
This commit is contained in:
Shubham Sharma
2024-11-29 08:18:48 +00:00
parent 082081658f
commit f2320a1fef
10 changed files with 37 additions and 754 deletions

View File

@@ -41,22 +41,22 @@
#define BLI_CNTX_DEFAULT_BLKSZ_LIST_GENOA(blkszs) \
/* s d c z */ \
bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 32, 3, 12 ); \
bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 6, 8, 4 ); \
bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 128, 144, 60 ); \
bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 8, 3, 12 ); \
bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 24, 8, 4 ); \
bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 120, 144, 60 ); \
bli_blksz_init_easy( &blkszs[ BLIS_KC ], 480, 512, 256, 512 ); \
bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 4002, 4080, 2004 ); \
bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 2016, 4080, 2004 ); \
\
bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); \
bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 );
#define BLI_CNTX_DEFAULT_BLKSZ_LIST_BERGAMO(blkszs) \
/* s d c z */ \
bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 32, 3, 12 ); \
bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 6, 8, 4 ); \
bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 64, 144, 60 ); \
bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 8, 3, 12 ); \
bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 24, 8, 4 ); \
bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 120, 144, 60 ); \
bli_blksz_init_easy( &blkszs[ BLIS_KC ], 480, 512, 256, 512 ); \
bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 3600, 4080, 2004 ); \
bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 2016, 4080, 2004 ); \
\
bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); \
bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 );
@@ -78,14 +78,14 @@ void bli_cntx_init_zen4( cntx_t* cntx )
13,
// gemm
BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_skx_asm_32x12_l2, FALSE,
BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_zen4_asm_32x6, FALSE,
BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_avx512_asm_8x24, TRUE,
BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE,
/*bli_zgemm_zen4_asm_12x4 is a column preferred kernel*/
BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_zen4_asm_12x4, FALSE,
// Different GEMM kernels are used for TRSM for zen4 architecture
BLIS_GEMM_FOR_TRSM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE,
BLIS_GEMM_FOR_TRSM_UKR, BLIS_DOUBLE, bli_dgemm_zen4_asm_8x24, TRUE,
BLIS_GEMM_FOR_TRSM_UKR, BLIS_DOUBLE, bli_dgemm_avx512_asm_8x24, TRUE,
BLIS_GEMM_FOR_TRSM_UKR, BLIS_DCOMPLEX, bli_zgemm_zen4_asm_4x12, TRUE,
// gemmtrsm_l

View File

@@ -87,7 +87,7 @@ void bli_cntx_init_zen5( cntx_t* cntx )
// Different GEMM kernels are used for TRSM for zen4 architecture
BLIS_GEMM_FOR_TRSM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE,
BLIS_GEMM_FOR_TRSM_UKR, BLIS_DOUBLE, bli_dgemm_zen4_asm_8x24, TRUE,
BLIS_GEMM_FOR_TRSM_UKR, BLIS_DOUBLE, bli_dgemm_avx512_asm_8x24, TRUE,
BLIS_GEMM_FOR_TRSM_UKR, BLIS_DCOMPLEX, bli_zgemm_zen4_asm_4x12, TRUE,
// gemmtrsm_l

View File

@@ -183,12 +183,13 @@ void bli_packm_blk_var1
buf_kappa = bli_obj_buffer_for_1x1( dt_p, kappa_p );
}
#ifdef BLIS_KERNELS_ZEN5
// For DGEMM in ZEN5, scale by alpha during packing
#ifdef BLIS_KERNELS_ZEN4
// For DGEMM in AVX512, scale by alpha during packing
if
(
( bli_obj_dt( p ) == BLIS_DOUBLE ) &&
( bli_arch_query_id() == BLIS_ARCH_ZEN5 )
( ( bli_arch_query_id() == BLIS_ARCH_ZEN5 ) ||
( bli_arch_query_id() == BLIS_ARCH_ZEN4 ) )
)
{
bli_obj_scalar_detach( p, &kappa );

View File

@@ -171,15 +171,16 @@ void bli_gemm_ker_var2
// function pointer.
f = ftypes[dt_exec];
#ifdef BLIS_KERNELS_ZEN5
#ifdef BLIS_KERNELS_ZEN4
// Optimized macro kernel is avaible for DGEMM
// for ZEN5. Only row major stored C is supported.
// for AVX512. Only row major stored C is supported.
// TODO: Add macro kernel function pointer in cntx
if
(
( bli_obj_dt( c ) == BLIS_DOUBLE ) &&
( bli_arch_query_id() == BLIS_ARCH_ZEN5 ) &&
( ( bli_arch_query_id() == BLIS_ARCH_ZEN5 ) ||
( bli_arch_query_id() == BLIS_ARCH_ZEN4 ) ) &&
( cs_c == 1 ) && // use this kernel only for row major C
// use generic macro kernel for mixed precision
( bli_obj_elem_size( a ) == 8 ) && // check if elem_sizeof(a) == sizeof(double)

View File

@@ -143,7 +143,7 @@
#define K_bli_ddotv_zen_int_avx512 1
#define K_bli_dgemm_haswell_asm_6x8 1
#define K_bli_dgemm_zen4_asm_32x6 1
#define K_bli_dgemm_zen4_asm_8x24 1
#define K_bli_dgemm_avx512_asm_8x24 1
#define K_bli_dgemmsup_rd_haswell_asm_6x8m 1
#define K_bli_dgemmsup_rd_haswell_asm_6x8n 1
#define K_bli_dgemmsup_rv_haswell_asm_6x8m 1

View File

@@ -466,9 +466,9 @@ INSTANTIATE_TEST_SUITE_P(
);
#endif
#ifdef K_bli_dgemm_zen4_asm_8x24
#ifdef K_bli_dgemm_avx512_asm_8x24
INSTANTIATE_TEST_SUITE_P(
bli_dgemm_zen4_asm_8x24,
bli_dgemm_avx512_asm_8x24,
dgemmGenericNat,
::testing::Combine(
::testing::Range(gtint_t(0), gtint_t(17), 1), // values of k
@@ -477,7 +477,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values('r', 'c'), // storage
::testing::Values(8), // values of m
::testing::Values(24), // values of n
::testing::Values(bli_dgemm_zen4_asm_8x24),
::testing::Values(bli_dgemm_avx512_asm_8x24),
::testing::Values(true, false) // memory test
),
::dgemmGenericNatPrint()

View File

@@ -1,717 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#include "bli_x86_asm_macros.h"
// BLIS_ASM_SYNTAX_INTEL syntax is followed
#define TAIL_NITER 5
#define LOOP_ALIGN ALIGN32
// Update C when C is general stored
#define UPDATE_C_SCATTERED(R1,R2,R3) \
\
KXNORW(K(1), K(0), K(0)) \
KXNORW(K(2), K(0), K(0)) \
KXNORW(K(3), K(0), K(0)) \
VGATHERQPD(ZMM(0) MASK_K(1), MEM(RCX,ZMM(2),1)) \
/*scale by beta*/ \
VFMADD231PD(ZMM(R1), ZMM(0), ZMM(1)) \
VGATHERQPD(ZMM(0) MASK_K(2), MEM(RCX,ZMM(3),1)) \
VFMADD231PD(ZMM(R2), ZMM(0), ZMM(1)) \
VGATHERQPD(ZMM(0) MASK_K(3), MEM(RCX,ZMM(4),1)) \
VFMADD231PD(ZMM(R3), ZMM(0), ZMM(1)) \
KXNORW(K(1), K(0), K(0)) \
KXNORW(K(2), K(0), K(0)) \
KXNORW(K(3), K(0), K(0)) \
/*store c*/ \
VSCATTERQPD(MEM(RCX,ZMM(2),1) MASK_K(1), ZMM(R1)) \
VSCATTERQPD(MEM(RCX,ZMM(3),1) MASK_K(2), ZMM(R2)) \
VSCATTERQPD(MEM(RCX,ZMM(4),1) MASK_K(3), ZMM(R3)) \
LEA(RCX, MEM(RCX,R12,1))
// Update C when C is general stored and beta = 0
#define UPDATE_C_SCATTERED_BZ(R1,R2,R3) \
\
KXNORW(K(1), K(0), K(0)) \
KXNORW(K(2), K(0), K(0)) \
KXNORW(K(3), K(0), K(0)) \
VSCATTERQPD(MEM(RCX,ZMM(2),1) MASK_K(1), ZMM(R1)) \
VSCATTERQPD(MEM(RCX,ZMM(3),1) MASK_K(2), ZMM(R2)) \
VSCATTERQPD(MEM(RCX,ZMM(4),1) MASK_K(3), ZMM(R3)) \
LEA(RCX, MEM(RCX,R12,1))
// 8x8 in register transpose, used for column stored C
#define TRANSPOSE_8X8(R0, R1, R2, R3, R4, R5, R6, R7) \
\
VUNPCKLPD(ZMM(6), ZMM(R0), ZMM(R1)) \
VUNPCKLPD(ZMM(7), ZMM(R2), ZMM(R3)) \
VUNPCKLPD(ZMM(2), ZMM(R4), ZMM(R5)) \
VUNPCKLPD(ZMM(3), ZMM(R6), ZMM(R7)) \
VMOVUPD(ZMM(0), ZMM(R0)) \
VMOVUPD(ZMM(1), ZMM(R4)) \
/*Stage2*/ \
VSHUFF64X2(ZMM(4), ZMM(6), ZMM(7), IMM(0x88)) \
VSHUFF64X2(ZMM(5), ZMM(2), ZMM(3), IMM(0x88)) \
/*Stage3 1,5*/ \
VSHUFF64X2(ZMM(R0), ZMM(4), ZMM(5), IMM(0x88)) \
VSHUFF64X2(ZMM(R4), ZMM(4), ZMM(5), IMM(0xDD)) \
/*Stage2*/ \
VSHUFF64X2(ZMM(4), ZMM(6), ZMM(7), IMM(0xDD)) \
VSHUFF64X2(ZMM(5), ZMM(2), ZMM(3), IMM(0xDD)) \
/*Stage3 3,7*/ \
VUNPCKHPD(ZMM(6), ZMM(0 ), ZMM(R1)) \
VUNPCKHPD(ZMM(7), ZMM(R2), ZMM(R3)) \
VUNPCKHPD(ZMM(2), ZMM(1 ), ZMM(R5)) \
VUNPCKHPD(ZMM(3), ZMM(R6), ZMM(R7)) \
VSHUFF64X2(ZMM(R2), ZMM(4), ZMM(5), IMM(0x88)) \
VSHUFF64X2(ZMM(R6), ZMM(4), ZMM(5), IMM(0xDD)) \
\
/*Stage2*/ \
VSHUFF64X2(ZMM(4), ZMM(6), ZMM(7), IMM(0x88)) \
VSHUFF64X2(ZMM(5), ZMM(2), ZMM(3), IMM(0x88)) \
/*Stage3 2,6*/ \
VSHUFF64X2(ZMM(R1), ZMM(4), ZMM(5), IMM(0x88)) \
VSHUFF64X2(ZMM(R5), ZMM(4), ZMM(5), IMM(0xDD)) \
/*Stage2*/ \
VSHUFF64X2(ZMM(4), ZMM(6), ZMM(7), IMM(0xDD)) \
VSHUFF64X2(ZMM(5), ZMM(2), ZMM(3), IMM(0xDD)) \
/*Stage3 4,8*/ \
VSHUFF64X2(ZMM(R3), ZMM(4), ZMM(5), IMM(0x88)) \
VSHUFF64X2(ZMM(R7), ZMM(4), ZMM(5), IMM(0xDD)) \
// Update C when C is column stored
#define UPDATE_C_COL_STORE(R0, R1, R2, R3, R4, R5, R6, R7) \
\
/* scale by alpha */\
VMULPD(ZMM(R0), ZMM(R0), ZMM(0)) \
VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
VMULPD(ZMM(R5), ZMM(R5), ZMM(0)) \
VMULPD(ZMM(R6), ZMM(R6), ZMM(0)) \
VMULPD(ZMM(R7), ZMM(R7), ZMM(0)) \
/*scale by beta*/\
VFMADD231PD(ZMM(R0), ZMM(1), MEM(RCX)) \
/*store c*/ \
VMOVUPD(MEM(RCX), ZMM(R0)) \
VFMADD231PD(ZMM(R1), ZMM(1), MEM(RCX, R10, 1)) \
VMOVUPD(MEM(RCX, R10, 1), ZMM(R1)) \
VFMADD231PD(ZMM(R2), ZMM(1), MEM(RCX, R10, 2)) \
VMOVUPD(MEM(RCX, R10, 2), ZMM(R2)) \
VFMADD231PD(ZMM(R3), ZMM(1), MEM(RCX, R11, 1)) \
VMOVUPD(MEM(RCX, R11, 1), ZMM(R3)) \
VFMADD231PD(ZMM(R4), ZMM(1), MEM(RCX, R10, 4)) \
VMOVUPD(MEM(RCX, R10, 4), ZMM(R4)) \
VFMADD231PD(ZMM(R5), ZMM(1), MEM(RCX, R12, 1)) \
VMOVUPD(MEM(RCX, R12, 1), ZMM(R5)) \
VFMADD231PD(ZMM(R6), ZMM(1), MEM(RCX, R11, 2)) \
VMOVUPD(MEM(RCX, R11, 2), ZMM(R6)) \
VFMADD231PD(ZMM(R7), ZMM(1), MEM(RCX, R13, 1)) \
VMOVUPD(MEM(RCX, R13, 1), ZMM(R7)) \
LEA(RCX, MEM(RCX,R10,8))
// Update C when C is column stored and beta = 0
#define UPDATE_C_COL_STORE_BZ(R0, R1, R2, R3, R4, R5, R6, R7) \
/* scale by alpha */\
VMULPD(ZMM(R0), ZMM(R0), ZMM(0)) \
VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
VMULPD(ZMM(R5), ZMM(R5), ZMM(0)) \
VMULPD(ZMM(R6), ZMM(R6), ZMM(0)) \
VMULPD(ZMM(R7), ZMM(R7), ZMM(0)) \
/*store c*/ \
VMOVUPD(MEM(RCX), ZMM(R0)) \
VMOVUPD(MEM(RCX, R10, 1), ZMM(R1)) /*R10 = cs_c*/ \
VMOVUPD(MEM(RCX, R10, 2), ZMM(R2)) \
VMOVUPD(MEM(RCX, R11, 1), ZMM(R3)) /*R11 = 3*cs_c*/\
VMOVUPD(MEM(RCX, R10, 4), ZMM(R4)) \
VMOVUPD(MEM(RCX, R12, 1), ZMM(R5)) /*R12 = 5*cs_c*/\
VMOVUPD(MEM(RCX, R11, 2), ZMM(R6)) \
VMOVUPD(MEM(RCX, R13, 1), ZMM(R7)) /*R13 = 7*cs_c*/\
LEA(RCX, MEM(RCX,R10,8))
#define SUBITER(n) \
\
VBROADCASTSD(ZMM(6), MEM(RAX,(8*n+ 0)*8)) \
VBROADCASTSD(ZMM(7), MEM(RAX,(8*n+ 1)*8)) \
VFMADD231PD(ZMM( 8), ZMM(0), ZMM(6)) \
VFMADD231PD(ZMM( 9), ZMM(1), ZMM(6)) \
VFMADD231PD(ZMM(10), ZMM(2), ZMM(6)) \
\
VBROADCASTSD(ZMM(6), MEM(RAX,(8*n+ 2)*8)) \
VFMADD231PD(ZMM(11), ZMM(0), ZMM(7)) \
VFMADD231PD(ZMM(12), ZMM(1), ZMM(7)) \
VFMADD231PD(ZMM(13), ZMM(2), ZMM(7)) \
\
VBROADCASTSD(ZMM(7), MEM(RAX,(8*n+ 3)*8)) \
VFMADD231PD(ZMM(14), ZMM(0), ZMM(6)) \
VFMADD231PD(ZMM(15), ZMM(1), ZMM(6)) \
VFMADD231PD(ZMM(16), ZMM(2), ZMM(6)) \
\
VBROADCASTSD(ZMM(6), MEM(RAX,(8*n+ 4)*8)) \
VFMADD231PD(ZMM(17), ZMM(0), ZMM(7)) \
VFMADD231PD(ZMM(18), ZMM(1), ZMM(7)) \
VFMADD231PD(ZMM(19), ZMM(2), ZMM(7)) \
\
VBROADCASTSD(ZMM(7), MEM(RAX,(8*n+ 5)*8)) \
VFMADD231PD(ZMM(20), ZMM(0), ZMM(6)) \
VFMADD231PD(ZMM(21), ZMM(1), ZMM(6)) \
VFMADD231PD(ZMM(22), ZMM(2), ZMM(6)) \
\
VBROADCASTSD(ZMM(6), MEM(RAX,(8*n+ 6)*8)) \
VFMADD231PD(ZMM(23), ZMM(0), ZMM(7)) \
VFMADD231PD(ZMM(24), ZMM(1), ZMM(7)) \
VFMADD231PD(ZMM(25), ZMM(2), ZMM(7)) \
\
VBROADCASTSD(ZMM(7), MEM(RAX,(8*n+ 7)*8)) \
VFMADD231PD(ZMM(26), ZMM(0), ZMM(6)) \
VFMADD231PD(ZMM(27), ZMM(1), ZMM(6)) \
VFMADD231PD(ZMM(28), ZMM(2), ZMM(6)) \
\
VFMADD231PD(ZMM(29), ZMM(0), ZMM(7)) \
VFMADD231PD(ZMM(30), ZMM(1), ZMM(7)) \
VFMADD231PD(ZMM(31), ZMM(2), ZMM(7)) \
\
VMOVAPD(ZMM(0), MEM(RBX,(24*n+0)*8)) \
VMOVAPD(ZMM(1), MEM(RBX,(24*n+8)*8)) \
VMOVAPD(ZMM(2), MEM(RBX,(24*n+16)*8)) \
//This is an array used for the scatter/gather instructions.
static int64_t offsets[24] __attribute__((aligned(64))) =
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17,18,19,20,21,22,23};
/*
* number of accumulation registers = 24/8 * 8 = 24 zmm8 to zmm31
* number of registers used for load B = 24/8 = 3 zmm0 to zmm2
* number of registers used for broadcast A = 2 zmm6 and zmm7
*/
void bli_dgemm_zen4_asm_8x24(
dim_t k_,
double* restrict alpha,
double* restrict a,
double* restrict b,
double* restrict beta,
double* restrict c, inc_t rs_c_, inc_t cs_c_,
auxinfo_t* data,
cntx_t* restrict cntx
)
{
(void)data;
(void)cntx;
const int64_t* offsetPtr = &offsets[0];
const int64_t k = k_;
const int64_t rs_c = rs_c_*8; //convert strides to bytes
const int64_t cs_c = cs_c_*8; //convert strides to bytes
BEGIN_ASM()
VXORPD(YMM(8) , YMM(8), YMM(8))
VXORPD(YMM(9) , YMM(9), YMM(9))
VXORPD(YMM(10), YMM(10), YMM(10)) //clear out registers
VXORPD(YMM(11), YMM(11), YMM(11)) //clear out registers
VMOVAPD(YMM(12), YMM(8))
VMOVAPD(YMM(13), YMM(8))
VMOVAPD(YMM(14), YMM(8))
VMOVAPD(YMM(15), YMM(8))
VXORPD(YMM(16), YMM(16), YMM(16)) //clear out registers
VXORPD(YMM(17), YMM(17), YMM(17)) //clear out registers
VMOVAPD(YMM(18), YMM(8))
VMOVAPD(YMM(19), YMM(8))
VMOVAPD(YMM(20), YMM(8))
VMOVAPD(YMM(21), YMM(8))
VXORPD(YMM(22), YMM(22), YMM(22)) //clear out registers
VXORPD(YMM(23), YMM(23), YMM(23)) //clear out registers
VMOVAPD(YMM(24), YMM(8))
VMOVAPD(YMM(25), YMM(8))
VMOVAPD(YMM(26), YMM(8))
VMOVAPD(YMM(27), YMM(8))
VXORPD(YMM(28), YMM(28), YMM(28)) //clear out registers
VXORPD(YMM(29), YMM(29), YMM(29)) //clear out registers
VMOVAPD(YMM(30), YMM(8))
VMOVAPD(YMM(31), YMM(8))
MOV(RSI, VAR(k)) //loop index
MOV(RAX, VAR(a)) //load address of a
MOV(RBX, VAR(b)) //load address of b
MOV(RCX, VAR(c)) //load address of c
LEA(R9, MEM(RCX,63)) // c for prefetching
VMOVAPD(ZMM(0), MEM(RBX, 0*8)) //pre-load b
VMOVAPD(ZMM(1), MEM(RBX, 8*8)) //pre-load b
VMOVAPD(ZMM(2), MEM(RBX,16*8)) //pre-load b
LEA(RBX, MEM(RBX,24*8)) //adjust b for pre-load
MOV(R12, VAR(rs_c))
MOV(R10, VAR(cs_c))
MOV(R11, IMM(8)) // prefetch loop count
// r11 = NR for row store
// r11 = MR for col store
MOV(R8, R12) // prefetch loop increment
// r8 = cs_c for row store
// r8 = rs_c for col store
MOV(R13, IMM(64)) // r13 = 0 for row store
// r13 = 64 for col store
CMP(R10, IMM(8)) // jmp if c row stor
JZ(POST_STRIDE)
MOV(R8 , R10) // r8 = cs_c - prefetch loop increment
MOV(R11, IMM(24)) // r11 = 24 - prefetch loop count
MOV(R13, IMM(0)) // r13 = 0
LABEL(POST_STRIDE)
MOV(RDI, RSI) // RDI = k
AND(RSI, IMM(3)) // RSI = k & 3, RSI = k % 4
SAR(RDI, IMM(2)) // RSI = k >> 2, RSI = k / 4
SUB(RDI, R11) // subtract prefetch loop count
SUB(RDI, IMM(0+TAIL_NITER)) // '0+' needed for preprocessor
JLE(K_LE_80)
LOOP_ALIGN
LABEL(LOOP1)
SUBITER(0)
SUBITER(1)
SUB(RDI, IMM(1))
SUBITER(2)
SUBITER(3)
LEA(RAX, MEM(RAX,4*8*8))
LEA(RBX, MEM(RBX,4*24*8))
JNZ(LOOP1)
LABEL(K_LE_80)
ADD(RDI, R11) // add prefetch loop count
JLE(K_LE_26)
LOOP_ALIGN
LABEL(LOOP2)
PREFETCH(0, MEM(R9))
SUBITER(0)
PREFETCH(0, MEM(R9,R13, 1)) // prefetch R9+64 if col store,
// prefetch R9+0 if row store
SUBITER(1)
SUB(RDI, IMM(1))
PREFETCH(0, MEM(R9,R13, 2)) // prefetch R9+128 if col store,
// prefetch R9+0 if row store
SUBITER(2)
SUBITER(3)
LEA(RAX, MEM(RAX,4*8*8))
LEA(RBX, MEM(RBX,4*24*8))
LEA(R9, MEM(R9,R8,1)) // r9 += rs_c if col store,
// r9 += cs_c if row store
JNZ(LOOP2)
LABEL(K_LE_26)
ADD(RDI, IMM(0+TAIL_NITER))
JLE(TAIL)
LOOP_ALIGN
LABEL(LOOP3)
SUBITER(0)
SUBITER(1)
SUB(RDI, IMM(1))
SUBITER(2)
SUBITER(3)
LEA(RAX, MEM(RAX,4*8*8))
LEA(RBX, MEM(RBX,4*24*8))
JNZ(LOOP3)
LABEL(TAIL)
TEST(RSI, RSI)
JZ(POSTACCUM)
LOOP_ALIGN
LABEL(TAIL_LOOP)
SUB(RSI, IMM(1))
SUBITER(0)
LEA(RAX, MEM(RAX,8*8))
LEA(RBX, MEM(RBX,24*8))
JNZ(TAIL_LOOP)
LABEL(POSTACCUM)
MOV(RAX, VAR(alpha))
MOV(RBX, VAR(beta))
VBROADCASTSD(ZMM(0), MEM(RAX)) // ZMM(0) = alpha
VBROADCASTSD(ZMM(1), MEM(RBX)) // zmm(1) = beta
VXORPD(YMM(2), YMM(2), YMM(2))
MOV(RAX, R12) // rs_c
MOV(RBX, R10) // cs_c
// Check if C is column stride.
CMP(RAX, IMM(8))
JE(COLUPDATE)
CMP(RBX, IMM(8))
JE(ROWUPDATE)
LABEL(SCATTERUPDATE)
// if C is general stride
VMULPD(ZMM( 8), ZMM( 8), ZMM(0)) //scale by alpha
VMULPD(ZMM( 9), ZMM( 9), ZMM(0))
VMULPD(ZMM(10), ZMM(10), ZMM(0))
VMULPD(ZMM(11), ZMM(11), ZMM(0))
VMULPD(ZMM(12), ZMM(12), ZMM(0))
VMULPD(ZMM(13), ZMM(13), ZMM(0))
VMULPD(ZMM(14), ZMM(14), ZMM(0))
VMULPD(ZMM(15), ZMM(15), ZMM(0))
VMULPD(ZMM(16), ZMM(16), ZMM(0))
VMULPD(ZMM(17), ZMM(17), ZMM(0))
VMULPD(ZMM(18), ZMM(18), ZMM(0))
VMULPD(ZMM(19), ZMM(19), ZMM(0))
VMULPD(ZMM(20), ZMM(20), ZMM(0))
VMULPD(ZMM(21), ZMM(21), ZMM(0))
VMULPD(ZMM(22), ZMM(22), ZMM(0))
VMULPD(ZMM(23), ZMM(23), ZMM(0))
VMULPD(ZMM(24), ZMM(24), ZMM(0))
VMULPD(ZMM(25), ZMM(25), ZMM(0))
VMULPD(ZMM(26), ZMM(26), ZMM(0))
VMULPD(ZMM(27), ZMM(27), ZMM(0))
VMULPD(ZMM(28), ZMM(28), ZMM(0))
VMULPD(ZMM(29), ZMM(29), ZMM(0))
VMULPD(ZMM(30), ZMM(30), ZMM(0))
VMULPD(ZMM(31), ZMM(31), ZMM(0))
MOV(RDI, VAR(offsetPtr))
VPBROADCASTQ(ZMM(0), R10)
VPMULLQ(ZMM(2), ZMM(0), MEM(RDI))
VPMULLQ(ZMM(3), ZMM(0), MEM(RDI, 8*8))
VPMULLQ(ZMM(4), ZMM(0), MEM(RDI,16*8))
VCOMISD(XMM(1), XMM(2))
JE(GENSTORBZ)
UPDATE_C_SCATTERED( 8, 9, 10) // scale by beta and store
UPDATE_C_SCATTERED(11, 12, 13)
UPDATE_C_SCATTERED(14, 15, 16)
UPDATE_C_SCATTERED(17, 18, 19)
UPDATE_C_SCATTERED(20, 21, 22)
UPDATE_C_SCATTERED(23, 24, 25)
UPDATE_C_SCATTERED(26, 27, 28)
UPDATE_C_SCATTERED(29, 30, 31)
JMP(END)
LABEL(GENSTORBZ)
UPDATE_C_SCATTERED_BZ( 8, 9, 10)
UPDATE_C_SCATTERED_BZ(11, 12, 13)
UPDATE_C_SCATTERED_BZ(14, 15, 16)
UPDATE_C_SCATTERED_BZ(17, 18, 19)
UPDATE_C_SCATTERED_BZ(20, 21, 22)
UPDATE_C_SCATTERED_BZ(23, 24, 25)
UPDATE_C_SCATTERED_BZ(26, 27, 28)
UPDATE_C_SCATTERED_BZ(29, 30, 31)
JMP(END)
LABEL(ROWUPDATE)
// if C is row stride
// R12 = rs_c
LEA(R11, MEM(R12, R12, 2)) // R11 = rs_c * 3, R11 = rs_c + rs_c * 2
LEA(R13, MEM(R12, R11, 2)) // R13 = rs_c * 7, R13 = rs_c + R11 * 2
LEA(R12, MEM(R12, R12, 4)) // R12 = rs_c * 5, R12 = rs_c + rs_c * 4
VCOMISD(XMM(1), XMM(2)) // XMM(1) = beta, XMM(2) = 0
JE(ROWSTORBZ)
// beta != 0
// row0
VMULPD(ZMM( 8), ZMM( 8), ZMM(0)) // scale by alpha
VMULPD(ZMM( 9), ZMM( 9), ZMM(0))
VMULPD(ZMM(10), ZMM(10), ZMM(0))
/*scale by beta*/
VFMADD231PD(ZMM( 8), ZMM(1), MEM(RCX)) //zmm8 = zmm1*C + zmm8, zmm8 = beta*C + zmm8
VFMADD231PD(ZMM( 9), ZMM(1), MEM(RCX,64)) //zmm9 = beta*C + zmm9
VFMADD231PD(ZMM(10), ZMM(1), MEM(RCX,128)) //zmm10 = beta*C + zmm10
/*store c*/
VMOVUPD(MEM(RCX ), ZMM( 8))
VMOVUPD(MEM(RCX, 64), ZMM( 9))
VMOVUPD(MEM(RCX,128), ZMM(10))
// row1
VMULPD(ZMM(11), ZMM(11), ZMM(0))
VMULPD(ZMM(12), ZMM(12), ZMM(0))
VMULPD(ZMM(13), ZMM(13), ZMM(0))
/*scale by beta*/
VFMADD231PD(ZMM(11), ZMM(1), MEM(RCX, RAX, 1 ))
VFMADD231PD(ZMM(12), ZMM(1), MEM(RCX, RAX, 1, 64 ))
VFMADD231PD(ZMM(13), ZMM(1), MEM(RCX, RAX, 1, 128))
/*store c*/
VMOVUPD(MEM(RCX, RAX, 1 ), ZMM(11))
VMOVUPD(MEM(RCX, RAX, 1, 64 ), ZMM(12))
VMOVUPD(MEM(RCX, RAX, 1, 128), ZMM(13))
// row2
VMULPD(ZMM(14), ZMM(14), ZMM(0))
VMULPD(ZMM(15), ZMM(15), ZMM(0))
VMULPD(ZMM(16), ZMM(16), ZMM(0))
/*scale by beta*/
VFMADD231PD(ZMM(14), ZMM(1), MEM(RCX, RAX, 2 ))
VFMADD231PD(ZMM(15), ZMM(1), MEM(RCX, RAX, 2, 64 ))
VFMADD231PD(ZMM(16), ZMM(1), MEM(RCX, RAX, 2, 128))
/*store c*/
VMOVUPD(MEM(RCX, RAX, 2 ), ZMM(14))
VMOVUPD(MEM(RCX, RAX, 2, 64 ), ZMM(15))
VMOVUPD(MEM(RCX, RAX, 2, 128), ZMM(16))
// row3
VMULPD(ZMM(17), ZMM(17), ZMM(0))
VMULPD(ZMM(18), ZMM(18), ZMM(0))
VMULPD(ZMM(19), ZMM(19), ZMM(0))
/*scale by beta*/
VFMADD231PD(ZMM(17), ZMM(1), MEM(RCX, R11, 1 ))
VFMADD231PD(ZMM(18), ZMM(1), MEM(RCX, R11, 1, 64 ))
VFMADD231PD(ZMM(19), ZMM(1), MEM(RCX, R11, 1, 128))
/*store c*/
VMOVUPD(MEM(RCX, R11, 1 ), ZMM(17))
VMOVUPD(MEM(RCX, R11, 1, 64 ), ZMM(18))
VMOVUPD(MEM(RCX, R11, 1, 128), ZMM(19))
// row4
VMULPD(ZMM(20), ZMM(20), ZMM(0))
VMULPD(ZMM(21), ZMM(21), ZMM(0))
VMULPD(ZMM(22), ZMM(22), ZMM(0))
/*scale by beta*/
VFMADD231PD(ZMM(20), ZMM(1), MEM(RCX, RAX, 4 ))
VFMADD231PD(ZMM(21), ZMM(1), MEM(RCX, RAX, 4, 64 ))
VFMADD231PD(ZMM(22), ZMM(1), MEM(RCX, RAX, 4, 128))
/*store c*/
VMOVUPD(MEM(RCX, RAX, 4 ), ZMM(20))
VMOVUPD(MEM(RCX, RAX, 4, 64 ), ZMM(21))
VMOVUPD(MEM(RCX, RAX, 4, 128), ZMM(22))
// row5
VMULPD(ZMM(23), ZMM(23), ZMM(0))
VMULPD(ZMM(24), ZMM(24), ZMM(0))
VMULPD(ZMM(25), ZMM(25), ZMM(0))
/*scale by beta*/
VFMADD231PD(ZMM(23), ZMM(1), MEM(RCX, R12, 1 ))
VFMADD231PD(ZMM(24), ZMM(1), MEM(RCX, R12, 1, 64 ))
VFMADD231PD(ZMM(25), ZMM(1), MEM(RCX, R12, 1, 128))
/*store c*/
VMOVUPD(MEM(RCX, R12, 1 ), ZMM(23))
VMOVUPD(MEM(RCX, R12, 1, 64 ), ZMM(24))
VMOVUPD(MEM(RCX, R12, 1, 128), ZMM(25))
// row6
VMULPD(ZMM(26), ZMM(26), ZMM(0))
VMULPD(ZMM(27), ZMM(27), ZMM(0))
VMULPD(ZMM(28), ZMM(28), ZMM(0))
/*scale by beta*/
VFMADD231PD(ZMM(26), ZMM(1), MEM(RCX, R11, 2 ))
VFMADD231PD(ZMM(27), ZMM(1), MEM(RCX, R11, 2, 64 ))
VFMADD231PD(ZMM(28), ZMM(1), MEM(RCX, R11, 2, 128))
/*store c*/
VMOVUPD(MEM(RCX, R11, 2 ), ZMM(26))
VMOVUPD(MEM(RCX, R11, 2, 64 ), ZMM(27))
VMOVUPD(MEM(RCX, R11, 2, 128), ZMM(28))
// row6
VMULPD(ZMM(29), ZMM(29), ZMM(0))
VMULPD(ZMM(30), ZMM(30), ZMM(0))
VMULPD(ZMM(31), ZMM(31), ZMM(0))
/*scale by beta*/
VFMADD231PD(ZMM(29), ZMM(1), MEM(RCX, R13, 1 ))
VFMADD231PD(ZMM(30), ZMM(1), MEM(RCX, R13, 1, 64 ))
VFMADD231PD(ZMM(31), ZMM(1), MEM(RCX, R13, 1, 128))
/*store c*/
VMOVUPD(MEM(RCX, R13, 1 ), ZMM(29))
VMOVUPD(MEM(RCX, R13, 1, 64 ), ZMM(30))
VMOVUPD(MEM(RCX, R13, 1, 128), ZMM(31))
JMP(END)
LABEL(ROWSTORBZ)
// beta == 0
// row0
VMULPD(ZMM( 8), ZMM( 8), ZMM(0))
VMULPD(ZMM( 9), ZMM( 9), ZMM(0))
VMULPD(ZMM(10), ZMM(10), ZMM(0))
/*store c*/
VMOVUPD(MEM(RCX ), ZMM( 8))
VMOVUPD(MEM(RCX, 64), ZMM( 9))
VMOVUPD(MEM(RCX,128), ZMM(10))
// row1
VMULPD(ZMM(11), ZMM(11), ZMM(0))
VMULPD(ZMM(12), ZMM(12), ZMM(0))
VMULPD(ZMM(13), ZMM(13), ZMM(0))
/*store c*/
VMOVUPD(MEM(RCX, RAX, 1 ), ZMM(11))
VMOVUPD(MEM(RCX, RAX, 1, 64 ), ZMM(12))
VMOVUPD(MEM(RCX, RAX, 1, 128), ZMM(13))
// row2
VMULPD(ZMM(14), ZMM(14), ZMM(0))
VMULPD(ZMM(15), ZMM(15), ZMM(0))
VMULPD(ZMM(16), ZMM(16), ZMM(0))
/*store c*/
VMOVUPD(MEM(RCX, RAX, 2 ), ZMM(14))
VMOVUPD(MEM(RCX, RAX, 2, 64 ), ZMM(15))
VMOVUPD(MEM(RCX, RAX, 2, 128), ZMM(16))
// row3
VMULPD(ZMM(17), ZMM(17), ZMM(0))
VMULPD(ZMM(18), ZMM(18), ZMM(0))
VMULPD(ZMM(19), ZMM(19), ZMM(0))
/*store c*/
VMOVUPD(MEM(RCX, R11, 1 ), ZMM(17))
VMOVUPD(MEM(RCX, R11, 1, 64 ), ZMM(18))
VMOVUPD(MEM(RCX, R11, 1, 128), ZMM(19))
// row4
VMULPD(ZMM(20), ZMM(20), ZMM(0))
VMULPD(ZMM(21), ZMM(21), ZMM(0))
VMULPD(ZMM(22), ZMM(22), ZMM(0))
/*store c*/
VMOVUPD(MEM(RCX, RAX, 4 ), ZMM(20))
VMOVUPD(MEM(RCX, RAX, 4, 64 ), ZMM(21))
VMOVUPD(MEM(RCX, RAX, 4, 128), ZMM(22))
// row5
VMULPD(ZMM(23), ZMM(23), ZMM(0))
VMULPD(ZMM(24), ZMM(24), ZMM(0))
VMULPD(ZMM(25), ZMM(25), ZMM(0))
/*store c*/
VMOVUPD(MEM(RCX, R12, 1 ), ZMM(23))
VMOVUPD(MEM(RCX, R12, 1, 64 ), ZMM(24))
VMOVUPD(MEM(RCX, R12, 1, 128), ZMM(25))
// row6
VMULPD(ZMM(26), ZMM(26), ZMM(0))
VMULPD(ZMM(27), ZMM(27), ZMM(0))
VMULPD(ZMM(28), ZMM(28), ZMM(0))
/*store c*/
VMOVUPD(MEM(RCX, R11, 2 ), ZMM(26))
VMOVUPD(MEM(RCX, R11, 2, 64 ), ZMM(27))
VMOVUPD(MEM(RCX, R11, 2, 128), ZMM(28))
// row6
VMULPD(ZMM(29), ZMM(29), ZMM(0))
VMULPD(ZMM(30), ZMM(30), ZMM(0))
VMULPD(ZMM(31), ZMM(31), ZMM(0))
/*store c*/
VMOVUPD(MEM(RCX, R13, 1 ), ZMM(29))
VMOVUPD(MEM(RCX, R13, 1, 64 ), ZMM(30))
VMOVUPD(MEM(RCX, R13, 1, 128), ZMM(31))
JMP(END)
LABEL(COLUPDATE)
// if C is col stride
// R10 = cs_c
LEA(R11, MEM(R10, R10, 2)) // R11 = cs_c * 3
LEA(R12, MEM(R10, R10, 4)) // R12 = cs_c * 5
LEA(R13, MEM(R10, R11, 2)) // R13 = cs_c * 7
VCOMISD(XMM(1), XMM(2))
JE(COLSTORBZ)
// beta != 0
TRANSPOSE_8X8( 8, 11, 14, 17, 20, 23, 26, 29)
TRANSPOSE_8X8( 9, 12, 15, 18, 21, 24, 27, 30)
TRANSPOSE_8X8(10, 13, 16, 19, 22, 25, 28, 31)
MOV(RAX, VAR(alpha))
MOV(RBX, VAR(beta))
VBROADCASTSD(ZMM(0), MEM(RAX))
VBROADCASTSD(ZMM(1), MEM(RBX))
UPDATE_C_COL_STORE( 8, 11, 14, 17, 20, 23, 26, 29)
UPDATE_C_COL_STORE( 9, 12, 15, 18, 21, 24, 27, 30)
UPDATE_C_COL_STORE(10, 13, 16, 19, 22, 25, 28, 31)
JMP(END)
LABEL(COLSTORBZ)
// beta == 0
TRANSPOSE_8X8( 8, 11, 14, 17, 20, 23, 26, 29)
TRANSPOSE_8X8( 9, 12, 15, 18, 21, 24, 27, 30)
TRANSPOSE_8X8(10, 13, 16, 19, 22, 25, 28, 31)
MOV(RAX, VAR(alpha))
VBROADCASTSD(ZMM(0), MEM(RAX))
UPDATE_C_COL_STORE_BZ( 8, 11, 14, 17, 20, 23, 26, 29)
UPDATE_C_COL_STORE_BZ( 9, 12, 15, 18, 21, 24, 27, 30)
UPDATE_C_COL_STORE_BZ(10, 13, 16, 19, 22, 25, 28, 31)
LABEL(END)
VZEROUPPER()
END_ASM
(
: // output operands
: // input operands
[k] "m" (k),
[a] "m" (a),
[b] "m" (b),
[alpha] "m" (alpha),
[beta] "m" (beta),
[c] "m" (c),
[rs_c] "m" (rs_c),
[cs_c] "m" (cs_c),
[offsetPtr] "m" (offsetPtr)
: // register clobber list
"rax", "rbx", "rcx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
"r13", "k0", "k1", "k2", "k3", "xmm1", "xmm2",
"ymm2", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13",
"ymm14", "ymm15", "ymm16", "ymm17", "ymm18", "ymm19", "ymm20",
"ymm21", "ymm22", "ymm23", "ymm24", "ymm25", "ymm26", "ymm27",
"ymm28", "ymm29", "ymm30", "ymm31", "zmm0", "zmm1", "zmm2", "zmm3",
"zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11",
"zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19",
"zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27",
"zmm28", "zmm29", "zmm30", "zmm31", "memory"
)
}

View File

@@ -126,11 +126,24 @@ PACKM_KER_PROT( dcomplex, z, packm_zen4_asm_12xk )
PACKM_KER_PROT( dcomplex, z, packm_zen4_asm_4xk )
// native dgemm kernel
GEMM_UKR_PROT( double, d, gemm_avx512_asm_8x24 )
GEMM_UKR_PROT( double, d, gemm_zen4_asm_32x6 )
GEMM_UKR_PROT( double, d, gemm_zen4_asm_8x24 )
GEMM_UKR_PROT( dcomplex, z, gemm_zen4_asm_12x4 )
GEMM_UKR_PROT( dcomplex, z, gemm_zen4_asm_4x12 )
// dgemm native macro kernel
void bli_dgemm_avx512_asm_8x24_macro_kernel
(
dim_t n,
dim_t m,
dim_t k,
double* c,
double* a,
double* b,
dim_t ldc,
double* beta
);
//sgemm rv sup
GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x64m_avx512 )

View File

@@ -32,9 +32,6 @@
*/
// native dgemm kernel
GEMM_UKR_PROT( double, d, gemm_avx512_asm_8x24 )
// Dgemm sup RV kernels
GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x8m)
GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x7m)
@@ -45,18 +42,6 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x3m)
GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x2m)
GEMMSUP_KER_PROT( double, d, gemmsup_rv_zen5_asm_24x1m)
void bli_dgemm_avx512_asm_8x24_macro_kernel
(
dim_t n,
dim_t m,
dim_t k,
double* c,
double* a,
double* b,
dim_t ldc,
double* beta
);
// threshold functions
bool bli_cntx_gemmsup_thresh_is_met_zen5
(