Merge "Added optimized single threaded dtrsm small for left cases" into amd-staging-milan-3.1

This commit is contained in:
Nallani Bhaskar
2021-05-19 00:47:56 -04:00
committed by Gerrit Code Review
4 changed files with 13546 additions and 16385 deletions

View File

@@ -55,50 +55,6 @@ void bli_trsm_front
obj_t b_local;
obj_t c_local;
#ifdef PRINT_SMALL_TRSM_INFO
printf("Side:: %c\n", side ? 'R' : 'L');
if (bli_obj_datatype(*a) == BLIS_FLOAT)
printf("Alpha:: %9.2e\n", *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, *alpha)));
else if (bli_obj_datatype(*a) == BLIS_DOUBLE)
printf("Alpha is double:: %9.2e\n", *((double *)bli_obj_buffer_for_const(BLIS_DOUBLE, *alpha)));
else
printf("Unsupported datatype for Alpha\n");
printf("A:: M = %d, N = %d, elem_size = %d, row_off = %ld, col_off = %ld, rs = %d, cs = %d, trans = %c, TRIANG = %c, unit diag = %c\n", a->dim[0], a->dim[1], bli_obj_elem_size(*a ), bli_obj_row_off(*a), bli_obj_col_off(*a), a->rs, a->cs, bli_obj_has_trans(*a) ? 'Y' : 'N', bli_obj_is_upper(*a) ? 'U' : bli_obj_is_lower(*a) ? 'L' : 'N', bli_obj_has_unit_diag(*a) ? 'Y' : 'N');
#ifdef PRINT_SMALL_TRSM
//bli_printm("a", a, "%4.1f", "");
#endif
printf("B:: M = %d, N = %d, elem_size = %d, row_off = %ld, col_off = %ld, rs = %d, cs = %d, trans = %c\n", b->dim[0], b->dim[1], bli_obj_elem_size(*a ), bli_obj_row_off(*a), bli_obj_col_off(*a), b->rs, b->cs, bli_obj_has_trans(*b) ? 'Y' : 'N');
#ifdef PRINT_SMALL_TRSM
//bli_printm("b", b, "%4.1f", "");
#endif
fflush(stdout);
#endif
#if 0
for (i = 0; i < m; i++) //no. of cols of B
{
for (j = 0; j < n; j++) //no. of rows of B
{
B[i*n + j] = 1001 + j + (i*n);
}
}
for (i = 0; i < m; i++) //no. of cols of B
{
for (j = i; j < m; j++) //no. of rows of B
{
L[i*m + j] = 2001 + j + (i*m);
}
}
#endif
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
gint_t status = bli_trsm_small( side, alpha, a, b, cntx, cntl );
if ( status == BLIS_SUCCESS )
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
return;
}
#endif
// Check parameters.
if ( bli_error_checking_is_enabled() )
bli_trsm_check( side, alpha, a, b, &BLIS_ZERO, b, cntx );

View File

@@ -229,6 +229,7 @@ void PASTEF77(ch,blasname) \
(ftype*)b, rs_b, \
NULL \
); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \
return; \
} \
else if(bli_is_trans(blis_transa)) \
@@ -244,6 +245,7 @@ void PASTEF77(ch,blasname) \
(ftype*)b, rs_b, \
NULL \
); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \
return; \
} \
} \
@@ -268,6 +270,7 @@ void PASTEF77(ch,blasname) \
PASTEMAC(ch,invscals)( a_conj, b[indx] ); \
} \
}\
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \
return; \
} \
} \
@@ -290,6 +293,7 @@ void PASTEF77(ch,blasname) \
(ftype*)a, cs_a, rs_a, \
(ftype*)b, cs_b, \
NULL); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \
return; \
} \
else if(bli_is_trans(blis_transa)) \
@@ -307,6 +311,7 @@ void PASTEF77(ch,blasname) \
(ftype*)a, cs_a, rs_a, \
(ftype*)b, cs_b, \
NULL); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \
return; \
} \
} \
@@ -331,6 +336,7 @@ void PASTEF77(ch,blasname) \
PASTEMAC(ch,invscals)( a_conj, b[indx*cs_b] ); \
}\
} \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \
return; \
} \
} \
@@ -374,6 +380,265 @@ void PASTEF77(ch,blasname) \
#endif
#ifdef BLIS_ENABLE_BLAS
INSERT_GENTFUNC_BLAS( trsm, trsm )
#ifdef BLIS_CONFIG_EPYC
void dtrsm_
(
const f77_char* side,
const f77_char* uploa,
const f77_char* transa,
const f77_char* diaga,
const f77_int* m,
const f77_int* n,
const double* alpha,
const double* a, const f77_int* lda,
double* b, const f77_int* ldb
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO)
AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d),
*side, *uploa,*transa, *diaga, *m, *n,
(void*)alpha,*lda, *ldb);
side_t blis_side;
uplo_t blis_uploa;
trans_t blis_transa;
diag_t blis_diaga;
dim_t m0, n0;
conj_t conja = BLIS_NO_CONJUGATE ;
/* Initialize BLIS. */
bli_init_auto();
/* Perform BLAS parameter checking. */
PASTEBLACHK(trsm)
(
MKSTR(d),
MKSTR(trsm),
side,
uploa,
transa,
diaga,
m,
n,
lda,
ldb
);
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_side( *side, &blis_side );
bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa );
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa );
bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga );
/* Typecast BLAS integers to BLIS integers. */
bli_convert_blas_dim1( *m, m0 );
bli_convert_blas_dim1( *n, n0 );
/* Set the row and column strides of the matrix operands. */
const inc_t rs_a = 1;
const inc_t cs_a = *lda;
const inc_t rs_b = 1;
const inc_t cs_b = *ldb;
const num_t dt = BLIS_DOUBLE;
if( n0 == 1 )
{
if( blis_side == BLIS_LEFT )
{
if(bli_is_notrans(blis_transa))
{
bli_dtrsv_unf_var2
(
blis_uploa,
blis_transa,
blis_diaga,
m0,
(double*)alpha,
(double*)a, rs_a, cs_a,
(double*)b, rs_b,
NULL
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return;
}
else if(bli_is_trans(blis_transa))
{
bli_dtrsv_unf_var1
(
blis_uploa,
blis_transa,
blis_diaga,
m0,
(double*)alpha,
(double*)a, rs_a, cs_a,
(double*)b, rs_b,
NULL
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return;
}
}
else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) )
{
/* b = alpha * b; */
bli_dscalv_ex
(
conja,
m0,
(double*)alpha,
b, rs_b,
NULL,
NULL
);
if(blis_diaga == BLIS_NONUNIT_DIAG)
{
double inva = 1.0/ *a;
for(int indx = 0; indx < m0; indx ++)
{
b[indx] = ( inva * b[indx] );
}
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return;
}
}
else if( m0 == 1 )
{
if(blis_side == BLIS_RIGHT)
{
if(bli_is_notrans(blis_transa))
{
if(blis_uploa == BLIS_UPPER)
blis_uploa = BLIS_LOWER;
else
blis_uploa = BLIS_UPPER;
bli_dtrsv_unf_var1
(
blis_uploa,
blis_transa,
blis_diaga,
n0,
(double*)alpha,
(double*)a, cs_a, rs_a,
(double*)b, cs_b,
NULL
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return;
}
else if(bli_is_trans(blis_transa))
{
if(blis_uploa == BLIS_UPPER)
blis_uploa = BLIS_LOWER;
else
blis_uploa = BLIS_UPPER;
bli_dtrsv_unf_var2
(
blis_uploa,
blis_transa,
blis_diaga,
n0,
(double*)alpha,
(double*)a, cs_a, rs_a,
(double*)b, cs_b,
NULL
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return;
}
}
else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 ))
{
/* b = alpha * b; */
bli_dscalv_ex
(
conja,
n0,
(double*)alpha,
b, cs_b,
NULL,
NULL
);
if(blis_diaga == BLIS_NONUNIT_DIAG)
{
double inva = 1.0/ *a;
for(int indx = 0; indx < n0; indx ++)
{
b[indx*cs_b] = (inva * b[indx*cs_b] );
}
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return;
}
}
const struc_t struca = BLIS_TRIANGULAR;
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t ao = BLIS_OBJECT_INITIALIZER;
obj_t bo = BLIS_OBJECT_INITIALIZER;
dim_t mn0_a;
bli_set_dim_with_side( blis_side, m0, n0, &mn0_a );
bli_obj_init_finish_1x1( dt, (double*)alpha, &alphao );
bli_obj_init_finish( dt, mn0_a, mn0_a, (double*)a, rs_a, cs_a, &ao );
bli_obj_init_finish( dt, m0, n0, (double*)b, rs_b, cs_b, &bo );
bli_obj_set_uplo( blis_uploa, &ao );
bli_obj_set_diag( blis_diaga, &ao );
bli_obj_set_conjtrans( blis_transa, &ao );
bli_obj_set_struc( struca, &ao );
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
/* Irrespective of num threads single thread bli_dtrsm_small
* is performing better than other implementations for [m,n]<=128 */
/* ToDo: This condition will be tunned for single thread */
if(m0 <=128 && n0<=128)
{
err_t status;
status = bli_trsm_small
(
blis_side,
&alphao,
&ao,
&bo,
NULL,
NULL
);
if (status == BLIS_SUCCESS)
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
}
#endif
bli_trsmnat
(
blis_side,
&alphao,
&ao,
&bo,
NULL,
NULL
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO)
/* Finalize BLIS. */
bli_finalize_auto();
}
GENTFUNC( float, s, trsm, trsm )
INSERT_GENTFUNC_BLAS_CZ( trsm, trsm )
#else
INSERT_GENTFUNC_BLAS( trsm, trsm )
#endif
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -260,3 +260,13 @@ void bli_dgemm_ref_k1_nn
double* c, const inc_t ldc
);
err_t bli_trsm_small
(
side_t side,
obj_t* alpha,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);