From 873b4f93fde498ce793c96a423ccd0b351bd5f12 Mon Sep 17 00:00:00 2001 From: Edward Smyth Date: Tue, 7 Mar 2023 05:13:24 -0500 Subject: [PATCH] GEMM: Early return when alpha = zero Add test in top-level GEMM BLAS layer for alpha = zero. Scale C appropriately if true and return. This ensures that A and B are not referenced, and thus that any Inf or NaN values in A or B are not propagated to C. This was tested in some lower level code paths, but not consistently. Solution throughout GEMM codebase assumes scalm handles scale value (beta from GEMM) equal to 0 without propagating inf/NaNs in C matrix. Also call AOCL_DTL exits and bli_finalize_auto() in a more consistent way. AMD-Internal: [CPUPL-3053] Change-Id: I4009f311951eb1ce9416cf846e9fa93b7c9219cc --- frame/compat/bla_gemm.c | 93 +++++++++++++++++- frame/compat/bla_gemm_amd.c | 181 ++++++++++++++++++++++++++++++++---- 2 files changed, 250 insertions(+), 24 deletions(-) diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 36bd5df40..e7576096c 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -97,6 +97,32 @@ void PASTEF77S(ch,blasname) \ bli_finalize_auto(); \ return; \ } \ +\ + /* If alpha is zero scale C by beta and return early. */ \ + if( PASTEMAC(ch,eq0)( *alpha )) \ + { \ + bli_convert_blas_dim1(*m, m0); \ + bli_convert_blas_dim1(*n, n0); \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + PASTEMAC2(ch,scalm,_ex)( BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m0, \ + n0, \ + (ftype*) beta, \ + (ftype*) c, rs_c, cs_c, \ + NULL, NULL \ + ); \ +\ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -133,7 +159,7 @@ void PASTEF77S(ch,blasname) \ ); \ \ AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);\ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ } \ @@ -215,6 +241,32 @@ void PASTEF77S(ch,blasname) \ bli_finalize_auto(); \ return; \ } \ +\ + /* If alpha is zero scale C by beta and return early. */ \ + if( PASTEMAC(ch,eq0)( *alpha )) \ + { \ + bli_convert_blas_dim1(*m, m0); \ + bli_convert_blas_dim1(*n, n0); \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + PASTEMAC2(ch,scalm,_ex)( BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m0, \ + n0, \ + (ftype*) beta, \ + (ftype*) c, rs_c, cs_c, \ + NULL, NULL \ + ); \ +\ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -264,6 +316,9 @@ void PASTEF77S(ch,blasname) \ ); \ } \ AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ return; \ } \ else if( m0 == 1 ) \ @@ -297,6 +352,9 @@ void PASTEF77S(ch,blasname) \ ); \ } \ AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ return; \ } \ \ @@ -333,7 +391,7 @@ void PASTEF77S(ch,blasname) \ ); \ \ AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ } \ @@ -412,6 +470,32 @@ void dzgemm_ return; } + /* If alpha is zero scale C by beta and return early. */ + if( PASTEMAC(z,eq0)( *alpha )) + { + bli_convert_blas_dim1(*m, m0); + bli_convert_blas_dim1(*n, n0); + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + PASTEMAC2(z,scalm,_ex)( BLIS_NO_CONJUGATE, + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + m0, + n0, + (dcomplex*) beta, + (dcomplex*) c, rs_c, cs_c, + NULL, NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); @@ -454,11 +538,12 @@ void dzgemm_ bli_obj_set_conjtrans( blis_transa, &ao ); bli_obj_set_conjtrans( blis_transb, &bo ); - // fall back on native path when zgemm is not handled in sup path. + // fall back on native path when zgemm is not handled in sup path. bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ bli_finalize_auto(); }// end of dzgemm_ diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 25ef492cb..c0e774db4 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -97,6 +97,32 @@ void PASTEF77S(ch,blasname) \ bli_finalize_auto(); \ return; \ } \ +\ + /* If alpha is zero scale C by beta and return early. */ \ + if( PASTEMAC(ch,eq0)( *alpha )) \ + { \ + bli_convert_blas_dim1(*m, m0); \ + bli_convert_blas_dim1(*n, n0); \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + PASTEMAC2(ch,scalm,_ex)( BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m0, \ + n0, \ + (ftype*) beta, \ + (ftype*) c, rs_c, cs_c, \ + NULL, NULL \ + ); \ +\ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -133,7 +159,7 @@ void PASTEF77S(ch,blasname) \ ); \ \ AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);\ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ } \ @@ -216,6 +242,32 @@ void PASTEF77S(ch,blasname) \ bli_finalize_auto(); \ return; \ } \ +\ + /* If alpha is zero scale C by beta and return early. */ \ + if( PASTEMAC(ch,eq0)( *alpha )) \ + { \ + bli_convert_blas_dim1(*m, m0); \ + bli_convert_blas_dim1(*n, n0); \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + PASTEMAC2(ch,scalm,_ex)( BLIS_NO_CONJUGATE, \ + 0, \ + BLIS_NONUNIT_DIAG, \ + BLIS_DENSE, \ + m0, \ + n0, \ + (ftype*) beta, \ + (ftype*) c, rs_c, cs_c, \ + NULL, NULL \ + ); \ +\ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -265,6 +317,9 @@ void PASTEF77S(ch,blasname) \ ); \ } \ AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ return; \ } \ else if( m0 == 1 ) \ @@ -298,6 +353,9 @@ void PASTEF77S(ch,blasname) \ ); \ } \ AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ return; \ } \ \ @@ -334,7 +392,7 @@ void PASTEF77S(ch,blasname) \ ); \ \ AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ } \ @@ -374,9 +432,6 @@ void dgemm_blis_impl double* c, const f77_int* ldc ) { - - - trans_t blis_transa; trans_t blis_transb; dim_t m0, n0, k0; @@ -413,6 +468,32 @@ void dgemm_blis_impl return; } + /* If alpha is zero scale C by beta and return early. */ + if( PASTEMAC(d,eq0)( *alpha )) + { + bli_convert_blas_dim1(*m, m0); + bli_convert_blas_dim1(*n, n0); + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + PASTEMAC2(d,scalm,_ex)( BLIS_NO_CONJUGATE, + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + m0, + n0, + (double*) beta, + (double*) c, rs_c, cs_c, + NULL, NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); @@ -475,7 +556,6 @@ void dgemm_blis_impl ); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ bli_finalize_auto(); @@ -495,7 +575,6 @@ void dgemm_blis_impl AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS */ bli_finalize_auto(); - return; } @@ -531,7 +610,9 @@ void dgemm_blis_impl } AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); return; } else if (m0 == 1) @@ -565,6 +646,9 @@ void dgemm_blis_impl ); } AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); return; } @@ -661,7 +745,6 @@ void dgemm_blis_impl AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ bli_finalize_auto(); - return; } } @@ -672,6 +755,9 @@ void dgemm_blis_impl if (status == BLIS_SUCCESS) { AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); return; } @@ -764,6 +850,32 @@ void zgemm_blis_impl return; } + /* If alpha is zero scale C by beta and return early. */ + if( PASTEMAC(z,eq0)( *alpha )) + { + bli_convert_blas_dim1(*m, m0); + bli_convert_blas_dim1(*n, n0); + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + PASTEMAC2(z,scalm,_ex)( BLIS_NO_CONJUGATE, + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + m0, + n0, + (dcomplex*) beta, + (dcomplex*) c, rs_c, cs_c, + NULL, NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); @@ -831,7 +943,6 @@ void zgemm_blis_impl AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS */ bli_finalize_auto(); - return; } @@ -851,7 +962,9 @@ void zgemm_blis_impl c, rs_c, ((void *)0)); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); return; } } @@ -870,6 +983,9 @@ void zgemm_blis_impl c, cs_c, ((void *)0)); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); return; } } @@ -918,18 +1034,17 @@ void zgemm_blis_impl if (status == BLIS_SUCCESS) { AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); return; } // fall back on native path when zgemm is not handled in sup path. bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ bli_finalize_auto(); }// end of zgemm_ @@ -1005,6 +1120,32 @@ void dzgemm_blis_impl return; } + /* If alpha is zero scale C by beta and return early. */ + if( PASTEMAC(z,eq0)( *alpha )) + { + bli_convert_blas_dim1(*m, m0); + bli_convert_blas_dim1(*n, n0); + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + PASTEMAC2(z,scalm,_ex)( BLIS_NO_CONJUGATE, + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + m0, + n0, + (dcomplex*) beta, + (dcomplex*) c, rs_c, cs_c, + NULL, NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); @@ -1047,11 +1188,11 @@ void dzgemm_blis_impl bli_obj_set_conjtrans( blis_transa, &ao ); bli_obj_set_conjtrans( blis_transb, &bo ); - // fall back on native path when zgemm is not handled in sup path. + // fall back on native path when zgemm is not handled in sup path. bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ bli_finalize_auto(); }// end of dzgemm_