mirror of
https://github.com/amd/blis.git
synced 2026-05-12 18:15:37 +00:00
Bug fix: AVX2 code being invoked on non-avx2 machine for ZGEMM API
Prevented calling avx2 based bli_zgemm_ref_k1_nn code on non-supported systems. Changed the name of the function bli_zgemm_ref_k1_nn to bli_zgemm_4x6_avx2_k1_nn(). Changed the name of the function bli_dgemm_ref_k1_nn to bli_dgemm_8x6_avx2_k1_nn(). Thanks to Kiran Varaganti <Kiran.Varaganti@amd.com> for identifying and helping to fix the issue. AMD-Internal: [CPUPL-3352] Change-Id: I02530ab197ed84c96cbad4f7dd56eedca0109c35
This commit is contained in:
@@ -432,41 +432,41 @@ void dgemm_blis_impl
|
||||
double* c, const f77_int* ldc
|
||||
)
|
||||
{
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
dim_t m0, n0, k0;
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
dim_t m0, n0, k0;
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1)
|
||||
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \
|
||||
(void*)alpha, *lda, *ldb, (void*)beta, *ldc);
|
||||
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1)
|
||||
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \
|
||||
(void*)alpha, *lda, *ldb, (void*)beta, *ldc);
|
||||
|
||||
/* Perform BLAS parameter checking. */
|
||||
PASTEBLACHK(gemm)
|
||||
(
|
||||
MKSTR(d),
|
||||
MKSTR(gemm),
|
||||
transa,
|
||||
transb,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
lda,
|
||||
ldb,
|
||||
ldc
|
||||
);
|
||||
/* Perform BLAS parameter checking. */
|
||||
PASTEBLACHK(gemm)
|
||||
(
|
||||
MKSTR(d),
|
||||
MKSTR(gemm),
|
||||
transa,
|
||||
transb,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
lda,
|
||||
ldb,
|
||||
ldc
|
||||
);
|
||||
|
||||
/* Quick return if possible. */
|
||||
if ( *m == 0 || *n == 0 || ((*alpha == 0.0 || *k == 0) && *beta == 1.0))
|
||||
{
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
/* Quick return if possible. */
|
||||
if ( *m == 0 || *n == 0 || ((*alpha == 0.0 || *k == 0) && *beta == 1.0))
|
||||
{
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
|
||||
/* If alpha is zero scale C by beta and return early. */
|
||||
if( PASTEMAC(d,eq0)( *alpha ))
|
||||
@@ -494,7 +494,7 @@ void dgemm_blis_impl
|
||||
return;
|
||||
}
|
||||
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
|
||||
bli_param_map_netlib_to_blis_trans(*transa, &blis_transa);
|
||||
bli_param_map_netlib_to_blis_trans(*transb, &blis_transb);
|
||||
|
||||
@@ -564,92 +564,92 @@ void dgemm_blis_impl
|
||||
|
||||
if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb))
|
||||
{
|
||||
bli_dgemm_ref_k1_nn( m0, n0, k0,
|
||||
(double*)alpha,
|
||||
(double*)a, *lda,
|
||||
(double*)b, *ldb,
|
||||
(double*)beta,
|
||||
c, *ldc
|
||||
);
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
bli_dgemm_8x6_avx2_k1_nn( m0, n0, k0,
|
||||
(double*)alpha,
|
||||
(double*)a, *lda,
|
||||
(double*)b, *ldb,
|
||||
(double*)beta,
|
||||
c, *ldc
|
||||
);
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
|
||||
if (n0 == 1)
|
||||
{
|
||||
if (bli_is_notrans(blis_transa))
|
||||
{
|
||||
bli_dgemv_unf_var2(
|
||||
BLIS_NO_TRANSPOSE,
|
||||
bli_extract_conj(blis_transb),
|
||||
m0, k0,
|
||||
(double*)alpha,
|
||||
(double*)a, rs_a, cs_a,
|
||||
(double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
|
||||
(double*)beta,
|
||||
c, rs_c,
|
||||
((void*)0)
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
bli_dgemv_unf_var1(
|
||||
blis_transa,
|
||||
bli_extract_conj(blis_transb),
|
||||
k0, m0,
|
||||
(double*)alpha,
|
||||
(double*)a, rs_a, cs_a,
|
||||
(double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
|
||||
(double*)beta,
|
||||
c, rs_c,
|
||||
((void*)0)
|
||||
);
|
||||
}
|
||||
if (bli_is_notrans(blis_transa))
|
||||
{
|
||||
bli_dgemv_unf_var2(
|
||||
BLIS_NO_TRANSPOSE,
|
||||
bli_extract_conj(blis_transb),
|
||||
m0, k0,
|
||||
(double*)alpha,
|
||||
(double*)a, rs_a, cs_a,
|
||||
(double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
|
||||
(double*)beta,
|
||||
c, rs_c,
|
||||
((void*)0)
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
bli_dgemv_unf_var1(
|
||||
blis_transa,
|
||||
bli_extract_conj(blis_transb),
|
||||
k0, m0,
|
||||
(double*)alpha,
|
||||
(double*)a, rs_a, cs_a,
|
||||
(double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
|
||||
(double*)beta,
|
||||
c, rs_c,
|
||||
((void*)0)
|
||||
);
|
||||
}
|
||||
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
else if (m0 == 1)
|
||||
{
|
||||
if (bli_is_notrans(blis_transb))
|
||||
{
|
||||
bli_dgemv_unf_var1(
|
||||
blis_transb,
|
||||
bli_extract_conj(blis_transa),
|
||||
n0, k0,
|
||||
(double*)alpha,
|
||||
(double*)b, cs_b, rs_b,
|
||||
(double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
|
||||
(double*)beta,
|
||||
c, cs_c,
|
||||
((void*)0)
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
bli_dgemv_unf_var2(
|
||||
blis_transb,
|
||||
bli_extract_conj(blis_transa),
|
||||
k0, n0,
|
||||
(double*)alpha,
|
||||
(double*)b, cs_b, rs_b,
|
||||
(double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
|
||||
(double*)beta,
|
||||
c, cs_c,
|
||||
((void*)0)
|
||||
);
|
||||
}
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
if (bli_is_notrans(blis_transb))
|
||||
{
|
||||
bli_dgemv_unf_var1(
|
||||
blis_transb,
|
||||
bli_extract_conj(blis_transa),
|
||||
n0, k0,
|
||||
(double*)alpha,
|
||||
(double*)b, cs_b, rs_b,
|
||||
(double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
|
||||
(double*)beta,
|
||||
c, cs_c,
|
||||
((void*)0)
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
bli_dgemv_unf_var2(
|
||||
blis_transb,
|
||||
bli_extract_conj(blis_transa),
|
||||
k0, n0,
|
||||
(double*)alpha,
|
||||
(double*)b, cs_b, rs_b,
|
||||
(double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
|
||||
(double*)beta,
|
||||
c, cs_c,
|
||||
((void*)0)
|
||||
);
|
||||
}
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
|
||||
const num_t dt = BLIS_DOUBLE;
|
||||
@@ -687,26 +687,26 @@ void dgemm_blis_impl
|
||||
if (is_parallel)
|
||||
#endif
|
||||
{
|
||||
// Will call parallelized dgemm code - sup & native
|
||||
PASTEMAC(gemm, BLIS_OAPI_EX_SUF)
|
||||
(
|
||||
&alphao,
|
||||
&ao,
|
||||
&bo,
|
||||
&betao,
|
||||
&co,
|
||||
NULL,
|
||||
NULL
|
||||
);
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
|
||||
// Will call parallelized dgemm code - sup & native
|
||||
PASTEMAC(gemm, BLIS_OAPI_EX_SUF)
|
||||
(
|
||||
&alphao,
|
||||
&ao,
|
||||
&bo,
|
||||
&betao,
|
||||
&co,
|
||||
NULL,
|
||||
NULL
|
||||
);
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
|
||||
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
|
||||
// The code below will be called when number of threads = 1.
|
||||
// The code below will be called when number of threads = 1.
|
||||
|
||||
#ifdef BLIS_ENABLE_SMALL_MATRIX
|
||||
|
||||
@@ -813,18 +813,18 @@ void zgemm_blis_impl
|
||||
dcomplex* c, const f77_int* ldc
|
||||
)
|
||||
{
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
dim_t m0, n0, k0;
|
||||
trans_t blis_transa;
|
||||
trans_t blis_transb;
|
||||
dim_t m0, n0, k0;
|
||||
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
/* Initialize BLIS. */
|
||||
bli_init_auto();
|
||||
|
||||
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1)
|
||||
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k,
|
||||
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1)
|
||||
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k,
|
||||
(void*)alpha, *lda, *ldb, (void*)beta, *ldc);
|
||||
|
||||
/* Perform BLAS parameter checking. */
|
||||
/* Perform BLAS parameter checking. */
|
||||
PASTEBLACHK(gemm)
|
||||
(
|
||||
MKSTR(z),
|
||||
@@ -924,6 +924,30 @@ void zgemm_blis_impl
|
||||
//dim_t nt = bli_thread_get_num_threads(); // get number of threads
|
||||
bool is_parallel = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked.
|
||||
|
||||
// This function is invoked on all architectures including 'generic'.
|
||||
// Non-AVX2+FMA3 platforms will use the kernels derived from the context.
|
||||
if (bli_cpuid_is_avx2fma3_supported() == FALSE)
|
||||
{
|
||||
|
||||
// Will call parallelized zgemm code - sup & native
|
||||
PASTEMAC(gemm, BLIS_OAPI_EX_SUF)
|
||||
(
|
||||
&alphao,
|
||||
&ao,
|
||||
&bo,
|
||||
&betao,
|
||||
&co,
|
||||
NULL,
|
||||
NULL
|
||||
);
|
||||
|
||||
AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
|
||||
/*
|
||||
Invoking the API for input sizes with k=1.
|
||||
- For single thread, the API has no constraints before invoking.
|
||||
@@ -933,7 +957,7 @@ void zgemm_blis_impl
|
||||
&& bli_is_notrans(blis_transa)
|
||||
&& bli_is_notrans(blis_transb))
|
||||
{
|
||||
bli_zgemm_ref_k1_nn( m0, n0, k0,
|
||||
bli_zgemm_4x6_avx2_k1_nn( m0, n0, k0,
|
||||
(dcomplex*)alpha,
|
||||
(dcomplex*)a, *lda,
|
||||
(dcomplex*)b, *ldb,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
##Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved.##
|
||||
##Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved.##
|
||||
|
||||
add_library(zen_3
|
||||
OBJECT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_small.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_ref_k1.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_ref_k1.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_avx2_k1.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_avx2_k1.c
|
||||
)
|
||||
target_compile_options(zen_3 PRIVATE /arch:AVX2)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -40,7 +40,7 @@
|
||||
#define D_MR 8
|
||||
#define D_NR 6
|
||||
|
||||
void bli_dgemm_ref_k1_nn
|
||||
void bli_dgemm_8x6_avx2_k1_nn
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -107,7 +107,7 @@
|
||||
NEG_PERM_M_FRINGE(rin_0,rn); \
|
||||
rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \
|
||||
|
||||
void bli_zgemm_ref_k1_nn
|
||||
void bli_zgemm_4x6_avx2_k1_nn
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
@@ -303,7 +303,7 @@ err_t bli_zgemm_small_At
|
||||
cntl_t* cntl
|
||||
);
|
||||
|
||||
void bli_dgemm_ref_k1_nn
|
||||
void bli_dgemm_8x6_avx2_k1_nn
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
@@ -315,7 +315,7 @@ void bli_dgemm_ref_k1_nn
|
||||
double* c, const inc_t ldc
|
||||
);
|
||||
|
||||
void bli_zgemm_ref_k1_nn
|
||||
void bli_zgemm_4x6_avx2_k1_nn
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
|
||||
Reference in New Issue
Block a user