From 5f5bc2498937d7ac5a64ff97fb48464e4dc4005a Mon Sep 17 00:00:00 2001 From: Mangala V Date: Mon, 15 May 2023 23:58:24 +0530 Subject: [PATCH] Bug fix: AVX2 code being invoked on non-avx2 machine for ZGEMM API Prevented calling avx2 based bli_zgemm_ref_k1_nn code on non-supported systems. Changed the name of the function bli_zgemm_ref_k1_nn to bli_zgemm_4x6_avx2_k1_nn(). Changed the name of the function bli_dgemm_ref_k1_nn to bli_dgemm_8x6_avx2_k1_nn(). Thanks to Kiran Varaganti for identifying and helping to fix the issue. AMD-Internal: [CPUPL-3352] Change-Id: I02530ab197ed84c96cbad4f7dd56eedca0109c35 --- frame/compat/bla_gemm_amd.c | 296 ++++++++++-------- kernels/zen/3/CMakeLists.txt | 6 +- ...bli_dgemm_ref_k1.c => bli_dgemm_avx2_k1.c} | 4 +- ...bli_zgemm_ref_k1.c => bli_zgemm_avx2_k1.c} | 4 +- kernels/zen/bli_kernels_zen.h | 4 +- 5 files changed, 169 insertions(+), 145 deletions(-) rename kernels/zen/3/{bli_dgemm_ref_k1.c => bli_dgemm_avx2_k1.c} (99%) rename kernels/zen/3/{bli_zgemm_ref_k1.c => bli_zgemm_avx2_k1.c} (99%) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 68f765976..afbecd2a5 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -432,41 +432,41 @@ void dgemm_blis_impl double* c, const f77_int* ldc ) { - trans_t blis_transa; - trans_t blis_transb; - dim_t m0, n0, k0; + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; - /* Initialize BLIS. */ - bli_init_auto(); + /* Initialize BLIS. */ + bli_init_auto(); - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \ - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemm) - ( - MKSTR(d), - MKSTR(gemm), - transa, - transb, - m, - n, - k, - lda, - ldb, - ldc - ); + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(d), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); - /* Quick return if possible. */ - if ( *m == 0 || *n == 0 || ((*alpha == 0.0 || *k == 0) && *beta == 1.0)) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 || ((*alpha == 0.0 || *k == 0) && *beta == 1.0)) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } /* If alpha is zero scale C by beta and return early. */ if( PASTEMAC(d,eq0)( *alpha )) @@ -494,7 +494,7 @@ void dgemm_blis_impl return; } - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); @@ -564,92 +564,92 @@ void dgemm_blis_impl if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) { - bli_dgemm_ref_k1_nn( m0, n0, k0, - (double*)alpha, - (double*)a, *lda, - (double*)b, *ldb, - (double*)beta, - c, *ldc - ); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS */ - bli_finalize_auto(); - return; + bli_dgemm_8x6_avx2_k1_nn( m0, n0, k0, + (double*)alpha, + (double*)a, *lda, + (double*)b, *ldb, + (double*)beta, + c, *ldc + ); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); + return; } if (n0 == 1) { - if (bli_is_notrans(blis_transa)) - { - bli_dgemv_unf_var2( - BLIS_NO_TRANSPOSE, - bli_extract_conj(blis_transb), - m0, k0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (double*)beta, - c, rs_c, - ((void*)0) - ); - } - else - { - bli_dgemv_unf_var1( - blis_transa, - bli_extract_conj(blis_transb), - k0, m0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (double*)beta, - c, rs_c, - ((void*)0) - ); - } + if (bli_is_notrans(blis_transa)) + { + bli_dgemv_unf_var2( + BLIS_NO_TRANSPOSE, + bli_extract_conj(blis_transb), + m0, k0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var1( + blis_transa, + bli_extract_conj(blis_transb), + k0, m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); + } - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS */ - bli_finalize_auto(); - return; + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); + return; } else if (m0 == 1) { - if (bli_is_notrans(blis_transb)) - { - bli_dgemv_unf_var1( - blis_transb, - bli_extract_conj(blis_transa), - n0, k0, - (double*)alpha, - (double*)b, cs_b, rs_b, - (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (double*)beta, - c, cs_c, - ((void*)0) - ); - } - else - { - bli_dgemv_unf_var2( - blis_transb, - bli_extract_conj(blis_transa), - k0, n0, - (double*)alpha, - (double*)b, cs_b, rs_b, - (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (double*)beta, - c, cs_c, - ((void*)0) - ); - } - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS */ - bli_finalize_auto(); - return; + if (bli_is_notrans(blis_transb)) + { + bli_dgemv_unf_var1( + blis_transb, + bli_extract_conj(blis_transa), + n0, k0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var2( + blis_transb, + bli_extract_conj(blis_transa), + k0, n0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); + return; } const num_t dt = BLIS_DOUBLE; @@ -687,26 +687,26 @@ void dgemm_blis_impl if (is_parallel) #endif { - // Will call parallelized dgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + // Will call parallelized dgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; } - // The code below will be called when number of threads = 1. +// The code below will be called when number of threads = 1. #ifdef BLIS_ENABLE_SMALL_MATRIX @@ -813,18 +813,18 @@ void zgemm_blis_impl dcomplex* c, const f77_int* ldc ) { - trans_t blis_transa; - trans_t blis_transb; - dim_t m0, n0, k0; + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; - /* Initialize BLIS. */ - bli_init_auto(); + /* Initialize BLIS. */ + bli_init_auto(); - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc); - /* Perform BLAS parameter checking. */ + /* Perform BLAS parameter checking. */ PASTEBLACHK(gemm) ( MKSTR(z), @@ -924,6 +924,30 @@ void zgemm_blis_impl //dim_t nt = bli_thread_get_num_threads(); // get number of threads bool is_parallel = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. + // This function is invoked on all architectures including 'generic'. + // Non-AVX2+FMA3 platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx2fma3_supported() == FALSE) + { + + // Will call parallelized zgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Invoking the API for input sizes with k=1. - For single thread, the API has no constraints before invoking. @@ -933,7 +957,7 @@ void zgemm_blis_impl && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) { - bli_zgemm_ref_k1_nn( m0, n0, k0, + bli_zgemm_4x6_avx2_k1_nn( m0, n0, k0, (dcomplex*)alpha, (dcomplex*)a, *lda, (dcomplex*)b, *ldb, diff --git a/kernels/zen/3/CMakeLists.txt b/kernels/zen/3/CMakeLists.txt index b7187e59e..741d46e2c 100644 --- a/kernels/zen/3/CMakeLists.txt +++ b/kernels/zen/3/CMakeLists.txt @@ -1,11 +1,11 @@ -##Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved.## add_library(zen_3 OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_small.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_ref_k1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_ref_k1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_avx2_k1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_avx2_k1.c ) target_compile_options(zen_3 PRIVATE /arch:AVX2) diff --git a/kernels/zen/3/bli_dgemm_ref_k1.c b/kernels/zen/3/bli_dgemm_avx2_k1.c similarity index 99% rename from kernels/zen/3/bli_dgemm_ref_k1.c rename to kernels/zen/3/bli_dgemm_avx2_k1.c index 14fa99ada..b225fdad1 100644 --- a/kernels/zen/3/bli_dgemm_ref_k1.c +++ b/kernels/zen/3/bli_dgemm_avx2_k1.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,7 +40,7 @@ #define D_MR 8 #define D_NR 6 -void bli_dgemm_ref_k1_nn +void bli_dgemm_8x6_avx2_k1_nn ( dim_t m, dim_t n, diff --git a/kernels/zen/3/bli_zgemm_ref_k1.c b/kernels/zen/3/bli_zgemm_avx2_k1.c similarity index 99% rename from kernels/zen/3/bli_zgemm_ref_k1.c rename to kernels/zen/3/bli_zgemm_avx2_k1.c index 60353cced..a6a92f9a5 100644 --- a/kernels/zen/3/bli_zgemm_ref_k1.c +++ b/kernels/zen/3/bli_zgemm_avx2_k1.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -107,7 +107,7 @@ NEG_PERM_M_FRINGE(rin_0,rn); \ rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ -void bli_zgemm_ref_k1_nn +void bli_zgemm_4x6_avx2_k1_nn ( dim_t m, dim_t n, diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index c90dbc0e0..e6a2f33f9 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -303,7 +303,7 @@ err_t bli_zgemm_small_At cntl_t* cntl ); -void bli_dgemm_ref_k1_nn +void bli_dgemm_8x6_avx2_k1_nn ( dim_t m, dim_t n, @@ -315,7 +315,7 @@ void bli_dgemm_ref_k1_nn double* c, const inc_t ldc ); -void bli_zgemm_ref_k1_nn +void bli_zgemm_4x6_avx2_k1_nn ( dim_t m, dim_t n,