Optimized SUP code for GEMMT

Details:
- Eliminated the IR loop in ref_var2m functions.
- Handled the rectangular and triangular portions of C matrix
  separately.
- Added a condition to check and eliminate zero regions inside IC loop.
- modified kc selection logic to choose optimal KC in SUP
- Updated thresholds to choose between SUP and native.

Change-Id: I21908eaa6bc3a8f37bdea29f7bfca7e6fcfee724
This commit is contained in:
Meghana Vankadari
2021-06-18 17:32:25 +05:30
committed by Dipal M Zambare
parent faeb79f2b9
commit 10ca8710f0
3 changed files with 242 additions and 148 deletions

View File

@@ -1497,7 +1497,7 @@ void PASTEMACT(ch,opname,uplo,varname) \
function pointer type. */ \
PASTECH(ch,gemmsup_ker_ft) \
gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \
ctype ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( ctype ) ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
ctype ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( ctype ) ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
\
/* storage-scheme of ct should be same as that of C.
Since update routines only support row-major order,
@@ -1580,6 +1580,7 @@ void PASTEMACT(ch,opname,uplo,varname) \
dim_t m_off = 0; \
dim_t n_off = 0; \
doff_t diagoffc; \
dim_t i, ip; \
\
/* Loop over the n dimension (NC rows/columns at a time). */ \
/*for ( dim_t jj = 0; jj < jc_iter; jj += 1 )*/ \
@@ -1690,7 +1691,8 @@ void PASTEMACT(ch,opname,uplo,varname) \
for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \
{ \
/* Calculate the thread's current IC block dimension. */ \
const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \
dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \
dim_t nc_pruned = nc_cur; \
\
ctype* restrict a_ic = a_pc + ii * icstep_a; \
ctype* restrict c_ic = c_jc + ii * icstep_c; \
@@ -1699,7 +1701,24 @@ void PASTEMACT(ch,opname,uplo,varname) \
\
if(bli_gemmt_is_strictly_above_diag( m_off, n_off, mc_cur, nc_cur ) ) continue; \
\
PASTEMAC(ch,set0s_mxn) ( MR, NR, ct, rs_ct, cs_ct ); \
diagoffc = m_off - n_off; \
\
if( diagoffc < 0 ) \
{ \
ip = -diagoffc / MR; \
i = ip * MR; \
mc_cur = mc_cur - i; \
diagoffc = -diagoffc % MR; \
m_off += i; \
c_ic = c_ic + ( i ) * rs_c; \
a_ic = a_ic + ( i ) * rs_a; \
} \
\
if( ( diagoffc + mc_cur ) < nc_cur ) \
{ \
nc_pruned = diagoffc + mc_cur; \
} \
\
ctype* a_use; \
inc_t rs_a_use, cs_a_use, ps_a_use; \
\
@@ -1755,8 +1774,8 @@ void PASTEMACT(ch,opname,uplo,varname) \
bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \
\
/* Compute number of primary and leftover components of the JR loop. */ \
dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \
dim_t jr_left = nc_cur % NR; \
dim_t jr_iter = ( nc_pruned + NR - 1 ) / NR; \
dim_t jr_left = nc_pruned % NR; \
\
/* Compute the JR loop thread range for the current thread. */ \
dim_t jr_start, jr_end; \
@@ -1785,76 +1804,92 @@ void PASTEMACT(ch,opname,uplo,varname) \
ctype* restrict b_jr = b_pc_use + j * ps_b_use; \
ctype* restrict c_jr = c_ic + j * jrstep_c; \
\
const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \
const dim_t ir_left = mc_cur % MR; \
dim_t i; \
dim_t m_zero = 0; \
dim_t n_iter_zero = 0; \
\
/* Loop over the m dimension (MR rows at a time). */ \
for(dim_t i = 0; i < ir_iter; i += 1 ) \
m_off_cblock = m_off; \
n_off_cblock = n_off + j * NR; \
\
if(bli_gemmt_is_strictly_below_diag(m_off_cblock, n_off_cblock, mc_cur, nc_cur)) \
{ \
const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \
m_zero = 0; \
} \
else \
{ \
/* compute number of rows that are filled with zeroes and can be ignored */ \
n_iter_zero = (n_off_cblock < m_off_cblock)? 0 : (n_off_cblock - m_off)/MR; \
m_zero = n_iter_zero * MR; \
} \
\
m_off_cblock = m_off + i * MR; \
n_off_cblock = n_off + j * NR; \
if(bli_gemmt_is_strictly_above_diag( m_off_cblock, n_off_cblock, mr_cur, nr_cur )) continue; \
ctype* restrict a_ir = a_ic_use + i * ps_a_use; \
ctype* restrict c_ir = c_jr + i * irstep_c; \
if( bli_gemmt_is_strictly_below_diag(m_off_cblock, n_off_cblock, mr_cur, nr_cur ) ) \
ctype* restrict a_ir = a_ic_use + n_iter_zero * ps_a_use; \
ctype* restrict c_ir = c_jr + n_iter_zero * irstep_c; \
\
/* Ignore the zero region */ \
m_off_cblock += m_zero; \
\
/* Compute the triangular part */ \
for( i = m_zero; (i < mc_cur) && ( m_off_cblock < n_off_cblock + nr_cur); i += MR ) \
{ \
const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \
\
/* Invoke the gemmsup millikernel. */ \
gemmsup_ker \
( \
conja, \
conjb, \
mr_cur, \
nr_cur, \
kc_cur, \
alpha_cast, \
a_ir, rs_a_use, cs_a_use, \
b_jr, rs_b_use, cs_b_use, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
/* Scale the bottom edge of C and add the result from above. */ \
/* If c and ct are col-major, induce transpose and call update for upper-triangle of C */ \
if( col_pref ) \
{ \
/* Invoke the gemmsup millikernel. */ \
gemmsup_ker \
( \
conja, \
conjb, \
mr_cur, \
nr_cur, \
kc_cur, \
alpha_cast, \
a_ir, rs_a_use, cs_a_use, \
b_jr, rs_b_use, cs_b_use, \
beta_use, \
c_ir, rs_c, cs_c, \
&aux, \
cntx \
); \
PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \
nr_cur, mr_cur, \
ct, cs_ct, rs_ct, \
beta_use, \
c_ir, cs_c, rs_c ); \
} \
else \
{ \
/* Invoke the gemmsup millikernel. */ \
gemmsup_ker \
( \
conja, \
conjb, \
mr_cur, \
nr_cur, \
kc_cur, \
alpha_cast, \
a_ir, rs_a_use, cs_a_use, \
b_jr, rs_b_use, cs_b_use, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
/* Scale the bottom edge of C and add the result from above. */ \
/* If c and ct are col-major, induce transpose and call update for upper-triangle of C */ \
if( col_pref ) \
{ \
PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \
nr_cur, mr_cur, \
ct, cs_ct, rs_ct, \
beta_use, \
c_ir, cs_c, rs_c ); \
} \
else \
{ \
PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \
mr_cur, nr_cur, \
ct, rs_ct, cs_ct, \
beta_use, \
c_ir, rs_c, cs_c ); \
} \
PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \
mr_cur, nr_cur, \
ct, rs_ct, cs_ct, \
beta_use, \
c_ir, rs_c, cs_c ); \
} \
\
a_ir += ps_a_use; \
c_ir += irstep_c; \
m_off_cblock += mr_cur; \
} \
\
/* Invoke the gemmsup millikerneli for remaining rectangular part. */ \
gemmsup_ker \
( \
conja, \
conjb, \
(i > mc_cur)? 0: mc_cur - i, \
nr_cur, \
kc_cur, \
alpha_cast, \
a_ir, rs_a_use, cs_a_use, \
b_jr, rs_b_use, cs_b_use, \
beta_use, \
c_ir, rs_c, cs_c, \
&aux, \
cntx \
); \
\
} \
} \
\
@@ -1889,8 +1924,6 @@ PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c
INSERT_GENTFUNC_L( gemmtsup, ref_var2m )
#undef GENTFUNC
#define GENTFUNC( ctype, ch, opname, uplo, varname ) \
\
@@ -1978,6 +2011,13 @@ void PASTEMACT(ch,opname,uplo,varname) \
stor_id == BLIS_CCC ) KC = KC0; \
else if ( stor_id == BLIS_RRC || \
stor_id == BLIS_CRC ) KC = KC0; \
else if ( stor_id == BLIS_RCR ) \
{ \
if ( m <= 4*MR ) KC = KC0; \
else if ( m <= 36*MR ) KC = KC0 / 2; \
else if ( m <= 56*MR ) KC = (( KC0 / 3 ) / 4 ) * 4; \
else KC = KC0 / 4; \
} \
else if ( m <= MR && n <= NR ) KC = KC0; \
else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; \
else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; \
@@ -2026,8 +2066,6 @@ void PASTEMACT(ch,opname,uplo,varname) \
\
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
\
PASTEMAC(ch,set0s_mxn) ( MR, NR, ct, rs_ct, cs_ct ); \
\
ctype* restrict a_00 = a; \
ctype* restrict b_00 = b; \
@@ -2097,6 +2135,7 @@ void PASTEMACT(ch,opname,uplo,varname) \
dim_t n_off = 0; \
doff_t diagoffc; \
dim_t m_off_cblock, n_off_cblock; \
dim_t jp, j; \
\
/* Compute number of primary and leftover components of the JC loop. */ \
/*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \
@@ -2211,14 +2250,37 @@ void PASTEMACT(ch,opname,uplo,varname) \
for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \
{ \
/* Calculate the thread's current IC block dimension. */ \
const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \
dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \
\
dim_t nc_pruned = nc_cur; \
\
m_off = ii; \
n_off = jj; \
\
if(bli_gemmt_is_strictly_below_diag(m_off, n_off, mc_cur, nc_cur)) continue; \
\
ctype* restrict a_ic = a_pc + ii * icstep_a; \
ctype* restrict c_ic = c_jc + ii * icstep_c; \
\
doff_t diagoffc = m_off - n_off; \
\
ctype* restrict b_pc_pruned = b_pc_use; \
\
if(diagoffc > 0 ) \
{ \
jp = diagoffc / NR; \
j = jp * NR; \
nc_pruned = nc_cur - j; \
n_off += j; \
diagoffc = diagoffc % NR; \
c_ic = c_ic + ( j ) * cs_c; \
b_pc_pruned = b_pc_use + ( jp ) * ps_b_use; \
} \
\
if( ( ( -diagoffc ) + nc_pruned ) < mc_cur ) \
{ \
mc_cur = -diagoffc + nc_pruned; \
} \
\
ctype* a_use; \
inc_t rs_a_use, cs_a_use, ps_a_use; \
@@ -2275,8 +2337,8 @@ void PASTEMACT(ch,opname,uplo,varname) \
bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \
\
/* Compute number of primary and leftover components of the JR loop. */ \
dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \
dim_t jr_left = nc_cur % NR; \
dim_t jr_iter = ( nc_pruned + NR - 1 ) / NR; \
dim_t jr_left = nc_pruned % NR; \
\
/* Compute the JR loop thread range for the current thread. */ \
dim_t jr_start, jr_end; \
@@ -2302,77 +2364,89 @@ void PASTEMACT(ch,opname,uplo,varname) \
/*
ctype* restrict b_jr = b_pc_use + j * jrstep_b; \
*/ \
ctype* restrict b_jr = b_pc_use + j * ps_b_use; \
ctype* restrict b_jr = b_pc_pruned + j * ps_b_use; \
ctype* restrict c_jr = c_ic + j * jrstep_c; \
dim_t m_rect = 0; \
dim_t n_iter_rect = 0; \
\
const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \
const dim_t ir_left = mc_cur % MR; \
m_off_cblock = m_off; \
n_off_cblock = n_off + j * NR; \
\
/* Loop over the m dimension (MR rows at a time). */ \
for(dim_t i = 0; i < ir_iter; i += 1 ) \
if(bli_gemmt_is_strictly_above_diag(m_off_cblock, n_off_cblock, mc_cur, nr_cur)) \
{ \
const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \
m_off_cblock = m_off + i * MR; \
n_off_cblock = n_off + j * NR; \
if( bli_gemmt_is_strictly_below_diag( m_off_cblock, n_off_cblock, mr_cur, nr_cur )) continue; \
ctype* restrict a_ir = a_ic_use + i * ps_a_use; \
ctype* restrict c_ir = c_jr + i * irstep_c; \
if(bli_gemmt_is_strictly_above_diag( m_off_cblock, n_off_cblock, mr_cur, nr_cur )) \
m_rect = mc_cur; \
} \
else \
{ \
/* calculate the number of rows in rectangular region of the block */ \
n_iter_rect = n_off_cblock < m_off_cblock ? 0: (n_off_cblock - m_off_cblock) / MR; \
m_rect = n_iter_rect * MR; \
} \
\
/* Compute the rectangular part */ \
gemmsup_ker \
( \
conja, \
conjb, \
m_rect, \
nr_cur, \
kc_cur, \
alpha_cast, \
a_ic_use, rs_a_use, cs_a_use, \
b_jr, rs_b_use, cs_b_use, \
beta_use, \
c_jr, rs_c, cs_c, \
&aux, \
cntx \
); \
\
m_off_cblock = m_off + m_rect; \
\
ctype* restrict a_ir = a_ic_use + n_iter_rect * ps_a_use; \
ctype* restrict c_ir = c_jr + n_iter_rect * irstep_c; \
\
/* compute the remaining triangular part */ \
for( dim_t i = m_rect;( i < mc_cur) && (m_off_cblock < n_off_cblock + nr_cur); i += MR ) \
{ \
const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \
\
/* Invoke the gemmsup millikernel. */ \
gemmsup_ker \
( \
conja, \
conjb, \
mr_cur, \
nr_cur, \
kc_cur, \
alpha_cast, \
a_ir, rs_a_use, cs_a_use, \
b_jr, rs_b_use, cs_b_use, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
if( col_pref ) \
{ \
/* Invoke the gemmsup millikernel. */ \
gemmsup_ker \
( \
conja, \
conjb, \
mr_cur, \
nr_cur, \
kc_cur, \
alpha_cast, \
a_ir, rs_a_use, cs_a_use, \
b_jr, rs_b_use, cs_b_use, \
beta_use, \
c_ir, rs_c, cs_c, \
&aux, \
cntx \
); \
PASTEMAC(ch,update_lower_triang)( n_off_cblock, m_off_cblock, \
nr_cur, mr_cur, \
ct, cs_ct, rs_ct, \
beta_use, \
c_ir, cs_c, rs_c ); \
} \
else \
{ \
/* Invoke the gemmsup millikernel. */ \
gemmsup_ker \
( \
conja, \
conjb, \
mr_cur, \
nr_cur, \
kc_cur, \
alpha_cast, \
a_ir, rs_a_use, cs_a_use, \
b_jr, rs_b_use, cs_b_use, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* If c and ct are col-major, induce transpose and call update for lower-triangle of C */ \
if( col_pref ) \
{ \
PASTEMAC(ch,update_lower_triang)( n_off_cblock, m_off_cblock, \
nr_cur, mr_cur, \
ct, cs_ct, rs_ct, \
beta_use, \
c_ir, cs_c, rs_c ); \
} \
else \
{ \
PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \
mr_cur, nr_cur, \
ct, rs_ct, cs_ct, \
beta_use, \
c_ir, rs_c, cs_c ); \
} \
PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \
mr_cur, nr_cur, \
ct, rs_ct, cs_ct, \
beta_use, \
c_ir, rs_c, cs_c ); \
} \
a_ir += ps_a_use; \
c_ir += irstep_c; \
m_off_cblock += mr_cur; \
\
} \
} \
} \

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc.
Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -61,25 +61,35 @@ void PASTEMAC(ch, varname) \
start = ((n_off < m_off) && (m_off < n_off + n_cur)) ? m_off: n_off; \
end = ((n_off < m_off + m_cur) && (m_off + m_cur < n_off + n_cur))? (m_off + m_cur):(n_off + n_cur); \
\
if( beta_val != 0.0 ) \
if ( beta_val == 1.0 ) \
{ \
for(diag = start, m= start-m_off; diag < end; diag++, m++) \
for(n = 0; n <= diag-n_off; n++) \
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
c[m*rs_c + n] += ct[m*rs_ct + n]; \
\
for(; m < m_cur; m++) \
for(n = 0; n < n_cur; n++) \
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
c[m*rs_c + n] += ct[m*rs_ct + n]; \
} \
else if( beta_val == 0.0 )\
{ \
for(diag = start, m= start-m_off; diag < end; diag++, m++) \
for(n = 0; n <= diag-n_off; n++) \
c[m*rs_c + n] = ct[m*rs_ct + n]; \
\
for(; m < m_cur; m++) \
for(n = 0; n < n_cur; n++) \
c[m*rs_c + n] = ct[m*rs_ct + n]; \
} \
else \
{ \
for(diag = start, m= start-m_off; diag < end; diag++, m++) \
for(n = 0; n <= diag-n_off; n++) \
c[m*rs_c + n] = ct[m*rs_ct + n]; \
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
\
for(; m < m_cur; m++) \
for(n = 0; n < n_cur; n++) \
c[m*rs_c + n] = ct[m*rs_ct + n]; \
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
} \
\
return; \
@@ -109,17 +119,17 @@ void PASTEMAC(ch, varname) \
start = ((n_off < m_off) && (m_off < n_off + n_cur)) ? m_off: n_off; \
end = ((n_off < m_off + m_cur) && (m_off + m_cur < n_off + n_cur))? (m_off + m_cur):(n_off + n_cur); \
\
if( beta_val != 0.0 ) \
if( beta_val == 1.0 ) \
{ \
for(m = 0; m < start-m_off; m++) \
for(n = 0; n < n_cur; n++) \
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
c[m*rs_c + n] += ct[m*rs_ct + n]; \
\
for(diag = start, m= start-m_off; diag < end; diag++, m++) \
for(n = diag-n_off; n < n_cur; n++) \
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
for(diag = start, m= start-m_off; diag < end; diag++, m++) \
for(n = diag-n_off; n < n_cur; n++) \
c[m*rs_c + n] += ct[m*rs_ct + n]; \
} \
else \
else if ( beta_val == 0.0 )\
{ \
for(m = 0; m < start-m_off; m++) \
for(n = 0; n < n_cur; n++) \
@@ -129,6 +139,16 @@ void PASTEMAC(ch, varname) \
for(n = diag-n_off; n < n_cur; n++) \
c[m*rs_c + n] = ct[m*rs_ct + n]; \
} \
else \
{ \
for(m = 0; m < start-m_off; m++) \
for(n = 0; n < n_cur; n++) \
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
\
for(diag = start, m= start-m_off; diag < end; diag++, m++) \
for(n = diag-n_off; n < n_cur; n++) \
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
} \
\
return; \
}

View File

@@ -79,7 +79,7 @@ bool bli_cntx_syrksup_thresh_is_met_zen( obj_t* a, obj_t* b, obj_t* c, cntx_t* c
}
else
{
if( n < 150 ) return TRUE;
if( n <= 432 ) return TRUE;
else return FALSE;
}
}