Derive TRSM ref kernels from TRSM blkzsz instead of GEMM blszs (#148)

- Currently TRSM reference kernels are derived from GEMM blocksizes and GEMM_UKR.
- This does not allow the flexibility to use different GEMM_UKR for GEMM and TRSM if optimized TRSM_UKR are not available.
- Made changes so that ref TRSM kernels are derived from TRSM blocksizes.
- Changed ZEN4 and ZEN5 cntx to use AVX2 kernels for CTRSM.

AMD-Internal: [SWLCSG-3702]
This commit is contained in:
Sharma, Shubham
2025-08-21 11:25:45 +05:30
committed by GitHub
parent e39cf64708
commit b5c8124d3d
5 changed files with 76 additions and 30 deletions

View File

@@ -52,15 +52,25 @@ void PASTEMAC3(ch,opname,arch,suf) \
{ \
const num_t dt = PASTEMAC(ch,type); \
\
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
/* Use trsm blocksizes if they are available else use general blocksizes. */ \
inc_t packnr = bli_cntx_get_trsm_blksz_max_dt( dt, BLIS_NR, cntx ); \
if ( packnr == 0 ) \
{ \
packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
} \
\
const inc_t rs_b = packnr; \
const inc_t cs_b = 1; \
\
ctype* minus_one = PASTEMAC(ch,m1); \
\
/* Use GEMM_FOR_TRSM_UKR if it is define else use GEMM_UKR. */ \
PASTECH(ch,gemm_ukr_ft) \
gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_FOR_TRSM_UKR, cntx ); \
if ( gemm_ukr == NULL ) \
{ \
gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
} \
PASTECH(ch,trsm_ukr_ft) \
trsm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, trsmkerid, cntx ); \
\

View File

@@ -60,11 +60,20 @@ void PASTEMAC3(ch,opname,arch,suf) \
{ \
const num_t dt = PASTEMAC(ch,type); \
\
const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
/* Use trsm blocksizes if they are available else use general blocksizes. */ \
dim_t mr = bli_cntx_get_trsm_blksz_def_dt( dt, BLIS_MR, cntx ); \
dim_t nr = bli_cntx_get_trsm_blksz_def_dt( dt, BLIS_NR, cntx ); \
\
const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
inc_t packmr = bli_cntx_get_trsm_blksz_max_dt( dt, BLIS_MR, cntx ); \
inc_t packnr = bli_cntx_get_trsm_blksz_max_dt( dt, BLIS_NR, cntx ); \
\
if ( mr == 0 || nr == 0 ) \
{ \
mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
} \
\
const dim_t m = mr; \
const dim_t n = nr; \
@@ -146,11 +155,20 @@ void PASTEMAC3(ch,opname,arch,suf) \
{ \
const num_t dt = PASTEMAC(ch,type); \
\
const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
/* Use trsm blocksizes if they are available else use general blocksizes. */ \
dim_t mr = bli_cntx_get_trsm_blksz_def_dt( dt, BLIS_MR, cntx ); \
dim_t nr = bli_cntx_get_trsm_blksz_def_dt( dt, BLIS_NR, cntx ); \
\
const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
inc_t packmr = bli_cntx_get_trsm_blksz_max_dt( dt, BLIS_MR, cntx ); \
inc_t packnr = bli_cntx_get_trsm_blksz_max_dt( dt, BLIS_NR, cntx ); \
\
if ( mr == 0 || nr == 0 ) \
{ \
mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
} \
\
const dim_t m = mr; \
const dim_t n = nr; \