mirror of
https://github.com/amd/blis.git
synced 2026-05-13 10:35:38 +00:00
Optimized SUP code for GEMMT
Details: - Eliminated the IR loop in ref_var2m functions. - Handled the rectangular and triangular portions of C matrix separately. - Added a condition to check and eliminate zero regions inside IC loop. - modified kc selection logic to choose optimal KC in SUP - Updated thresholds to choose between SUP and native. Change-Id: I21908eaa6bc3a8f37bdea29f7bfca7e6fcfee724
This commit is contained in:
committed by
Dipal M Zambare
parent
faeb79f2b9
commit
10ca8710f0
@@ -1497,7 +1497,7 @@ void PASTEMACT(ch,opname,uplo,varname) \
|
||||
function pointer type. */ \
|
||||
PASTECH(ch,gemmsup_ker_ft) \
|
||||
gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \
|
||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( ctype ) ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( ctype ) ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
||||
\
|
||||
/* storage-scheme of ct should be same as that of C.
|
||||
Since update routines only support row-major order,
|
||||
@@ -1580,6 +1580,7 @@ void PASTEMACT(ch,opname,uplo,varname) \
|
||||
dim_t m_off = 0; \
|
||||
dim_t n_off = 0; \
|
||||
doff_t diagoffc; \
|
||||
dim_t i, ip; \
|
||||
\
|
||||
/* Loop over the n dimension (NC rows/columns at a time). */ \
|
||||
/*for ( dim_t jj = 0; jj < jc_iter; jj += 1 )*/ \
|
||||
@@ -1690,7 +1691,8 @@ void PASTEMACT(ch,opname,uplo,varname) \
|
||||
for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \
|
||||
{ \
|
||||
/* Calculate the thread's current IC block dimension. */ \
|
||||
const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \
|
||||
dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \
|
||||
dim_t nc_pruned = nc_cur; \
|
||||
\
|
||||
ctype* restrict a_ic = a_pc + ii * icstep_a; \
|
||||
ctype* restrict c_ic = c_jc + ii * icstep_c; \
|
||||
@@ -1699,7 +1701,24 @@ void PASTEMACT(ch,opname,uplo,varname) \
|
||||
\
|
||||
if(bli_gemmt_is_strictly_above_diag( m_off, n_off, mc_cur, nc_cur ) ) continue; \
|
||||
\
|
||||
PASTEMAC(ch,set0s_mxn) ( MR, NR, ct, rs_ct, cs_ct ); \
|
||||
diagoffc = m_off - n_off; \
|
||||
\
|
||||
if( diagoffc < 0 ) \
|
||||
{ \
|
||||
ip = -diagoffc / MR; \
|
||||
i = ip * MR; \
|
||||
mc_cur = mc_cur - i; \
|
||||
diagoffc = -diagoffc % MR; \
|
||||
m_off += i; \
|
||||
c_ic = c_ic + ( i ) * rs_c; \
|
||||
a_ic = a_ic + ( i ) * rs_a; \
|
||||
} \
|
||||
\
|
||||
if( ( diagoffc + mc_cur ) < nc_cur ) \
|
||||
{ \
|
||||
nc_pruned = diagoffc + mc_cur; \
|
||||
} \
|
||||
\
|
||||
ctype* a_use; \
|
||||
inc_t rs_a_use, cs_a_use, ps_a_use; \
|
||||
\
|
||||
@@ -1755,8 +1774,8 @@ void PASTEMACT(ch,opname,uplo,varname) \
|
||||
bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \
|
||||
\
|
||||
/* Compute number of primary and leftover components of the JR loop. */ \
|
||||
dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \
|
||||
dim_t jr_left = nc_cur % NR; \
|
||||
dim_t jr_iter = ( nc_pruned + NR - 1 ) / NR; \
|
||||
dim_t jr_left = nc_pruned % NR; \
|
||||
\
|
||||
/* Compute the JR loop thread range for the current thread. */ \
|
||||
dim_t jr_start, jr_end; \
|
||||
@@ -1785,76 +1804,92 @@ void PASTEMACT(ch,opname,uplo,varname) \
|
||||
ctype* restrict b_jr = b_pc_use + j * ps_b_use; \
|
||||
ctype* restrict c_jr = c_ic + j * jrstep_c; \
|
||||
\
|
||||
const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \
|
||||
const dim_t ir_left = mc_cur % MR; \
|
||||
dim_t i; \
|
||||
dim_t m_zero = 0; \
|
||||
dim_t n_iter_zero = 0; \
|
||||
\
|
||||
/* Loop over the m dimension (MR rows at a time). */ \
|
||||
for(dim_t i = 0; i < ir_iter; i += 1 ) \
|
||||
m_off_cblock = m_off; \
|
||||
n_off_cblock = n_off + j * NR; \
|
||||
\
|
||||
if(bli_gemmt_is_strictly_below_diag(m_off_cblock, n_off_cblock, mc_cur, nc_cur)) \
|
||||
{ \
|
||||
const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \
|
||||
m_zero = 0; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* compute number of rows that are filled with zeroes and can be ignored */ \
|
||||
n_iter_zero = (n_off_cblock < m_off_cblock)? 0 : (n_off_cblock - m_off)/MR; \
|
||||
m_zero = n_iter_zero * MR; \
|
||||
} \
|
||||
\
|
||||
m_off_cblock = m_off + i * MR; \
|
||||
n_off_cblock = n_off + j * NR; \
|
||||
if(bli_gemmt_is_strictly_above_diag( m_off_cblock, n_off_cblock, mr_cur, nr_cur )) continue; \
|
||||
ctype* restrict a_ir = a_ic_use + i * ps_a_use; \
|
||||
ctype* restrict c_ir = c_jr + i * irstep_c; \
|
||||
if( bli_gemmt_is_strictly_below_diag(m_off_cblock, n_off_cblock, mr_cur, nr_cur ) ) \
|
||||
ctype* restrict a_ir = a_ic_use + n_iter_zero * ps_a_use; \
|
||||
ctype* restrict c_ir = c_jr + n_iter_zero * irstep_c; \
|
||||
\
|
||||
/* Ignore the zero region */ \
|
||||
m_off_cblock += m_zero; \
|
||||
\
|
||||
/* Compute the triangular part */ \
|
||||
for( i = m_zero; (i < mc_cur) && ( m_off_cblock < n_off_cblock + nr_cur); i += MR ) \
|
||||
{ \
|
||||
const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \
|
||||
\
|
||||
/* Invoke the gemmsup millikernel. */ \
|
||||
gemmsup_ker \
|
||||
( \
|
||||
conja, \
|
||||
conjb, \
|
||||
mr_cur, \
|
||||
nr_cur, \
|
||||
kc_cur, \
|
||||
alpha_cast, \
|
||||
a_ir, rs_a_use, cs_a_use, \
|
||||
b_jr, rs_b_use, cs_b_use, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
/* Scale the bottom edge of C and add the result from above. */ \
|
||||
/* If c and ct are col-major, induce transpose and call update for upper-triangle of C */ \
|
||||
if( col_pref ) \
|
||||
{ \
|
||||
/* Invoke the gemmsup millikernel. */ \
|
||||
gemmsup_ker \
|
||||
( \
|
||||
conja, \
|
||||
conjb, \
|
||||
mr_cur, \
|
||||
nr_cur, \
|
||||
kc_cur, \
|
||||
alpha_cast, \
|
||||
a_ir, rs_a_use, cs_a_use, \
|
||||
b_jr, rs_b_use, cs_b_use, \
|
||||
beta_use, \
|
||||
c_ir, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \
|
||||
nr_cur, mr_cur, \
|
||||
ct, cs_ct, rs_ct, \
|
||||
beta_use, \
|
||||
c_ir, cs_c, rs_c ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemmsup millikernel. */ \
|
||||
gemmsup_ker \
|
||||
( \
|
||||
conja, \
|
||||
conjb, \
|
||||
mr_cur, \
|
||||
nr_cur, \
|
||||
kc_cur, \
|
||||
alpha_cast, \
|
||||
a_ir, rs_a_use, cs_a_use, \
|
||||
b_jr, rs_b_use, cs_b_use, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
/* Scale the bottom edge of C and add the result from above. */ \
|
||||
/* If c and ct are col-major, induce transpose and call update for upper-triangle of C */ \
|
||||
if( col_pref ) \
|
||||
{ \
|
||||
PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \
|
||||
nr_cur, mr_cur, \
|
||||
ct, cs_ct, rs_ct, \
|
||||
beta_use, \
|
||||
c_ir, cs_c, rs_c ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \
|
||||
mr_cur, nr_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
beta_use, \
|
||||
c_ir, rs_c, cs_c ); \
|
||||
} \
|
||||
PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \
|
||||
mr_cur, nr_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
beta_use, \
|
||||
c_ir, rs_c, cs_c ); \
|
||||
} \
|
||||
\
|
||||
a_ir += ps_a_use; \
|
||||
c_ir += irstep_c; \
|
||||
m_off_cblock += mr_cur; \
|
||||
} \
|
||||
\
|
||||
/* Invoke the gemmsup millikerneli for remaining rectangular part. */ \
|
||||
gemmsup_ker \
|
||||
( \
|
||||
conja, \
|
||||
conjb, \
|
||||
(i > mc_cur)? 0: mc_cur - i, \
|
||||
nr_cur, \
|
||||
kc_cur, \
|
||||
alpha_cast, \
|
||||
a_ir, rs_a_use, cs_a_use, \
|
||||
b_jr, rs_b_use, cs_b_use, \
|
||||
beta_use, \
|
||||
c_ir, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
} \
|
||||
} \
|
||||
\
|
||||
@@ -1889,8 +1924,6 @@ PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c
|
||||
|
||||
INSERT_GENTFUNC_L( gemmtsup, ref_var2m )
|
||||
|
||||
|
||||
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC( ctype, ch, opname, uplo, varname ) \
|
||||
\
|
||||
@@ -1978,6 +2011,13 @@ void PASTEMACT(ch,opname,uplo,varname) \
|
||||
stor_id == BLIS_CCC ) KC = KC0; \
|
||||
else if ( stor_id == BLIS_RRC || \
|
||||
stor_id == BLIS_CRC ) KC = KC0; \
|
||||
else if ( stor_id == BLIS_RCR ) \
|
||||
{ \
|
||||
if ( m <= 4*MR ) KC = KC0; \
|
||||
else if ( m <= 36*MR ) KC = KC0 / 2; \
|
||||
else if ( m <= 56*MR ) KC = (( KC0 / 3 ) / 4 ) * 4; \
|
||||
else KC = KC0 / 4; \
|
||||
} \
|
||||
else if ( m <= MR && n <= NR ) KC = KC0; \
|
||||
else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; \
|
||||
else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; \
|
||||
@@ -2026,8 +2066,6 @@ void PASTEMACT(ch,opname,uplo,varname) \
|
||||
\
|
||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||
\
|
||||
PASTEMAC(ch,set0s_mxn) ( MR, NR, ct, rs_ct, cs_ct ); \
|
||||
\
|
||||
ctype* restrict a_00 = a; \
|
||||
ctype* restrict b_00 = b; \
|
||||
@@ -2097,6 +2135,7 @@ void PASTEMACT(ch,opname,uplo,varname) \
|
||||
dim_t n_off = 0; \
|
||||
doff_t diagoffc; \
|
||||
dim_t m_off_cblock, n_off_cblock; \
|
||||
dim_t jp, j; \
|
||||
\
|
||||
/* Compute number of primary and leftover components of the JC loop. */ \
|
||||
/*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \
|
||||
@@ -2211,14 +2250,37 @@ void PASTEMACT(ch,opname,uplo,varname) \
|
||||
for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \
|
||||
{ \
|
||||
/* Calculate the thread's current IC block dimension. */ \
|
||||
const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \
|
||||
dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \
|
||||
\
|
||||
dim_t nc_pruned = nc_cur; \
|
||||
\
|
||||
m_off = ii; \
|
||||
n_off = jj; \
|
||||
\
|
||||
if(bli_gemmt_is_strictly_below_diag(m_off, n_off, mc_cur, nc_cur)) continue; \
|
||||
\
|
||||
ctype* restrict a_ic = a_pc + ii * icstep_a; \
|
||||
ctype* restrict c_ic = c_jc + ii * icstep_c; \
|
||||
\
|
||||
doff_t diagoffc = m_off - n_off; \
|
||||
\
|
||||
ctype* restrict b_pc_pruned = b_pc_use; \
|
||||
\
|
||||
if(diagoffc > 0 ) \
|
||||
{ \
|
||||
jp = diagoffc / NR; \
|
||||
j = jp * NR; \
|
||||
nc_pruned = nc_cur - j; \
|
||||
n_off += j; \
|
||||
diagoffc = diagoffc % NR; \
|
||||
c_ic = c_ic + ( j ) * cs_c; \
|
||||
b_pc_pruned = b_pc_use + ( jp ) * ps_b_use; \
|
||||
} \
|
||||
\
|
||||
if( ( ( -diagoffc ) + nc_pruned ) < mc_cur ) \
|
||||
{ \
|
||||
mc_cur = -diagoffc + nc_pruned; \
|
||||
} \
|
||||
\
|
||||
ctype* a_use; \
|
||||
inc_t rs_a_use, cs_a_use, ps_a_use; \
|
||||
@@ -2275,8 +2337,8 @@ void PASTEMACT(ch,opname,uplo,varname) \
|
||||
bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \
|
||||
\
|
||||
/* Compute number of primary and leftover components of the JR loop. */ \
|
||||
dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \
|
||||
dim_t jr_left = nc_cur % NR; \
|
||||
dim_t jr_iter = ( nc_pruned + NR - 1 ) / NR; \
|
||||
dim_t jr_left = nc_pruned % NR; \
|
||||
\
|
||||
/* Compute the JR loop thread range for the current thread. */ \
|
||||
dim_t jr_start, jr_end; \
|
||||
@@ -2302,77 +2364,89 @@ void PASTEMACT(ch,opname,uplo,varname) \
|
||||
/*
|
||||
ctype* restrict b_jr = b_pc_use + j * jrstep_b; \
|
||||
*/ \
|
||||
ctype* restrict b_jr = b_pc_use + j * ps_b_use; \
|
||||
ctype* restrict b_jr = b_pc_pruned + j * ps_b_use; \
|
||||
ctype* restrict c_jr = c_ic + j * jrstep_c; \
|
||||
dim_t m_rect = 0; \
|
||||
dim_t n_iter_rect = 0; \
|
||||
\
|
||||
const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \
|
||||
const dim_t ir_left = mc_cur % MR; \
|
||||
m_off_cblock = m_off; \
|
||||
n_off_cblock = n_off + j * NR; \
|
||||
\
|
||||
/* Loop over the m dimension (MR rows at a time). */ \
|
||||
for(dim_t i = 0; i < ir_iter; i += 1 ) \
|
||||
if(bli_gemmt_is_strictly_above_diag(m_off_cblock, n_off_cblock, mc_cur, nr_cur)) \
|
||||
{ \
|
||||
const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \
|
||||
m_off_cblock = m_off + i * MR; \
|
||||
n_off_cblock = n_off + j * NR; \
|
||||
if( bli_gemmt_is_strictly_below_diag( m_off_cblock, n_off_cblock, mr_cur, nr_cur )) continue; \
|
||||
ctype* restrict a_ir = a_ic_use + i * ps_a_use; \
|
||||
ctype* restrict c_ir = c_jr + i * irstep_c; \
|
||||
if(bli_gemmt_is_strictly_above_diag( m_off_cblock, n_off_cblock, mr_cur, nr_cur )) \
|
||||
m_rect = mc_cur; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* calculate the number of rows in rectangular region of the block */ \
|
||||
n_iter_rect = n_off_cblock < m_off_cblock ? 0: (n_off_cblock - m_off_cblock) / MR; \
|
||||
m_rect = n_iter_rect * MR; \
|
||||
} \
|
||||
\
|
||||
/* Compute the rectangular part */ \
|
||||
gemmsup_ker \
|
||||
( \
|
||||
conja, \
|
||||
conjb, \
|
||||
m_rect, \
|
||||
nr_cur, \
|
||||
kc_cur, \
|
||||
alpha_cast, \
|
||||
a_ic_use, rs_a_use, cs_a_use, \
|
||||
b_jr, rs_b_use, cs_b_use, \
|
||||
beta_use, \
|
||||
c_jr, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
m_off_cblock = m_off + m_rect; \
|
||||
\
|
||||
ctype* restrict a_ir = a_ic_use + n_iter_rect * ps_a_use; \
|
||||
ctype* restrict c_ir = c_jr + n_iter_rect * irstep_c; \
|
||||
\
|
||||
/* compute the remaining triangular part */ \
|
||||
for( dim_t i = m_rect;( i < mc_cur) && (m_off_cblock < n_off_cblock + nr_cur); i += MR ) \
|
||||
{ \
|
||||
const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \
|
||||
\
|
||||
/* Invoke the gemmsup millikernel. */ \
|
||||
gemmsup_ker \
|
||||
( \
|
||||
conja, \
|
||||
conjb, \
|
||||
mr_cur, \
|
||||
nr_cur, \
|
||||
kc_cur, \
|
||||
alpha_cast, \
|
||||
a_ir, rs_a_use, cs_a_use, \
|
||||
b_jr, rs_b_use, cs_b_use, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
if( col_pref ) \
|
||||
{ \
|
||||
/* Invoke the gemmsup millikernel. */ \
|
||||
gemmsup_ker \
|
||||
( \
|
||||
conja, \
|
||||
conjb, \
|
||||
mr_cur, \
|
||||
nr_cur, \
|
||||
kc_cur, \
|
||||
alpha_cast, \
|
||||
a_ir, rs_a_use, cs_a_use, \
|
||||
b_jr, rs_b_use, cs_b_use, \
|
||||
beta_use, \
|
||||
c_ir, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
PASTEMAC(ch,update_lower_triang)( n_off_cblock, m_off_cblock, \
|
||||
nr_cur, mr_cur, \
|
||||
ct, cs_ct, rs_ct, \
|
||||
beta_use, \
|
||||
c_ir, cs_c, rs_c ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemmsup millikernel. */ \
|
||||
gemmsup_ker \
|
||||
( \
|
||||
conja, \
|
||||
conjb, \
|
||||
mr_cur, \
|
||||
nr_cur, \
|
||||
kc_cur, \
|
||||
alpha_cast, \
|
||||
a_ir, rs_a_use, cs_a_use, \
|
||||
b_jr, rs_b_use, cs_b_use, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* If c and ct are col-major, induce transpose and call update for lower-triangle of C */ \
|
||||
if( col_pref ) \
|
||||
{ \
|
||||
PASTEMAC(ch,update_lower_triang)( n_off_cblock, m_off_cblock, \
|
||||
nr_cur, mr_cur, \
|
||||
ct, cs_ct, rs_ct, \
|
||||
beta_use, \
|
||||
c_ir, cs_c, rs_c ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \
|
||||
mr_cur, nr_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
beta_use, \
|
||||
c_ir, rs_c, cs_c ); \
|
||||
} \
|
||||
PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \
|
||||
mr_cur, nr_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
beta_use, \
|
||||
c_ir, rs_c, cs_c ); \
|
||||
} \
|
||||
a_ir += ps_a_use; \
|
||||
c_ir += irstep_c; \
|
||||
m_off_cblock += mr_cur; \
|
||||
\
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2020, Advanced Micro Devices, Inc.
|
||||
Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -61,25 +61,35 @@ void PASTEMAC(ch, varname) \
|
||||
start = ((n_off < m_off) && (m_off < n_off + n_cur)) ? m_off: n_off; \
|
||||
end = ((n_off < m_off + m_cur) && (m_off + m_cur < n_off + n_cur))? (m_off + m_cur):(n_off + n_cur); \
|
||||
\
|
||||
if( beta_val != 0.0 ) \
|
||||
if ( beta_val == 1.0 ) \
|
||||
{ \
|
||||
for(diag = start, m= start-m_off; diag < end; diag++, m++) \
|
||||
for(n = 0; n <= diag-n_off; n++) \
|
||||
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
|
||||
c[m*rs_c + n] += ct[m*rs_ct + n]; \
|
||||
\
|
||||
for(; m < m_cur; m++) \
|
||||
for(n = 0; n < n_cur; n++) \
|
||||
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
|
||||
c[m*rs_c + n] += ct[m*rs_ct + n]; \
|
||||
} \
|
||||
else if( beta_val == 0.0 )\
|
||||
{ \
|
||||
for(diag = start, m= start-m_off; diag < end; diag++, m++) \
|
||||
for(n = 0; n <= diag-n_off; n++) \
|
||||
c[m*rs_c + n] = ct[m*rs_ct + n]; \
|
||||
\
|
||||
for(; m < m_cur; m++) \
|
||||
for(n = 0; n < n_cur; n++) \
|
||||
c[m*rs_c + n] = ct[m*rs_ct + n]; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
for(diag = start, m= start-m_off; diag < end; diag++, m++) \
|
||||
for(n = 0; n <= diag-n_off; n++) \
|
||||
c[m*rs_c + n] = ct[m*rs_ct + n]; \
|
||||
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
|
||||
\
|
||||
for(; m < m_cur; m++) \
|
||||
for(n = 0; n < n_cur; n++) \
|
||||
c[m*rs_c + n] = ct[m*rs_ct + n]; \
|
||||
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
|
||||
} \
|
||||
\
|
||||
return; \
|
||||
@@ -109,17 +119,17 @@ void PASTEMAC(ch, varname) \
|
||||
start = ((n_off < m_off) && (m_off < n_off + n_cur)) ? m_off: n_off; \
|
||||
end = ((n_off < m_off + m_cur) && (m_off + m_cur < n_off + n_cur))? (m_off + m_cur):(n_off + n_cur); \
|
||||
\
|
||||
if( beta_val != 0.0 ) \
|
||||
if( beta_val == 1.0 ) \
|
||||
{ \
|
||||
for(m = 0; m < start-m_off; m++) \
|
||||
for(n = 0; n < n_cur; n++) \
|
||||
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
|
||||
c[m*rs_c + n] += ct[m*rs_ct + n]; \
|
||||
\
|
||||
for(diag = start, m= start-m_off; diag < end; diag++, m++) \
|
||||
for(n = diag-n_off; n < n_cur; n++) \
|
||||
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
|
||||
for(diag = start, m= start-m_off; diag < end; diag++, m++) \
|
||||
for(n = diag-n_off; n < n_cur; n++) \
|
||||
c[m*rs_c + n] += ct[m*rs_ct + n]; \
|
||||
} \
|
||||
else \
|
||||
else if ( beta_val == 0.0 )\
|
||||
{ \
|
||||
for(m = 0; m < start-m_off; m++) \
|
||||
for(n = 0; n < n_cur; n++) \
|
||||
@@ -129,6 +139,16 @@ void PASTEMAC(ch, varname) \
|
||||
for(n = diag-n_off; n < n_cur; n++) \
|
||||
c[m*rs_c + n] = ct[m*rs_ct + n]; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
for(m = 0; m < start-m_off; m++) \
|
||||
for(n = 0; n < n_cur; n++) \
|
||||
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
|
||||
\
|
||||
for(diag = start, m= start-m_off; diag < end; diag++, m++) \
|
||||
for(n = diag-n_off; n < n_cur; n++) \
|
||||
c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \
|
||||
} \
|
||||
\
|
||||
return; \
|
||||
}
|
||||
|
||||
@@ -79,7 +79,7 @@ bool bli_cntx_syrksup_thresh_is_met_zen( obj_t* a, obj_t* b, obj_t* c, cntx_t* c
|
||||
}
|
||||
else
|
||||
{
|
||||
if( n < 150 ) return TRUE;
|
||||
if( n <= 432 ) return TRUE;
|
||||
else return FALSE;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user