mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Derive TRSM ref kernels from TRSM blkzsz instead of GEMM blszs (#148)
- Currently TRSM reference kernels are derived from GEMM blocksizes and GEMM_UKR. - This does not allow the flexibility to use different GEMM_UKR for GEMM and TRSM if optimized TRSM_UKR are not available. - Made changes so that ref TRSM kernels are derived from TRSM blocksizes. - Changed ZEN4 and ZEN5 cntx to use AVX2 kernels for CTRSM. AMD-Internal: [SWLCSG-3702]
This commit is contained in:
@@ -52,15 +52,25 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
{ \
|
||||
const num_t dt = PASTEMAC(ch,type); \
|
||||
\
|
||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||
/* Use trsm blocksizes if they are available else use general blocksizes. */ \
|
||||
inc_t packnr = bli_cntx_get_trsm_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||
if ( packnr == 0 ) \
|
||||
{ \
|
||||
packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||
} \
|
||||
\
|
||||
const inc_t rs_b = packnr; \
|
||||
const inc_t cs_b = 1; \
|
||||
\
|
||||
ctype* minus_one = PASTEMAC(ch,m1); \
|
||||
\
|
||||
/* Use GEMM_FOR_TRSM_UKR if it is define else use GEMM_UKR. */ \
|
||||
PASTECH(ch,gemm_ukr_ft) \
|
||||
gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||
gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_FOR_TRSM_UKR, cntx ); \
|
||||
if ( gemm_ukr == NULL ) \
|
||||
{ \
|
||||
gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||
} \
|
||||
PASTECH(ch,trsm_ukr_ft) \
|
||||
trsm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, trsmkerid, cntx ); \
|
||||
\
|
||||
|
||||
@@ -60,11 +60,20 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
{ \
|
||||
const num_t dt = PASTEMAC(ch,type); \
|
||||
\
|
||||
const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
||||
const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
||||
/* Use trsm blocksizes if they are available else use general blocksizes. */ \
|
||||
dim_t mr = bli_cntx_get_trsm_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
||||
dim_t nr = bli_cntx_get_trsm_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
||||
\
|
||||
const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
|
||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||
inc_t packmr = bli_cntx_get_trsm_blksz_max_dt( dt, BLIS_MR, cntx ); \
|
||||
inc_t packnr = bli_cntx_get_trsm_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||
\
|
||||
if ( mr == 0 || nr == 0 ) \
|
||||
{ \
|
||||
mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
||||
nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
||||
packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
|
||||
packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||
} \
|
||||
\
|
||||
const dim_t m = mr; \
|
||||
const dim_t n = nr; \
|
||||
@@ -146,11 +155,20 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
{ \
|
||||
const num_t dt = PASTEMAC(ch,type); \
|
||||
\
|
||||
const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
||||
const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
||||
/* Use trsm blocksizes if they are available else use general blocksizes. */ \
|
||||
dim_t mr = bli_cntx_get_trsm_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
||||
dim_t nr = bli_cntx_get_trsm_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
||||
\
|
||||
const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
|
||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||
inc_t packmr = bli_cntx_get_trsm_blksz_max_dt( dt, BLIS_MR, cntx ); \
|
||||
inc_t packnr = bli_cntx_get_trsm_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||
\
|
||||
if ( mr == 0 || nr == 0 ) \
|
||||
{ \
|
||||
mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
||||
nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
||||
packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
|
||||
packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||
} \
|
||||
\
|
||||
const dim_t m = mr; \
|
||||
const dim_t n = nr; \
|
||||
|
||||
Reference in New Issue
Block a user