Integrated 32x6 DGEMM kernel for zen4 and its related changes are added.

Details:
- Now AOCL BLIS uses AX512 - 32x6 DGEMM kernel for native code path.
  Thanks to Moore, Branden <Branden.Moore@amd.com> for suggesting and
  implementing these optimizations.
- In the initial version of 32x6 DGEMM kernel, to broadcast elements of B packed
  we perform load into xmm (2 elements), broadcast into zmm from xmmm and then to get the
  next element, we do vpermilpd(xmm). This logic is replaced with direct broadcast from
  memory, since the elements of Bpack are stored contiguously, the first broadcast fetches
  the cacheline and then subsequent broadcasts happen faster. We use two registers for broadcast
  and interleave broadcast operation with FMAs to hide any memory latencies.
- Native dTRSM uses 16x14 dgemm - therefore we need to override the default blkszs (MR,NR,..)
  when executing trsm. we call bli_zen4_override_trsm_blkszs(cntx_local) on a local cntx_t object
  for double data-type as well in the function bli_trsm_front(), bli_trsm_xx_ker_var2, xx = {ll,lu,rl,ru}.
  Renamed "BLIS_GEMM_AVX2_UKR" to "BLIS_GEMM_FOR_TRSM_UKR" and in the bli_cntx_init_zen4() we replaced
  dgemm kernel for TRSM with 16x14 dgemm kernel.
- New packm kernels - 16xk, 24xk and 32xk are added.
- New 32xk packm reference kernel is added in bli_packm_cxk_ref.c and it is
  enabled for zen4 config (bli_dpackm_32xk_zen4_ref() )
- Copyright year updated for modified files.
- cleaned up code for "zen" config - removed unused packm kernels declaration in kernels/zen/bli_kernels.h
- [SWLCSG-1374], [CPUPL-2918]

Change-Id: I576282382504b72072a6db068eabd164c8943627
This commit is contained in:
Kiran Varaganti
2023-01-10 12:04:55 +05:30
parent 0a699c45f0
commit 201db7883c
22 changed files with 1558 additions and 47 deletions

View File

@@ -5,6 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -1720,3 +1721,250 @@ void PASTEMAC3(ch,opname,arch,suf) \
INSERT_GENTFUNC_BASIC3( packm_24xk, 24, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX )
#undef GENTFUNC
#define GENTFUNC( ctype, ch, opname, mnr, arch, suf ) \
\
void PASTEMAC3(ch,opname,arch,suf) \
( \
conj_t conja, \
pack_t schema, \
dim_t cdim, \
dim_t n, \
dim_t n_max, \
ctype* restrict kappa, \
ctype* restrict a, inc_t inca, inc_t lda, \
ctype* restrict p, inc_t ldp, \
cntx_t* restrict cntx \
) \
{ \
ctype* restrict kappa_cast = kappa; \
ctype* restrict alpha1 = a; \
ctype* restrict pi1 = p; \
\
if ( cdim == mnr ) \
{ \
if ( PASTEMAC(ch,eq1)( *kappa_cast ) ) \
{ \
if ( bli_is_conj( conja ) ) \
{ \
for ( dim_t k = n; k != 0; --k ) \
{ \
PASTEMAC(ch,copyjs)( *(alpha1 + 0*inca), *(pi1 + 0) ); \
PASTEMAC(ch,copyjs)( *(alpha1 + 1*inca), *(pi1 + 1) ); \
PASTEMAC(ch,copyjs)( *(alpha1 + 2*inca), *(pi1 + 2) ); \
PASTEMAC(ch,copyjs)( *(alpha1 + 3*inca), *(pi1 + 3) ); \
PASTEMAC(ch,copyjs)( *(alpha1 + 4*inca), *(pi1 + 4) ); \
PASTEMAC(ch,copyjs)( *(alpha1 + 5*inca), *(pi1 + 5) ); \
PASTEMAC(ch,copyjs)( *(alpha1 + 6*inca), *(pi1 + 6) ); \
PASTEMAC(ch,copyjs)( *(alpha1 + 7*inca), *(pi1 + 7) ); \
PASTEMAC(ch,copyjs)( *(alpha1 + 8*inca), *(pi1 + 8) ); \
PASTEMAC(ch,copyjs)( *(alpha1 + 9*inca), *(pi1 + 9) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +10*inca), *(pi1 +10) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +11*inca), *(pi1 +11) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +12*inca), *(pi1 +12) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +13*inca), *(pi1 +13) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +14*inca), *(pi1 +14) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +15*inca), *(pi1 +15) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +16*inca), *(pi1 +16) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +17*inca), *(pi1 +17) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +18*inca), *(pi1 +18) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +19*inca), *(pi1 +19) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +20*inca), *(pi1 +20) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +21*inca), *(pi1 +21) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +22*inca), *(pi1 +22) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +23*inca), *(pi1 +23) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +24*inca), *(pi1 +24) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +25*inca), *(pi1 +25) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +26*inca), *(pi1 +26) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +27*inca), *(pi1 +27) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +28*inca), *(pi1 +28) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +29*inca), *(pi1 +29) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +30*inca), *(pi1 +30) ); \
PASTEMAC(ch,copyjs)( *(alpha1 +31*inca), *(pi1 +31) ); \
\
alpha1 += lda; \
pi1 += ldp; \
} \
} \
else \
{ \
for ( dim_t k = n; k != 0; --k ) \
{ \
PASTEMAC(ch,copys)( *(alpha1 + 0*inca), *(pi1 + 0) ); \
PASTEMAC(ch,copys)( *(alpha1 + 1*inca), *(pi1 + 1) ); \
PASTEMAC(ch,copys)( *(alpha1 + 2*inca), *(pi1 + 2) ); \
PASTEMAC(ch,copys)( *(alpha1 + 3*inca), *(pi1 + 3) ); \
PASTEMAC(ch,copys)( *(alpha1 + 4*inca), *(pi1 + 4) ); \
PASTEMAC(ch,copys)( *(alpha1 + 5*inca), *(pi1 + 5) ); \
PASTEMAC(ch,copys)( *(alpha1 + 6*inca), *(pi1 + 6) ); \
PASTEMAC(ch,copys)( *(alpha1 + 7*inca), *(pi1 + 7) ); \
PASTEMAC(ch,copys)( *(alpha1 + 8*inca), *(pi1 + 8) ); \
PASTEMAC(ch,copys)( *(alpha1 + 9*inca), *(pi1 + 9) ); \
PASTEMAC(ch,copys)( *(alpha1 +10*inca), *(pi1 +10) ); \
PASTEMAC(ch,copys)( *(alpha1 +11*inca), *(pi1 +11) ); \
PASTEMAC(ch,copys)( *(alpha1 +12*inca), *(pi1 +12) ); \
PASTEMAC(ch,copys)( *(alpha1 +13*inca), *(pi1 +13) ); \
PASTEMAC(ch,copys)( *(alpha1 +14*inca), *(pi1 +14) ); \
PASTEMAC(ch,copys)( *(alpha1 +15*inca), *(pi1 +15) ); \
PASTEMAC(ch,copys)( *(alpha1 +16*inca), *(pi1 +16) ); \
PASTEMAC(ch,copys)( *(alpha1 +17*inca), *(pi1 +17) ); \
PASTEMAC(ch,copys)( *(alpha1 +18*inca), *(pi1 +18) ); \
PASTEMAC(ch,copys)( *(alpha1 +19*inca), *(pi1 +19) ); \
PASTEMAC(ch,copys)( *(alpha1 +20*inca), *(pi1 +20) ); \
PASTEMAC(ch,copys)( *(alpha1 +21*inca), *(pi1 +21) ); \
PASTEMAC(ch,copys)( *(alpha1 +22*inca), *(pi1 +22) ); \
PASTEMAC(ch,copys)( *(alpha1 +23*inca), *(pi1 +23) ); \
PASTEMAC(ch,copys)( *(alpha1 +24*inca), *(pi1 +24) ); \
PASTEMAC(ch,copys)( *(alpha1 +25*inca), *(pi1 +25) ); \
PASTEMAC(ch,copys)( *(alpha1 +26*inca), *(pi1 +26) ); \
PASTEMAC(ch,copys)( *(alpha1 +27*inca), *(pi1 +27) ); \
PASTEMAC(ch,copys)( *(alpha1 +28*inca), *(pi1 +28) ); \
PASTEMAC(ch,copys)( *(alpha1 +29*inca), *(pi1 +29) ); \
PASTEMAC(ch,copys)( *(alpha1 +30*inca), *(pi1 +30) ); \
PASTEMAC(ch,copys)( *(alpha1 +31*inca), *(pi1 +31) ); \
\
alpha1 += lda; \
pi1 += ldp; \
} \
} \
} \
else \
{ \
if ( bli_is_conj( conja ) ) \
{ \
for ( dim_t k = n; k != 0; --k ) \
{ \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 6*inca), *(pi1 + 6) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 7*inca), *(pi1 + 7) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 8*inca), *(pi1 + 8) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 + 9*inca), *(pi1 + 9) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +10*inca), *(pi1 +10) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +11*inca), *(pi1 +11) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +12*inca), *(pi1 +12) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +13*inca), *(pi1 +13) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +14*inca), *(pi1 +14) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +15*inca), *(pi1 +15) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +16*inca), *(pi1 +16) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +17*inca), *(pi1 +17) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +18*inca), *(pi1 +18) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +19*inca), *(pi1 +19) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +20*inca), *(pi1 +20) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +21*inca), *(pi1 +21) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +22*inca), *(pi1 +22) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +23*inca), *(pi1 +23) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +24*inca), *(pi1 +24) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +25*inca), *(pi1 +25) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +26*inca), *(pi1 +26) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +27*inca), *(pi1 +27) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +28*inca), *(pi1 +28) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +29*inca), *(pi1 +29) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +30*inca), *(pi1 +30) ); \
PASTEMAC(ch,scal2js)( *kappa_cast, *(alpha1 +31*inca), *(pi1 +31) ); \
\
alpha1 += lda; \
pi1 += ldp; \
} \
} \
else \
{ \
for ( dim_t k = n; k != 0; --k ) \
{ \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 0*inca), *(pi1 + 0) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 1*inca), *(pi1 + 1) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 2*inca), *(pi1 + 2) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 3*inca), *(pi1 + 3) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 4*inca), *(pi1 + 4) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 5*inca), *(pi1 + 5) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 6*inca), *(pi1 + 6) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 7*inca), *(pi1 + 7) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 8*inca), *(pi1 + 8) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 + 9*inca), *(pi1 + 9) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +10*inca), *(pi1 +10) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +11*inca), *(pi1 +11) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +12*inca), *(pi1 +12) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +13*inca), *(pi1 +13) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +14*inca), *(pi1 +14) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +15*inca), *(pi1 +15) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +16*inca), *(pi1 +16) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +17*inca), *(pi1 +17) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +18*inca), *(pi1 +18) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +19*inca), *(pi1 +19) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +20*inca), *(pi1 +20) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +21*inca), *(pi1 +21) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +22*inca), *(pi1 +22) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +23*inca), *(pi1 +23) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +24*inca), *(pi1 +24) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +25*inca), *(pi1 +25) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +26*inca), *(pi1 +26) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +27*inca), *(pi1 +27) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +28*inca), *(pi1 +28) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +29*inca), *(pi1 +29) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +30*inca), *(pi1 +30) ); \
PASTEMAC(ch,scal2s)( *kappa_cast, *(alpha1 +31*inca), *(pi1 +31) ); \
\
alpha1 += lda; \
pi1 += ldp; \
} \
} \
} \
} \
else /* if ( cdim < mnr ) */ \
{ \
PASTEMAC2(ch,scal2m,BLIS_TAPI_EX_SUF) \
( \
0, \
BLIS_NONUNIT_DIAG, \
BLIS_DENSE, \
( trans_t )conja, \
cdim, \
n, \
kappa, \
a, inca, lda, \
p, 1, ldp, \
cntx, \
NULL \
); \
\
/* if ( cdim < mnr ) */ \
{ \
const dim_t i = cdim; \
const dim_t m_edge = mnr - cdim; \
const dim_t n_edge = n_max; \
ctype* restrict p_cast = p; \
ctype* restrict p_edge = p_cast + (i )*1; \
\
PASTEMAC(ch,set0s_mxn) \
( \
m_edge, \
n_edge, \
p_edge, 1, ldp \
); \
} \
} \
\
if ( n < n_max ) \
{ \
const dim_t j = n; \
const dim_t m_edge = mnr; \
const dim_t n_edge = n_max - n; \
ctype* restrict p_cast = p; \
ctype* restrict p_edge = p_cast + (j )*ldp; \
\
PASTEMAC(ch,set0s_mxn) \
( \
m_edge, \
n_edge, \
p_edge, 1, ldp \
); \
} \
}
INSERT_GENTFUNC_BASIC3( packm_32xk, 32, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX )