diff --git a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c index ff46d1f52..382ca6f67 100644 --- a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c +++ b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c @@ -1497,7 +1497,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ function pointer type. */ \ PASTECH(ch,gemmsup_ker_ft) \ gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( ctype ) ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( ctype ) ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ \ /* storage-scheme of ct should be same as that of C. Since update routines only support row-major order, @@ -1580,6 +1580,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ dim_t m_off = 0; \ dim_t n_off = 0; \ doff_t diagoffc; \ + dim_t i, ip; \ \ /* Loop over the n dimension (NC rows/columns at a time). */ \ /*for ( dim_t jj = 0; jj < jc_iter; jj += 1 )*/ \ @@ -1690,7 +1691,8 @@ void PASTEMACT(ch,opname,uplo,varname) \ for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ { \ /* Calculate the thread's current IC block dimension. */ \ - const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ + dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ + dim_t nc_pruned = nc_cur; \ \ ctype* restrict a_ic = a_pc + ii * icstep_a; \ ctype* restrict c_ic = c_jc + ii * icstep_c; \ @@ -1699,7 +1701,24 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ if(bli_gemmt_is_strictly_above_diag( m_off, n_off, mc_cur, nc_cur ) ) continue; \ \ - PASTEMAC(ch,set0s_mxn) ( MR, NR, ct, rs_ct, cs_ct ); \ + diagoffc = m_off - n_off; \ +\ + if( diagoffc < 0 ) \ + { \ + ip = -diagoffc / MR; \ + i = ip * MR; \ + mc_cur = mc_cur - i; \ + diagoffc = -diagoffc % MR; \ + m_off += i; \ + c_ic = c_ic + ( i ) * rs_c; \ + a_ic = a_ic + ( i ) * rs_a; \ + } \ +\ + if( ( diagoffc + mc_cur ) < nc_cur ) \ + { \ + nc_pruned = diagoffc + mc_cur; \ + } \ +\ ctype* a_use; \ inc_t rs_a_use, cs_a_use, ps_a_use; \ \ @@ -1755,8 +1774,8 @@ void PASTEMACT(ch,opname,uplo,varname) \ bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ \ /* Compute number of primary and leftover components of the JR loop. */ \ - dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ - dim_t jr_left = nc_cur % NR; \ + dim_t jr_iter = ( nc_pruned + NR - 1 ) / NR; \ + dim_t jr_left = nc_pruned % NR; \ \ /* Compute the JR loop thread range for the current thread. */ \ dim_t jr_start, jr_end; \ @@ -1785,76 +1804,92 @@ void PASTEMACT(ch,opname,uplo,varname) \ ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ ctype* restrict c_jr = c_ic + j * jrstep_c; \ \ - const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ - const dim_t ir_left = mc_cur % MR; \ + dim_t i; \ + dim_t m_zero = 0; \ + dim_t n_iter_zero = 0; \ \ - /* Loop over the m dimension (MR rows at a time). */ \ - for(dim_t i = 0; i < ir_iter; i += 1 ) \ + m_off_cblock = m_off; \ + n_off_cblock = n_off + j * NR; \ +\ + if(bli_gemmt_is_strictly_below_diag(m_off_cblock, n_off_cblock, mc_cur, nc_cur)) \ { \ - const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ + m_zero = 0; \ + } \ + else \ + { \ + /* compute number of rows that are filled with zeroes and can be ignored */ \ + n_iter_zero = (n_off_cblock < m_off_cblock)? 0 : (n_off_cblock - m_off)/MR; \ + m_zero = n_iter_zero * MR; \ + } \ \ - m_off_cblock = m_off + i * MR; \ - n_off_cblock = n_off + j * NR; \ - if(bli_gemmt_is_strictly_above_diag( m_off_cblock, n_off_cblock, mr_cur, nr_cur )) continue; \ - ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ - ctype* restrict c_ir = c_jr + i * irstep_c; \ - if( bli_gemmt_is_strictly_below_diag(m_off_cblock, n_off_cblock, mr_cur, nr_cur ) ) \ + ctype* restrict a_ir = a_ic_use + n_iter_zero * ps_a_use; \ + ctype* restrict c_ir = c_jr + n_iter_zero * irstep_c; \ +\ + /* Ignore the zero region */ \ + m_off_cblock += m_zero; \ +\ + /* Compute the triangular part */ \ + for( i = m_zero; (i < mc_cur) && ( m_off_cblock < n_off_cblock + nr_cur); i += MR ) \ + { \ + const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \ +\ + /* Invoke the gemmsup millikernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ + /* Scale the bottom edge of C and add the result from above. */ \ + /* If c and ct are col-major, induce transpose and call update for upper-triangle of C */ \ + if( col_pref ) \ { \ - /* Invoke the gemmsup millikernel. */ \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ + PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \ + nr_cur, mr_cur, \ + ct, cs_ct, rs_ct, \ + beta_use, \ + c_ir, cs_c, rs_c ); \ } \ else \ { \ - /* Invoke the gemmsup millikernel. */ \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ - /* Scale the bottom edge of C and add the result from above. */ \ - /* If c and ct are col-major, induce transpose and call update for upper-triangle of C */ \ - if( col_pref ) \ - { \ - PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \ - nr_cur, mr_cur, \ - ct, cs_ct, rs_ct, \ - beta_use, \ - c_ir, cs_c, rs_c ); \ - } \ - else \ - { \ - PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \ - mr_cur, nr_cur, \ - ct, rs_ct, cs_ct, \ - beta_use, \ - c_ir, rs_c, cs_c ); \ - } \ + PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \ + mr_cur, nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c ); \ } \ +\ + a_ir += ps_a_use; \ + c_ir += irstep_c; \ + m_off_cblock += mr_cur; \ } \ +\ + /* Invoke the gemmsup millikerneli for remaining rectangular part. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + (i > mc_cur)? 0: mc_cur - i, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ +\ } \ } \ \ @@ -1889,8 +1924,6 @@ PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c INSERT_GENTFUNC_L( gemmtsup, ref_var2m ) - - #undef GENTFUNC #define GENTFUNC( ctype, ch, opname, uplo, varname ) \ \ @@ -1978,6 +2011,13 @@ void PASTEMACT(ch,opname,uplo,varname) \ stor_id == BLIS_CCC ) KC = KC0; \ else if ( stor_id == BLIS_RRC || \ stor_id == BLIS_CRC ) KC = KC0; \ + else if ( stor_id == BLIS_RCR ) \ + { \ + if ( m <= 4*MR ) KC = KC0; \ + else if ( m <= 36*MR ) KC = KC0 / 2; \ + else if ( m <= 56*MR ) KC = (( KC0 / 3 ) / 4 ) * 4; \ + else KC = KC0 / 4; \ + } \ else if ( m <= MR && n <= NR ) KC = KC0; \ else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; \ else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; \ @@ -2026,8 +2066,6 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - PASTEMAC(ch,set0s_mxn) ( MR, NR, ct, rs_ct, cs_ct ); \ \ ctype* restrict a_00 = a; \ ctype* restrict b_00 = b; \ @@ -2097,6 +2135,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ dim_t n_off = 0; \ doff_t diagoffc; \ dim_t m_off_cblock, n_off_cblock; \ + dim_t jp, j; \ \ /* Compute number of primary and leftover components of the JC loop. */ \ /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ @@ -2211,14 +2250,37 @@ void PASTEMACT(ch,opname,uplo,varname) \ for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ { \ /* Calculate the thread's current IC block dimension. */ \ - const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ + dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + dim_t nc_pruned = nc_cur; \ \ m_off = ii; \ + n_off = jj; \ \ if(bli_gemmt_is_strictly_below_diag(m_off, n_off, mc_cur, nc_cur)) continue; \ \ ctype* restrict a_ic = a_pc + ii * icstep_a; \ ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + doff_t diagoffc = m_off - n_off; \ +\ + ctype* restrict b_pc_pruned = b_pc_use; \ +\ + if(diagoffc > 0 ) \ + { \ + jp = diagoffc / NR; \ + j = jp * NR; \ + nc_pruned = nc_cur - j; \ + n_off += j; \ + diagoffc = diagoffc % NR; \ + c_ic = c_ic + ( j ) * cs_c; \ + b_pc_pruned = b_pc_use + ( jp ) * ps_b_use; \ + } \ +\ + if( ( ( -diagoffc ) + nc_pruned ) < mc_cur ) \ + { \ + mc_cur = -diagoffc + nc_pruned; \ + } \ \ ctype* a_use; \ inc_t rs_a_use, cs_a_use, ps_a_use; \ @@ -2275,8 +2337,8 @@ void PASTEMACT(ch,opname,uplo,varname) \ bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ \ /* Compute number of primary and leftover components of the JR loop. */ \ - dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ - dim_t jr_left = nc_cur % NR; \ + dim_t jr_iter = ( nc_pruned + NR - 1 ) / NR; \ + dim_t jr_left = nc_pruned % NR; \ \ /* Compute the JR loop thread range for the current thread. */ \ dim_t jr_start, jr_end; \ @@ -2302,77 +2364,89 @@ void PASTEMACT(ch,opname,uplo,varname) \ /* ctype* restrict b_jr = b_pc_use + j * jrstep_b; \ */ \ - ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict b_jr = b_pc_pruned + j * ps_b_use; \ ctype* restrict c_jr = c_ic + j * jrstep_c; \ + dim_t m_rect = 0; \ + dim_t n_iter_rect = 0; \ \ - const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ - const dim_t ir_left = mc_cur % MR; \ + m_off_cblock = m_off; \ + n_off_cblock = n_off + j * NR; \ \ - /* Loop over the m dimension (MR rows at a time). */ \ - for(dim_t i = 0; i < ir_iter; i += 1 ) \ + if(bli_gemmt_is_strictly_above_diag(m_off_cblock, n_off_cblock, mc_cur, nr_cur)) \ { \ - const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ - m_off_cblock = m_off + i * MR; \ - n_off_cblock = n_off + j * NR; \ - if( bli_gemmt_is_strictly_below_diag( m_off_cblock, n_off_cblock, mr_cur, nr_cur )) continue; \ - ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ - ctype* restrict c_ir = c_jr + i * irstep_c; \ - if(bli_gemmt_is_strictly_above_diag( m_off_cblock, n_off_cblock, mr_cur, nr_cur )) \ + m_rect = mc_cur; \ + } \ + else \ + { \ + /* calculate the number of rows in rectangular region of the block */ \ + n_iter_rect = n_off_cblock < m_off_cblock ? 0: (n_off_cblock - m_off_cblock) / MR; \ + m_rect = n_iter_rect * MR; \ + } \ +\ + /* Compute the rectangular part */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + m_rect, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ic_use, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_jr, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ +\ + m_off_cblock = m_off + m_rect; \ +\ + ctype* restrict a_ir = a_ic_use + n_iter_rect * ps_a_use; \ + ctype* restrict c_ir = c_jr + n_iter_rect * irstep_c; \ +\ + /* compute the remaining triangular part */ \ + for( dim_t i = m_rect;( i < mc_cur) && (m_off_cblock < n_off_cblock + nr_cur); i += MR ) \ + { \ + const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \ +\ + /* Invoke the gemmsup millikernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + if( col_pref ) \ { \ - /* Invoke the gemmsup millikernel. */ \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ + PASTEMAC(ch,update_lower_triang)( n_off_cblock, m_off_cblock, \ + nr_cur, mr_cur, \ + ct, cs_ct, rs_ct, \ + beta_use, \ + c_ir, cs_c, rs_c ); \ } \ else \ { \ - /* Invoke the gemmsup millikernel. */ \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* If c and ct are col-major, induce transpose and call update for lower-triangle of C */ \ - if( col_pref ) \ - { \ - PASTEMAC(ch,update_lower_triang)( n_off_cblock, m_off_cblock, \ - nr_cur, mr_cur, \ - ct, cs_ct, rs_ct, \ - beta_use, \ - c_ir, cs_c, rs_c ); \ - } \ - else \ - { \ - PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \ - mr_cur, nr_cur, \ - ct, rs_ct, cs_ct, \ - beta_use, \ - c_ir, rs_c, cs_c ); \ - } \ + PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \ + mr_cur, nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c ); \ } \ + a_ir += ps_a_use; \ + c_ir += irstep_c; \ + m_off_cblock += mr_cur; \ +\ } \ } \ } \ diff --git a/frame/util/bli_util_update.c b/frame/util/bli_util_update.c index 0f23424c8..b57c06572 100644 --- a/frame/util/bli_util_update.c +++ b/frame/util/bli_util_update.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -61,25 +61,35 @@ void PASTEMAC(ch, varname) \ start = ((n_off < m_off) && (m_off < n_off + n_cur)) ? m_off: n_off; \ end = ((n_off < m_off + m_cur) && (m_off + m_cur < n_off + n_cur))? (m_off + m_cur):(n_off + n_cur); \ \ - if( beta_val != 0.0 ) \ + if ( beta_val == 1.0 ) \ { \ for(diag = start, m= start-m_off; diag < end; diag++, m++) \ for(n = 0; n <= diag-n_off; n++) \ - c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + c[m*rs_c + n] += ct[m*rs_ct + n]; \ \ for(; m < m_cur; m++) \ for(n = 0; n < n_cur; n++) \ - c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + c[m*rs_c + n] += ct[m*rs_ct + n]; \ + } \ + else if( beta_val == 0.0 )\ + { \ + for(diag = start, m= start-m_off; diag < end; diag++, m++) \ + for(n = 0; n <= diag-n_off; n++) \ + c[m*rs_c + n] = ct[m*rs_ct + n]; \ +\ + for(; m < m_cur; m++) \ + for(n = 0; n < n_cur; n++) \ + c[m*rs_c + n] = ct[m*rs_ct + n]; \ } \ else \ { \ for(diag = start, m= start-m_off; diag < end; diag++, m++) \ for(n = 0; n <= diag-n_off; n++) \ - c[m*rs_c + n] = ct[m*rs_ct + n]; \ + c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ \ for(; m < m_cur; m++) \ for(n = 0; n < n_cur; n++) \ - c[m*rs_c + n] = ct[m*rs_ct + n]; \ + c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ } \ \ return; \ @@ -109,17 +119,17 @@ void PASTEMAC(ch, varname) \ start = ((n_off < m_off) && (m_off < n_off + n_cur)) ? m_off: n_off; \ end = ((n_off < m_off + m_cur) && (m_off + m_cur < n_off + n_cur))? (m_off + m_cur):(n_off + n_cur); \ \ - if( beta_val != 0.0 ) \ + if( beta_val == 1.0 ) \ { \ for(m = 0; m < start-m_off; m++) \ for(n = 0; n < n_cur; n++) \ - c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + c[m*rs_c + n] += ct[m*rs_ct + n]; \ \ - for(diag = start, m= start-m_off; diag < end; diag++, m++) \ - for(n = diag-n_off; n < n_cur; n++) \ - c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + for(diag = start, m= start-m_off; diag < end; diag++, m++) \ + for(n = diag-n_off; n < n_cur; n++) \ + c[m*rs_c + n] += ct[m*rs_ct + n]; \ } \ - else \ + else if ( beta_val == 0.0 )\ { \ for(m = 0; m < start-m_off; m++) \ for(n = 0; n < n_cur; n++) \ @@ -129,6 +139,16 @@ void PASTEMAC(ch, varname) \ for(n = diag-n_off; n < n_cur; n++) \ c[m*rs_c + n] = ct[m*rs_ct + n]; \ } \ + else \ + { \ + for(m = 0; m < start-m_off; m++) \ + for(n = 0; n < n_cur; n++) \ + c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ +\ + for(diag = start, m= start-m_off; diag < end; diag++, m++) \ + for(n = diag-n_off; n < n_cur; n++) \ + c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + } \ \ return; \ } diff --git a/kernels/zen/util/bli_thresh_funcs_zen.c b/kernels/zen/util/bli_thresh_funcs_zen.c index 3aed8bf5b..1b5fc8699 100644 --- a/kernels/zen/util/bli_thresh_funcs_zen.c +++ b/kernels/zen/util/bli_thresh_funcs_zen.c @@ -79,7 +79,7 @@ bool bli_cntx_syrksup_thresh_is_met_zen( obj_t* a, obj_t* b, obj_t* c, cntx_t* c } else { - if( n < 150 ) return TRUE; + if( n <= 432 ) return TRUE; else return FALSE; } }