mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Integrated 32x6 DGEMM kernel for zen4 and its related changes are added.
Details: - Now AOCL BLIS uses AX512 - 32x6 DGEMM kernel for native code path. Thanks to Moore, Branden <Branden.Moore@amd.com> for suggesting and implementing these optimizations. - In the initial version of 32x6 DGEMM kernel, to broadcast elements of B packed we perform load into xmm (2 elements), broadcast into zmm from xmmm and then to get the next element, we do vpermilpd(xmm). This logic is replaced with direct broadcast from memory, since the elements of Bpack are stored contiguously, the first broadcast fetches the cacheline and then subsequent broadcasts happen faster. We use two registers for broadcast and interleave broadcast operation with FMAs to hide any memory latencies. - Native dTRSM uses 16x14 dgemm - therefore we need to override the default blkszs (MR,NR,..) when executing trsm. we call bli_zen4_override_trsm_blkszs(cntx_local) on a local cntx_t object for double data-type as well in the function bli_trsm_front(), bli_trsm_xx_ker_var2, xx = {ll,lu,rl,ru}. Renamed "BLIS_GEMM_AVX2_UKR" to "BLIS_GEMM_FOR_TRSM_UKR" and in the bli_cntx_init_zen4() we replaced dgemm kernel for TRSM with 16x14 dgemm kernel. - New packm kernels - 16xk, 24xk and 32xk are added. - New 32xk packm reference kernel is added in bli_packm_cxk_ref.c and it is enabled for zen4 config (bli_dpackm_32xk_zen4_ref() ) - Copyright year updated for modified files. - cleaned up code for "zen" config - removed unused packm kernels declaration in kernels/zen/bli_kernels.h - [SWLCSG-1374], [CPUPL-2918] Change-Id: I576282382504b72072a6db068eabd164c8943627
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -1720,3 +1721,250 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
|
||||
INSERT_GENTFUNC_BASIC3( packm_24xk, 24, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX )
|
||||
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC( ctype, ch, opname, mnr, arch, suf ) \
|
||||
\
|
||||
void PASTEMAC3(ch,opname,arch,suf) \
|
||||
( \
|
||||
conj_t conja, \
|
||||
pack_t schema, \
|
||||
dim_t cdim, \
|
||||
dim_t n, \
|
||||
dim_t n_max, \
|
||||
ctype* restrict kappa, \
|
||||
ctype* restrict a, inc_t inca, inc_t lda, \
|
||||
ctype* restrict p, inc_t ldp, \
|
||||
cntx_t* restrict cntx \
|
||||
) \
|
||||
{ \
|
||||
ctype* restrict kappa_cast = kappa; \
|
||||
ctype* restrict alpha1 = a; \
|
||||
ctype* restrict pi1 = p; \
|
||||
\
|
||||
if ( cdim == mnr ) \
|
||||
{ \
|
||||
if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \
|
||||
{ \
|
||||
if ( bli_is_conj( conja ) ) \
|
||||
{ \
|
||||
for ( dim_t k = n; k != 0; --k ) \
|
||||
{ \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 0) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 1) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 2) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 3) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 4) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 5) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 + 6*inca), *(pi1 + 6) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 + 7*inca), *(pi1 + 7) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 + 8*inca), *(pi1 + 8) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 + 9*inca), *(pi1 + 9) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +10*inca), *(pi1 +10) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +11*inca), *(pi1 +11) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +12*inca), *(pi1 +12) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +13*inca), *(pi1 +13) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +14*inca), *(pi1 +14) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +15*inca), *(pi1 +15) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +16*inca), *(pi1 +16) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +17*inca), *(pi1 +17) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +18*inca), *(pi1 +18) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +19*inca), *(pi1 +19) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +20*inca), *(pi1 +20) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +21*inca), *(pi1 +21) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +22*inca), *(pi1 +22) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +23*inca), *(pi1 +23) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +24*inca), *(pi1 +24) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +25*inca), *(pi1 +25) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +26*inca), *(pi1 +26) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +27*inca), *(pi1 +27) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +28*inca), *(pi1 +28) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +29*inca), *(pi1 +29) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +30*inca), *(pi1 +30) ); \
|
||||
PASTEMAC(ch,copyjs)( *(alpha1 +31*inca), *(pi1 +31) ); \
|
||||
\
|
||||
alpha1 += lda; \
|
||||
pi1 += ldp; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
for ( dim_t k = n; k != 0; --k ) \
|
||||
{ \
|
||||
PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 0) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 1) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 2) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 3) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 4) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 5) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 + 6*inca), *(pi1 + 6) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 + 7*inca), *(pi1 + 7) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 + 8*inca), *(pi1 + 8) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 + 9*inca), *(pi1 + 9) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +10*inca), *(pi1 +10) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +11*inca), *(pi1 +11) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +12*inca), *(pi1 +12) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +13*inca), *(pi1 +13) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +14*inca), *(pi1 +14) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +15*inca), *(pi1 +15) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +16*inca), *(pi1 +16) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +17*inca), *(pi1 +17) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +18*inca), *(pi1 +18) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +19*inca), *(pi1 +19) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +20*inca), *(pi1 +20) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +21*inca), *(pi1 +21) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +22*inca), *(pi1 +22) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +23*inca), *(pi1 +23) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +24*inca), *(pi1 +24) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +25*inca), *(pi1 +25) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +26*inca), *(pi1 +26) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +27*inca), *(pi1 +27) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +28*inca), *(pi1 +28) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +29*inca), *(pi1 +29) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +30*inca), *(pi1 +30) ); \
|
||||
PASTEMAC(ch,copys)( *(alpha1 +31*inca), *(pi1 +31) ); \
|
||||
\
|
||||
alpha1 += lda; \
|
||||
pi1 += ldp; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if ( bli_is_conj( conja ) ) \
|
||||
{ \
|
||||
for ( dim_t k = n; k != 0; --k ) \
|
||||
{ \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 6*inca), *(pi1 + 6) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 7*inca), *(pi1 + 7) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 8*inca), *(pi1 + 8) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 9*inca), *(pi1 + 9) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +10*inca), *(pi1 +10) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +11*inca), *(pi1 +11) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +12*inca), *(pi1 +12) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +13*inca), *(pi1 +13) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +14*inca), *(pi1 +14) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +15*inca), *(pi1 +15) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +16*inca), *(pi1 +16) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +17*inca), *(pi1 +17) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +18*inca), *(pi1 +18) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +19*inca), *(pi1 +19) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +20*inca), *(pi1 +20) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +21*inca), *(pi1 +21) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +22*inca), *(pi1 +22) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +23*inca), *(pi1 +23) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +24*inca), *(pi1 +24) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +25*inca), *(pi1 +25) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +26*inca), *(pi1 +26) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +27*inca), *(pi1 +27) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +28*inca), *(pi1 +28) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +29*inca), *(pi1 +29) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +30*inca), *(pi1 +30) ); \
|
||||
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +31*inca), *(pi1 +31) ); \
|
||||
\
|
||||
alpha1 += lda; \
|
||||
pi1 += ldp; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
for ( dim_t k = n; k != 0; --k ) \
|
||||
{ \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 6*inca), *(pi1 + 6) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 7*inca), *(pi1 + 7) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 8*inca), *(pi1 + 8) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 9*inca), *(pi1 + 9) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +10*inca), *(pi1 +10) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +11*inca), *(pi1 +11) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +12*inca), *(pi1 +12) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +13*inca), *(pi1 +13) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +14*inca), *(pi1 +14) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +15*inca), *(pi1 +15) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +16*inca), *(pi1 +16) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +17*inca), *(pi1 +17) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +18*inca), *(pi1 +18) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +19*inca), *(pi1 +19) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +20*inca), *(pi1 +20) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +21*inca), *(pi1 +21) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +22*inca), *(pi1 +22) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +23*inca), *(pi1 +23) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +24*inca), *(pi1 +24) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +25*inca), *(pi1 +25) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +26*inca), *(pi1 +26) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +27*inca), *(pi1 +27) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +28*inca), *(pi1 +28) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +29*inca), *(pi1 +29) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +30*inca), *(pi1 +30) ); \
|
||||
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +31*inca), *(pi1 +31) ); \
|
||||
\
|
||||
alpha1 += lda; \
|
||||
pi1 += ldp; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
else /* if ( cdim < mnr ) */ \
|
||||
{ \
|
||||
PASTEMAC2(ch,scal2m,BLIS_TAPI_EX_SUF) \
|
||||
( \
|
||||
0, \
|
||||
BLIS_NONUNIT_DIAG, \
|
||||
BLIS_DENSE, \
|
||||
( trans_t )conja, \
|
||||
cdim, \
|
||||
n, \
|
||||
kappa, \
|
||||
a, inca, lda, \
|
||||
p, 1, ldp, \
|
||||
cntx, \
|
||||
NULL \
|
||||
); \
|
||||
\
|
||||
/* if ( cdim < mnr ) */ \
|
||||
{ \
|
||||
const dim_t i = cdim; \
|
||||
const dim_t m_edge = mnr - cdim; \
|
||||
const dim_t n_edge = n_max; \
|
||||
ctype* restrict p_cast = p; \
|
||||
ctype* restrict p_edge = p_cast + (i )*1; \
|
||||
\
|
||||
PASTEMAC(ch,set0s_mxn) \
|
||||
( \
|
||||
m_edge, \
|
||||
n_edge, \
|
||||
p_edge, 1, ldp \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
if ( n < n_max ) \
|
||||
{ \
|
||||
const dim_t j = n; \
|
||||
const dim_t m_edge = mnr; \
|
||||
const dim_t n_edge = n_max - n; \
|
||||
ctype* restrict p_cast = p; \
|
||||
ctype* restrict p_edge = p_cast + (j )*ldp; \
|
||||
\
|
||||
PASTEMAC(ch,set0s_mxn) \
|
||||
( \
|
||||
m_edge, \
|
||||
n_edge, \
|
||||
p_edge, 1, ldp \
|
||||
); \
|
||||
} \
|
||||
}
|
||||
|
||||
INSERT_GENTFUNC_BASIC3( packm_32xk, 32, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX )
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user