GEMM: Early return when alpha = zero

Add test in top-level GEMM BLAS layer for alpha = zero. Scale C
appropriately if true and return. This ensures that A and B are not
referenced, and thus that any Inf or NaN values in A or B are not
propagated to C. This was tested in some lower level code paths,
but not consistently.

Solution throughout GEMM codebase assumes scalm handles scale value
(beta from GEMM) equal to 0 without propagating inf/NaNs in C matrix.

Also call AOCL_DTL exits and bli_finalize_auto() in a more consistent
way.

AMD-Internal: [CPUPL-3053]
Change-Id: I4009f311951eb1ce9416cf846e9fa93b7c9219cc
This commit is contained in:
Edward Smyth
2023-03-07 05:13:24 -05:00
parent 5bd2a777ba
commit 873b4f93fd
2 changed files with 250 additions and 24 deletions

View File

@@ -97,6 +97,32 @@ void PASTEF77S(ch,blasname) \
bli_finalize_auto(); \
return; \
} \
\
/* If alpha is zero scale C by beta and return early. */ \
if( PASTEMAC(ch,eq0)( *alpha )) \
{ \
bli_convert_blas_dim1(*m, m0); \
bli_convert_blas_dim1(*n, n0); \
const inc_t rs_c = 1; \
const inc_t cs_c = *ldc; \
\
PASTEMAC2(ch,scalm,_ex)( BLIS_NO_CONJUGATE, \
0, \
BLIS_NONUNIT_DIAG, \
BLIS_DENSE, \
m0, \
n0, \
(ftype*) beta, \
(ftype*) c, rs_c, cs_c, \
NULL, NULL \
); \
\
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
return; \
} \
\
/* Map BLAS chars to their corresponding BLIS enumerated type value. */ \
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \
@@ -133,7 +159,7 @@ void PASTEF77S(ch,blasname) \
); \
\
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);\
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
} \
@@ -215,6 +241,32 @@ void PASTEF77S(ch,blasname) \
bli_finalize_auto(); \
return; \
} \
\
/* If alpha is zero scale C by beta and return early. */ \
if( PASTEMAC(ch,eq0)( *alpha )) \
{ \
bli_convert_blas_dim1(*m, m0); \
bli_convert_blas_dim1(*n, n0); \
const inc_t rs_c = 1; \
const inc_t cs_c = *ldc; \
\
PASTEMAC2(ch,scalm,_ex)( BLIS_NO_CONJUGATE, \
0, \
BLIS_NONUNIT_DIAG, \
BLIS_DENSE, \
m0, \
n0, \
(ftype*) beta, \
(ftype*) c, rs_c, cs_c, \
NULL, NULL \
); \
\
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
return; \
} \
\
/* Map BLAS chars to their corresponding BLIS enumerated type value. */ \
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \
@@ -264,6 +316,9 @@ void PASTEF77S(ch,blasname) \
); \
} \
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
return; \
} \
else if( m0 == 1 ) \
@@ -297,6 +352,9 @@ void PASTEF77S(ch,blasname) \
); \
} \
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
return; \
} \
\
@@ -333,7 +391,7 @@ void PASTEF77S(ch,blasname) \
); \
\
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
} \
@@ -412,6 +470,32 @@ void dzgemm_
return;
}
/* If alpha is zero scale C by beta and return early. */
if( PASTEMAC(z,eq0)( *alpha ))
{
bli_convert_blas_dim1(*m, m0);
bli_convert_blas_dim1(*n, n0);
const inc_t rs_c = 1;
const inc_t cs_c = *ldc;
PASTEMAC2(z,scalm,_ex)( BLIS_NO_CONJUGATE,
0,
BLIS_NONUNIT_DIAG,
BLIS_DENSE,
m0,
n0,
(dcomplex*) beta,
(dcomplex*) c, rs_c, cs_c,
NULL, NULL
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa );
bli_param_map_netlib_to_blis_trans( *transb, &blis_transb );
@@ -454,11 +538,12 @@ void dzgemm_
bli_obj_set_conjtrans( blis_transa, &ao );
bli_obj_set_conjtrans( blis_transb, &bo );
// fall back on native path when zgemm is not handled in sup path.
// fall back on native path when zgemm is not handled in sup path.
bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
}// end of dzgemm_

View File

@@ -97,6 +97,32 @@ void PASTEF77S(ch,blasname) \
bli_finalize_auto(); \
return; \
} \
\
/* If alpha is zero scale C by beta and return early. */ \
if( PASTEMAC(ch,eq0)( *alpha )) \
{ \
bli_convert_blas_dim1(*m, m0); \
bli_convert_blas_dim1(*n, n0); \
const inc_t rs_c = 1; \
const inc_t cs_c = *ldc; \
\
PASTEMAC2(ch,scalm,_ex)( BLIS_NO_CONJUGATE, \
0, \
BLIS_NONUNIT_DIAG, \
BLIS_DENSE, \
m0, \
n0, \
(ftype*) beta, \
(ftype*) c, rs_c, cs_c, \
NULL, NULL \
); \
\
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
return; \
} \
\
/* Map BLAS chars to their corresponding BLIS enumerated type value. */ \
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \
@@ -133,7 +159,7 @@ void PASTEF77S(ch,blasname) \
); \
\
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);\
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
} \
@@ -216,6 +242,32 @@ void PASTEF77S(ch,blasname) \
bli_finalize_auto(); \
return; \
} \
\
/* If alpha is zero scale C by beta and return early. */ \
if( PASTEMAC(ch,eq0)( *alpha )) \
{ \
bli_convert_blas_dim1(*m, m0); \
bli_convert_blas_dim1(*n, n0); \
const inc_t rs_c = 1; \
const inc_t cs_c = *ldc; \
\
PASTEMAC2(ch,scalm,_ex)( BLIS_NO_CONJUGATE, \
0, \
BLIS_NONUNIT_DIAG, \
BLIS_DENSE, \
m0, \
n0, \
(ftype*) beta, \
(ftype*) c, rs_c, cs_c, \
NULL, NULL \
); \
\
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
return; \
} \
\
/* Map BLAS chars to their corresponding BLIS enumerated type value. */ \
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \
@@ -265,6 +317,9 @@ void PASTEF77S(ch,blasname) \
); \
} \
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
return; \
} \
else if( m0 == 1 ) \
@@ -298,6 +353,9 @@ void PASTEF77S(ch,blasname) \
); \
} \
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
return; \
} \
\
@@ -334,7 +392,7 @@ void PASTEF77S(ch,blasname) \
); \
\
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
} \
@@ -374,9 +432,6 @@ void dgemm_blis_impl
double* c, const f77_int* ldc
)
{
trans_t blis_transa;
trans_t blis_transb;
dim_t m0, n0, k0;
@@ -413,6 +468,32 @@ void dgemm_blis_impl
return;
}
/* If alpha is zero scale C by beta and return early. */
if( PASTEMAC(d,eq0)( *alpha ))
{
bli_convert_blas_dim1(*m, m0);
bli_convert_blas_dim1(*n, n0);
const inc_t rs_c = 1;
const inc_t cs_c = *ldc;
PASTEMAC2(d,scalm,_ex)( BLIS_NO_CONJUGATE,
0,
BLIS_NONUNIT_DIAG,
BLIS_DENSE,
m0,
n0,
(double*) beta,
(double*) c, rs_c, cs_c,
NULL, NULL
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(*transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(*transb, &blis_transb);
@@ -475,7 +556,6 @@ void dgemm_blis_impl
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
@@ -495,7 +575,6 @@ void dgemm_blis_impl
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS */
bli_finalize_auto();
return;
}
@@ -531,7 +610,9 @@ void dgemm_blis_impl
}
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS */
bli_finalize_auto();
return;
}
else if (m0 == 1)
@@ -565,6 +646,9 @@ void dgemm_blis_impl
);
}
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS */
bli_finalize_auto();
return;
}
@@ -661,7 +745,6 @@ void dgemm_blis_impl
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
}
@@ -672,6 +755,9 @@ void dgemm_blis_impl
if (status == BLIS_SUCCESS)
{
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS */
bli_finalize_auto();
return;
}
@@ -764,6 +850,32 @@ void zgemm_blis_impl
return;
}
/* If alpha is zero scale C by beta and return early. */
if( PASTEMAC(z,eq0)( *alpha ))
{
bli_convert_blas_dim1(*m, m0);
bli_convert_blas_dim1(*n, n0);
const inc_t rs_c = 1;
const inc_t cs_c = *ldc;
PASTEMAC2(z,scalm,_ex)( BLIS_NO_CONJUGATE,
0,
BLIS_NONUNIT_DIAG,
BLIS_DENSE,
m0,
n0,
(dcomplex*) beta,
(dcomplex*) c, rs_c, cs_c,
NULL, NULL
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa );
bli_param_map_netlib_to_blis_trans( *transb, &blis_transb );
@@ -831,7 +943,6 @@ void zgemm_blis_impl
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS */
bli_finalize_auto();
return;
}
@@ -851,7 +962,9 @@ void zgemm_blis_impl
c, rs_c,
((void *)0));
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
}
@@ -870,6 +983,9 @@ void zgemm_blis_impl
c, cs_c,
((void *)0));
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
}
@@ -918,18 +1034,17 @@ void zgemm_blis_impl
if (status == BLIS_SUCCESS)
{
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
// fall back on native path when zgemm is not handled in sup path.
bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
return;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
}// end of zgemm_
@@ -1005,6 +1120,32 @@ void dzgemm_blis_impl
return;
}
/* If alpha is zero scale C by beta and return early. */
if( PASTEMAC(z,eq0)( *alpha ))
{
bli_convert_blas_dim1(*m, m0);
bli_convert_blas_dim1(*n, n0);
const inc_t rs_c = 1;
const inc_t cs_c = *ldc;
PASTEMAC2(z,scalm,_ex)( BLIS_NO_CONJUGATE,
0,
BLIS_NONUNIT_DIAG,
BLIS_DENSE,
m0,
n0,
(dcomplex*) beta,
(dcomplex*) c, rs_c, cs_c,
NULL, NULL
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa );
bli_param_map_netlib_to_blis_trans( *transb, &blis_transb );
@@ -1047,11 +1188,11 @@ void dzgemm_blis_impl
bli_obj_set_conjtrans( blis_transa, &ao );
bli_obj_set_conjtrans( blis_transb, &bo );
// fall back on native path when zgemm is not handled in sup path.
// fall back on native path when zgemm is not handled in sup path.
bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1)
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
}// end of dzgemm_