Bug fix: AVX2 code being invoked on non-avx2 machine for ZGEMM API

Prevented calling avx2 based bli_zgemm_ref_k1_nn code on
non-supported systems.
Changed the name of the function bli_zgemm_ref_k1_nn to bli_zgemm_4x6_avx2_k1_nn().
Changed the name of the function bli_dgemm_ref_k1_nn to bli_dgemm_8x6_avx2_k1_nn().

Thanks to Kiran Varaganti <Kiran.Varaganti@amd.com>
for identifying and helping to fix the issue.

AMD-Internal: [CPUPL-3352]
Change-Id: I02530ab197ed84c96cbad4f7dd56eedca0109c35
This commit is contained in:
Mangala V
2023-05-15 23:58:24 +05:30
parent 2c4f032e0f
commit 5f5bc24989
5 changed files with 169 additions and 145 deletions

View File

@@ -432,41 +432,41 @@ void dgemm_blis_impl
double* c, const f77_int* ldc
)
{
trans_t blis_transa;
trans_t blis_transb;
dim_t m0, n0, k0;
trans_t blis_transa;
trans_t blis_transb;
dim_t m0, n0, k0;
/* Initialize BLIS. */
bli_init_auto();
/* Initialize BLIS. */
bli_init_auto();
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1)
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \
(void*)alpha, *lda, *ldb, (void*)beta, *ldc);
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1)
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \
(void*)alpha, *lda, *ldb, (void*)beta, *ldc);
/* Perform BLAS parameter checking. */
PASTEBLACHK(gemm)
(
MKSTR(d),
MKSTR(gemm),
transa,
transb,
m,
n,
k,
lda,
ldb,
ldc
);
/* Perform BLAS parameter checking. */
PASTEBLACHK(gemm)
(
MKSTR(d),
MKSTR(gemm),
transa,
transb,
m,
n,
k,
lda,
ldb,
ldc
);
/* Quick return if possible. */
if ( *m == 0 || *n == 0 || ((*alpha == 0.0 || *k == 0) && *beta == 1.0))
{
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
/* Quick return if possible. */
if ( *m == 0 || *n == 0 || ((*alpha == 0.0 || *k == 0) && *beta == 1.0))
{
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
/* If alpha is zero scale C by beta and return early. */
if( PASTEMAC(d,eq0)( *alpha ))
@@ -494,7 +494,7 @@ void dgemm_blis_impl
return;
}
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(*transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(*transb, &blis_transb);
@@ -564,92 +564,92 @@ void dgemm_blis_impl
if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb))
{
bli_dgemm_ref_k1_nn( m0, n0, k0,
(double*)alpha,
(double*)a, *lda,
(double*)b, *ldb,
(double*)beta,
c, *ldc
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS */
bli_finalize_auto();
return;
bli_dgemm_8x6_avx2_k1_nn( m0, n0, k0,
(double*)alpha,
(double*)a, *lda,
(double*)b, *ldb,
(double*)beta,
c, *ldc
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS */
bli_finalize_auto();
return;
}
if (n0 == 1)
{
if (bli_is_notrans(blis_transa))
{
bli_dgemv_unf_var2(
BLIS_NO_TRANSPOSE,
bli_extract_conj(blis_transb),
m0, k0,
(double*)alpha,
(double*)a, rs_a, cs_a,
(double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
(double*)beta,
c, rs_c,
((void*)0)
);
}
else
{
bli_dgemv_unf_var1(
blis_transa,
bli_extract_conj(blis_transb),
k0, m0,
(double*)alpha,
(double*)a, rs_a, cs_a,
(double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
(double*)beta,
c, rs_c,
((void*)0)
);
}
if (bli_is_notrans(blis_transa))
{
bli_dgemv_unf_var2(
BLIS_NO_TRANSPOSE,
bli_extract_conj(blis_transb),
m0, k0,
(double*)alpha,
(double*)a, rs_a, cs_a,
(double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
(double*)beta,
c, rs_c,
((void*)0)
);
}
else
{
bli_dgemv_unf_var1(
blis_transa,
bli_extract_conj(blis_transb),
k0, m0,
(double*)alpha,
(double*)a, rs_a, cs_a,
(double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
(double*)beta,
c, rs_c,
((void*)0)
);
}
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS */
bli_finalize_auto();
return;
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS */
bli_finalize_auto();
return;
}
else if (m0 == 1)
{
if (bli_is_notrans(blis_transb))
{
bli_dgemv_unf_var1(
blis_transb,
bli_extract_conj(blis_transa),
n0, k0,
(double*)alpha,
(double*)b, cs_b, rs_b,
(double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
(double*)beta,
c, cs_c,
((void*)0)
);
}
else
{
bli_dgemv_unf_var2(
blis_transb,
bli_extract_conj(blis_transa),
k0, n0,
(double*)alpha,
(double*)b, cs_b, rs_b,
(double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
(double*)beta,
c, cs_c,
((void*)0)
);
}
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS */
bli_finalize_auto();
return;
if (bli_is_notrans(blis_transb))
{
bli_dgemv_unf_var1(
blis_transb,
bli_extract_conj(blis_transa),
n0, k0,
(double*)alpha,
(double*)b, cs_b, rs_b,
(double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
(double*)beta,
c, cs_c,
((void*)0)
);
}
else
{
bli_dgemv_unf_var2(
blis_transb,
bli_extract_conj(blis_transa),
k0, n0,
(double*)alpha,
(double*)b, cs_b, rs_b,
(double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
(double*)beta,
c, cs_c,
((void*)0)
);
}
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS */
bli_finalize_auto();
return;
}
const num_t dt = BLIS_DOUBLE;
@@ -687,26 +687,26 @@ void dgemm_blis_impl
if (is_parallel)
#endif
{
// Will call parallelized dgemm code - sup & native
PASTEMAC(gemm, BLIS_OAPI_EX_SUF)
(
&alphao,
&ao,
&bo,
&betao,
&co,
NULL,
NULL
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
// Will call parallelized dgemm code - sup & native
PASTEMAC(gemm, BLIS_OAPI_EX_SUF)
(
&alphao,
&ao,
&bo,
&betao,
&co,
NULL,
NULL
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
// The code below will be called when number of threads = 1.
// The code below will be called when number of threads = 1.
#ifdef BLIS_ENABLE_SMALL_MATRIX
@@ -813,18 +813,18 @@ void zgemm_blis_impl
dcomplex* c, const f77_int* ldc
)
{
trans_t blis_transa;
trans_t blis_transb;
dim_t m0, n0, k0;
trans_t blis_transa;
trans_t blis_transb;
dim_t m0, n0, k0;
/* Initialize BLIS. */
bli_init_auto();
/* Initialize BLIS. */
bli_init_auto();
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1)
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k,
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1)
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k,
(void*)alpha, *lda, *ldb, (void*)beta, *ldc);
/* Perform BLAS parameter checking. */
/* Perform BLAS parameter checking. */
PASTEBLACHK(gemm)
(
MKSTR(z),
@@ -924,6 +924,30 @@ void zgemm_blis_impl
//dim_t nt = bli_thread_get_num_threads(); // get number of threads
bool is_parallel = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked.
// This function is invoked on all architectures including 'generic'.
// Non-AVX2+FMA3 platforms will use the kernels derived from the context.
if (bli_cpuid_is_avx2fma3_supported() == FALSE)
{
// Will call parallelized zgemm code - sup & native
PASTEMAC(gemm, BLIS_OAPI_EX_SUF)
(
&alphao,
&ao,
&bo,
&betao,
&co,
NULL,
NULL
);
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
/*
Invoking the API for input sizes with k=1.
- For single thread, the API has no constraints before invoking.
@@ -933,7 +957,7 @@ void zgemm_blis_impl
&& bli_is_notrans(blis_transa)
&& bli_is_notrans(blis_transb))
{
bli_zgemm_ref_k1_nn( m0, n0, k0,
bli_zgemm_4x6_avx2_k1_nn( m0, n0, k0,
(dcomplex*)alpha,
(dcomplex*)a, *lda,
(dcomplex*)b, *ldb,

View File

@@ -1,11 +1,11 @@
##Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved.##
##Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved.##
add_library(zen_3
OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_small.c
${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small.c
${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_ref_k1.c
${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_ref_k1.c
${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_avx2_k1.c
${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_avx2_k1.c
)
target_compile_options(zen_3 PRIVATE /arch:AVX2)

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -40,7 +40,7 @@
#define D_MR 8
#define D_NR 6
void bli_dgemm_ref_k1_nn
void bli_dgemm_8x6_avx2_k1_nn
(
dim_t m,
dim_t n,

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -107,7 +107,7 @@
NEG_PERM_M_FRINGE(rin_0,rn); \
rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \
void bli_zgemm_ref_k1_nn
void bli_zgemm_4x6_avx2_k1_nn
(
dim_t m,
dim_t n,

View File

@@ -303,7 +303,7 @@ err_t bli_zgemm_small_At
cntl_t* cntl
);
void bli_dgemm_ref_k1_nn
void bli_dgemm_8x6_avx2_k1_nn
(
dim_t m,
dim_t n,
@@ -315,7 +315,7 @@ void bli_dgemm_ref_k1_nn
double* c, const inc_t ldc
);
void bli_zgemm_ref_k1_nn
void bli_zgemm_4x6_avx2_k1_nn
(
dim_t m,
dim_t n,