Added BLAS Extension API's: CBLAS_?GEMM3M

AMD-Internal: [CPUPL-1151]

Induced 3M1 method is enabled for CGEMM3M and ZGEMM3M

Change-Id: I8276c5018340d0a45694551f48aad5b735819eae
This commit is contained in:
bhaskarn
2020-10-28 21:23:58 +05:30
committed by Dipal M Zambare
parent 4b56cc94da
commit 711cc0ef35
15 changed files with 994 additions and 5 deletions

View File

@@ -16,6 +16,8 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_dot.c
${CMAKE_CURRENT_SOURCE_DIR}/bla_dot.h
${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm.c
${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm.h
${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm3m.c
${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm3m.h
${CMAKE_CURRENT_SOURCE_DIR}/bla_gemmt.c
${CMAKE_CURRENT_SOURCE_DIR}/bla_gemmt.h
${CMAKE_CURRENT_SOURCE_DIR}/bla_gemv.c

230
frame/compat/bla_gemm3m.c Normal file
View File

@@ -0,0 +1,230 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
//
// Define BLAS-to-BLIS interfaces.
//
#ifdef BLIS_BLAS3_CALLS_TAPI
#undef GENTFUNC
#define GENTFUNC( ftype, ch, blasname, blisname ) \
\
void PASTEF77(ch,blasname) \
( \
const f77_char* transa, \
const f77_char* transb, \
const f77_int* m, \
const f77_int* n, \
const f77_int* k, \
const ftype* alpha, \
const ftype* a, const f77_int* lda, \
const ftype* b, const f77_int* ldb, \
const ftype* beta, \
ftype* c, const f77_int* ldc \
) \
{ \
trans_t blis_transa; \
trans_t blis_transb; \
dim_t m0, n0, k0; \
inc_t rs_a, cs_a; \
inc_t rs_b, cs_b; \
inc_t rs_c, cs_c; \
\
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \
/* Initialize BLIS. */ \
bli_init_auto(); \
\
/* Perform BLAS parameter checking. */ \
PASTEBLACHK(blasname) \
( \
MKSTR(ch), \
MKSTR(blasname), \
transa, \
transb, \
m, \
n, \
k, \
lda, \
ldb, \
ldc \
); \
\
/* Map BLAS chars to their corresponding BLIS enumerated type value. */ \
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \
bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \
\
/* Typecast BLAS integers to BLIS integers. */ \
bli_convert_blas_dim1( *m, m0 ); \
bli_convert_blas_dim1( *n, n0 ); \
bli_convert_blas_dim1( *k, k0 ); \
\
/* Set the row and column strides of the matrix operands. */ \
rs_a = 1; \
cs_a = *lda; \
rs_b = 1; \
cs_b = *ldb; \
rs_c = 1; \
cs_c = *ldc; \
\
/* Call BLIS interface. */ \
PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \
( \
blis_transa, \
blis_transb, \
m0, \
n0, \
k0, \
(ftype*)alpha, \
(ftype*)a, rs_a, cs_a, \
(ftype*)b, rs_b, cs_b, \
(ftype*)beta, \
(ftype*)c, rs_c, cs_c, \
NULL, \
NULL \
); \
\
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
}
#else
#undef GENTFUNC
#define GENTFUNC( ftype, ch, blasname, blisname ) \
\
void PASTEF77(ch,blasname) \
( \
const f77_char* transa, \
const f77_char* transb, \
const f77_int* m, \
const f77_int* n, \
const f77_int* k, \
const ftype* alpha, \
const ftype* a, const f77_int* lda, \
const ftype* b, const f77_int* ldb, \
const ftype* beta, \
ftype* c, const f77_int* ldc \
) \
{ \
trans_t blis_transa; \
trans_t blis_transb; \
dim_t m0, n0, k0; \
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \
\
/* Initialize BLIS. */ \
bli_init_auto(); \
\
/* Perform BLAS parameter checking. */ \
PASTEBLACHK(blasname) \
( \
MKSTR(ch), \
MKSTR(blasname), \
transa, \
transb, \
m, \
n, \
k, \
lda, \
ldb, \
ldc \
); \
\
/* Map BLAS chars to their corresponding BLIS enumerated type value. */ \
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \
bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \
\
/* Typecast BLAS integers to BLIS integers. */ \
bli_convert_blas_dim1( *m, m0 ); \
bli_convert_blas_dim1( *n, n0 ); \
bli_convert_blas_dim1( *k, k0 ); \
\
/* Set the row and column strides of the matrix operands. */ \
const inc_t rs_a = 1; \
const inc_t cs_a = *lda; \
const inc_t rs_b = 1; \
const inc_t cs_b = *ldb; \
const inc_t rs_c = 1; \
const inc_t cs_c = *ldc; \
\
const num_t dt = PASTEMAC(ch,type); \
\
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \
obj_t ao = BLIS_OBJECT_INITIALIZER; \
obj_t bo = BLIS_OBJECT_INITIALIZER; \
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \
obj_t co = BLIS_OBJECT_INITIALIZER; \
\
dim_t m0_a, n0_a; \
dim_t m0_b, n0_b; \
\
bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \
bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \
\
bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \
bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \
\
bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \
bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \
bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \
\
bli_obj_set_conjtrans( blis_transa, &ao ); \
bli_obj_set_conjtrans( blis_transb, &bo ); \
\
PASTEMAC(blisname,ind) \
( \
&alphao, \
&ao, \
&bo, \
&betao, \
&co, \
NULL, \
NULL \
); \
\
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
}
#endif
#ifdef BLIS_ENABLE_BLAS
INSERT_GENTFUNC_BLAS_CZ( gemm3m, gemm )
#endif

59
frame/compat/bla_gemm3m.h Normal file
View File

@@ -0,0 +1,59 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
//
// Prototype BLAS-to-BLIS interfaces.
//
#undef GENTPROT
#define GENTPROT( ftype, ch, blasname ) \
\
BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \
( \
const f77_char* transa, \
const f77_char* transb, \
const f77_int* m, \
const f77_int* n, \
const f77_int* k, \
const ftype* alpha, \
const ftype* a, const f77_int* lda, \
const ftype* b, const f77_int* ldb, \
const ftype* beta, \
ftype* c, const f77_int* ldc \
);
#ifdef BLIS_ENABLE_BLAS
INSERT_GENTPROT_BLAS( gemm3m )
#endif

View File

@@ -197,7 +197,8 @@
// -- Batch Extension prototypes --
#include "bla_gemm_batch.h"
#include "bla_gemm3m.h"
#include "bla_gemm3m_check.h"
// -- Fortran-compatible APIs to BLIS functions --

View File

@@ -164,4 +164,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/cblas_saxpby.c
${CMAKE_CURRENT_SOURCE_DIR}/cblas_daxpby.c
${CMAKE_CURRENT_SOURCE_DIR}/cblas_caxpby.c
${CMAKE_CURRENT_SOURCE_DIR}/cblas_zaxpby.c
${CMAKE_CURRENT_SOURCE_DIR}/cblas_cgemm3m.c
${CMAKE_CURRENT_SOURCE_DIR}/cblas_zgemm3m.c
)

View File

@@ -685,6 +685,16 @@ void BLIS_EXPORT_BLAS cblas_zgemm_batch(enum CBLAS_ORDER Order,
f77_int *lda_array, const void **B, f77_int *ldb_array,
const void *beta_array, void **C, f77_int *ldc_array,
f77_int group_count, f77_int *group_size);
void BLIS_EXPORT_BLAS cblas_cgemm3m(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA,
enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N,
f77_int K, const void *alpha, const void *A,
f77_int lda, const void *B, f77_int ldb,
const void *beta, void *C, f77_int ldc);
void BLIS_EXPORT_BLAS cblas_zgemm3m(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA,
enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N,
f77_int K, const void *alpha, const void *A,
f77_int lda, const void *B, f77_int ldb,
const void *beta, void *C, f77_int ldc);
// -- AMIN APIs -------
BLIS_EXPORT_BLAS f77_int cblas_isamin(f77_int N, const float *X, f77_int incX);

View File

@@ -0,0 +1,121 @@
#include "blis.h"
#ifdef BLIS_ENABLE_CBLAS
/*
*
* cblas_cgemm3m.c
*
* This program is a C interface to cgemm3m.
*
* Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
*
*/
#include "cblas.h"
#include "cblas_f77.h"
void cblas_cgemm3m(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA,
enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N,
f77_int K, const void *alpha, const void *A,
f77_int lda, const void *B, f77_int ldb,
const void *beta, void *C, f77_int ldc)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1);
char TA, TB;
#ifdef F77_CHAR
F77_CHAR F77_TA, F77_TB;
#else
#define F77_TA &TA
#define F77_TB &TB
#endif
#ifdef F77_INT
F77_INT F77_M=M, F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb;
F77_INT F77_ldc=ldc;
#else
#define F77_M M
#define F77_N N
#define F77_K K
#define F77_lda lda
#define F77_ldb ldb
#define F77_ldc ldc
#endif
extern int CBLAS_CallFromC;
extern int RowMajorStrg;
RowMajorStrg = 0;
CBLAS_CallFromC = 1;
if( Order == CblasColMajor )
{
if(TransA == CblasTrans) TA='T';
else if ( TransA == CblasConjTrans ) TA='C';
else if ( TransA == CblasNoTrans ) TA='N';
else
{
cblas_xerbla(2, "cblas_cgemm3m", "Illegal TransA setting, %d\n", TransA);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
if(TransB == CblasTrans) TB='T';
else if ( TransB == CblasConjTrans ) TB='C';
else if ( TransB == CblasNoTrans ) TB='N';
else
{
cblas_xerbla(3, "cblas_cgemm3m", "Illegal TransB setting, %d\n", TransB);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
#ifdef F77_CHAR
F77_TA = C2F_CHAR(&TA);
F77_TB = C2F_CHAR(&TB);
#endif
F77_cgemm3m(F77_TA, F77_TB, &F77_M, &F77_N, &F77_K, (scomplex*)alpha, (scomplex*)A,
&F77_lda, (scomplex*)B, &F77_ldb, (scomplex*)beta, (scomplex*)C, &F77_ldc);
} else if (Order == CblasRowMajor)
{
RowMajorStrg = 1;
if(TransA == CblasTrans) TB='T';
else if ( TransA == CblasConjTrans ) TB='C';
else if ( TransA == CblasNoTrans ) TB='N';
else
{
cblas_xerbla(2, "cblas_cgemm3m", "Illegal TransA setting, %d\n", TransA);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
if(TransB == CblasTrans) TA='T';
else if ( TransB == CblasConjTrans ) TA='C';
else if ( TransB == CblasNoTrans ) TA='N';
else
{
cblas_xerbla(2, "cblas_cgemm3m", "Illegal TransB setting, %d\n", TransB);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
#ifdef F77_CHAR
F77_TA = C2F_CHAR(&TA);
F77_TB = C2F_CHAR(&TB);
#endif
F77_cgemm3m(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, (scomplex*)alpha, (scomplex*)B,
&F77_ldb, (scomplex*)A, &F77_lda, (scomplex*)beta, (scomplex*)C, &F77_ldc);
}
else cblas_xerbla(1, "cblas_cgemm3m", "Illegal Order setting, %d\n", Order);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
#endif

View File

@@ -186,6 +186,8 @@
#define F77_daxpby daxpby_
#define F77_caxpby caxpby_
#define F77_zaxpby zaxpby_
#define F77_cgemm3m cgemm3m_
#define F77_zgemm3m zgemm3m_
#define F77_isamin_sub isaminsub_
#define F77_idamin_sub idaminsub_

View File

@@ -0,0 +1,119 @@
#include "blis.h"
#ifdef BLIS_ENABLE_CBLAS
/*
*
* cblas_zgemm3m.c
*
* This program is a C interface to zgemm3m.
*
* Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
*
*/
#include "cblas.h"
#include "cblas_f77.h"
void cblas_zgemm3m(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA,
enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N,
f77_int K, const void *alpha, const void *A,
f77_int lda, const void *B, f77_int ldb,
const void *beta, void *C, f77_int ldc)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1);
char TA, TB;
#ifdef F77_CHAR
F77_CHAR F77_TA, F77_TB;
#else
#define F77_TA &TA
#define F77_TB &TB
#endif
#ifdef F77_INT
F77_INT F77_M=M, F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb;
F77_INT F77_ldc=ldc;
#else
#define F77_M M
#define F77_N N
#define F77_K K
#define F77_lda lda
#define F77_ldb ldb
#define F77_ldc ldc
#endif
extern int CBLAS_CallFromC;
extern int RowMajorStrg;
RowMajorStrg = 0;
CBLAS_CallFromC = 1;
if( Order == CblasColMajor )
{
if(TransA == CblasTrans) TA='T';
else if ( TransA == CblasConjTrans ) TA='C';
else if ( TransA == CblasNoTrans ) TA='N';
else
{
cblas_xerbla(2, "cblas_zgemm3m", "Illegal TransA setting, %d\n", TransA);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
if(TransB == CblasTrans) TB='T';
else if ( TransB == CblasConjTrans ) TB='C';
else if ( TransB == CblasNoTrans ) TB='N';
else
{
cblas_xerbla(3, "cblas_zgemm3m", "Illegal TransB setting, %d\n", TransB);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
#ifdef F77_CHAR
F77_TA = C2F_CHAR(&TA);
F77_TB = C2F_CHAR(&TB);
#endif
F77_zgemm3m(F77_TA, F77_TB, &F77_M, &F77_N, &F77_K, (dcomplex*)alpha, (dcomplex*)A,
&F77_lda, (dcomplex*)B, &F77_ldb, (dcomplex*)beta, (dcomplex*)C, &F77_ldc);
} else if (Order == CblasRowMajor)
{
RowMajorStrg = 1;
if(TransA == CblasTrans) TB='T';
else if ( TransA == CblasConjTrans ) TB='C';
else if ( TransA == CblasNoTrans ) TB='N';
else
{
cblas_xerbla(2, "cblas_zgemm3m", "Illegal TransA setting, %d\n", TransA);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
if(TransB == CblasTrans) TA='T';
else if ( TransB == CblasConjTrans ) TA='C';
else if ( TransB == CblasNoTrans ) TA='N';
else
{
cblas_xerbla(2, "cblas_zgemm3m", "Illegal TransB setting, %d\n", TransB);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
#ifdef F77_CHAR
F77_TA = C2F_CHAR(&TA);
F77_TB = C2F_CHAR(&TB);
#endif
F77_zgemm3m(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, (dcomplex*)alpha, (dcomplex*)B,
&F77_ldb, (dcomplex*)A, &F77_lda, (dcomplex*)beta, (dcomplex*)C, &F77_ldc);
}
else cblas_xerbla(1, "cblas_zgemm3m", "Illegal Order setting, %d\n", Order);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
#endif

View File

@@ -22,6 +22,7 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_trmm_check.h
${CMAKE_CURRENT_SOURCE_DIR}/bla_trmv_check.h
${CMAKE_CURRENT_SOURCE_DIR}/bla_trsm_check.h
${CMAKE_CURRENT_SOURCE_DIR}/bla_trsv_check.h
${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm3m_check.h
)

View File

@@ -0,0 +1,88 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifdef BLIS_ENABLE_BLAS
#define bla_gemm3m_check( dt_str, op_str, transa, transb, m, n, k, lda, ldb, ldc ) \
{ \
f77_int info = 0; \
f77_int nota, notb; \
f77_int conja, conjb; \
f77_int ta, tb; \
f77_int nrowa, nrowb; \
\
nota = PASTEF770(lsame)( transa, "N", (ftnlen)1, (ftnlen)1 ); \
notb = PASTEF770(lsame)( transb, "N", (ftnlen)1, (ftnlen)1 ); \
conja = PASTEF770(lsame)( transa, "C", (ftnlen)1, (ftnlen)1 ); \
conjb = PASTEF770(lsame)( transb, "C", (ftnlen)1, (ftnlen)1 ); \
ta = PASTEF770(lsame)( transa, "T", (ftnlen)1, (ftnlen)1 ); \
tb = PASTEF770(lsame)( transb, "T", (ftnlen)1, (ftnlen)1 ); \
\
if ( nota ) { nrowa = *m; } \
else { nrowa = *k; } \
if ( notb ) { nrowb = *k; } \
else { nrowb = *n; } \
\
if ( !nota && !conja && !ta ) \
info = 1; \
else if ( !notb && !conjb && !tb ) \
info = 2; \
else if ( *m < 0 ) \
info = 3; \
else if ( *n < 0 ) \
info = 4; \
else if ( *k < 0 ) \
info = 5; \
else if ( *lda < bli_max( 1, nrowa ) ) \
info = 8; \
else if ( *ldb < bli_max( 1, nrowb ) ) \
info = 10; \
else if ( *ldc < bli_max( 1, *m ) ) \
info = 13; \
\
if ( info != 0 ) \
{ \
char func_str[ BLIS_MAX_BLAS_FUNC_STR_LENGTH ]; \
\
sprintf( func_str, "%s%-5s", dt_str, op_str ); \
\
bli_string_mkupper( func_str ); \
\
PASTEF770(xerbla)( func_str, &info, (ftnlen)6 ); \
\
return; \
} \
}
#endif

View File

@@ -68,7 +68,7 @@ bool bli_l3_ind_oper_st[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS][2] =
/* c z */
/* 3mh */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE},
{FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} },
/* 3m1 */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE},
/* 3m1 */ { {TRUE,TRUE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE},
{FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} },
/* 4mh */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE},
{FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} },

View File

@@ -16,6 +16,10 @@ add_executable(TestGemm test_gemm.c)
target_link_libraries(TestGemm debug "${PROJECT_NAME}.lib")
target_link_libraries(TestGemm optimized "${PROJECT_NAME}.lib")
add_executable(TestGemm3m test_gemm3m.c)
target_link_libraries(TestGemm3m debug "${PROJECT_NAME}.lib")
target_link_libraries(TestGemm3m optimized "${PROJECT_NAME}.lib")
add_executable(TestGemmt test_gemmt.c)
target_link_libraries(TestGemmt debug "${PROJECT_NAME}.lib")
target_link_libraries(TestGemmt optimized "${PROJECT_NAME}.lib")

View File

@@ -201,7 +201,8 @@ blis: check-env \
test_her2k_blis.x \
test_trmm_blis.x \
test_trsm_blis.x \
test_gemm_batch_blis.x
test_gemm_batch_blis.x \
test_gemm3m_blis.x
openblas: check-env \
test_dotv_openblas.x \
@@ -226,7 +227,8 @@ openblas: check-env \
test_her2k_openblas.x \
test_trmm_openblas.x \
test_trsm_openblas.x \
test_gemm_batch_openblas.x
test_gemm_batch_openblas.x \
test_gemm3m_openblas.x
atlas: check-env \
test_dotv_atlas.x \
@@ -272,7 +274,8 @@ mkl: check-env \
test_her2k_mkl.x \
test_trmm_mkl.x \
test_trsm_mkl.x \
test_gemm_batch_mkl.x
test_gemm_batch_mkl.x \
test_gemm3m_mkl.x
essl: check-env \
test_dotv_essl.x \

347
test/test_gemm3m.c Normal file
View File

@@ -0,0 +1,347 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name of The University of Texas nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifdef WIN32
#include <io.h>
#else
#include <unistd.h>
#endif
#include "blis.h"
#include "cblas.h"
#define CBLAS
//#define FILE_IN_OUT
//#define PRINT
#define MATRIX_INITIALISATION
int main( int argc, char** argv )
{
obj_t a, b, c;
obj_t c_save;
obj_t alpha, beta;
dim_t m, n, k;
dim_t p;
dim_t p_begin, p_end, p_inc;
int m_input, n_input, k_input;
num_t dt;
int r, n_repeats;
trans_t transa;
trans_t transb;
f77_char f77_transa;
f77_char f77_transb;
double dtime;
double dtime_save;
double gflops;
#ifdef FILE_IN_OUT
FILE* fin = NULL;
FILE* fout = NULL;
#endif
//bli_init();
//bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING );
n_repeats = 3;
#ifndef PRINT
p_begin = 200;
p_end = 2000;
p_inc = 100;
m_input = -1;
n_input = -1;
k_input = -1;
#else
p_begin = 16;
p_end = 16;
p_inc = 1;
m_input = 5;
k_input = 6;
n_input = 4;
#endif
dt = BLIS_SCOMPLEX;
//dt = BLIS_DCOMPLEX;
transa = BLIS_NO_TRANSPOSE;
transb = BLIS_NO_TRANSPOSE;
bli_param_map_blis_to_netlib_trans( transa, &f77_transa );
bli_param_map_blis_to_netlib_trans( transb, &f77_transb );
// printf("BLIS Library version is : %s\n", bli_info_get_version_str());
#ifdef FILE_IN_OUT
if (argc < 3)
{
printf("Usage: ./test_gemm_XX.x input.csv output.csv\n");
exit(1);
}
fin = fopen(argv[1], "r");
if (fin == NULL)
{
printf("Error opening the file %s\n", argv[1]);
exit(1);
}
fout = fopen(argv[2], "w");
if (fout == NULL)
{
printf("Error opening output file %s\n", argv[2]);
exit(1);
}
fprintf(fout, "m\t k\t n\t cs_a\t cs_b\t cs_c\t gflops\t GEMM_Algo\n");
printf("~~~~~~~~~~_BLAS\t m\t k\t n\t cs_a\t cs_b\t cs_c \t gflops\t GEMM_Algo\n");
inc_t cs_a;
inc_t cs_b;
inc_t cs_c;
while (fscanf(fin, "%lld %lld %lld %lld %lld %lld\n", &m, &k, &n, &cs_a, &cs_b, &cs_c) == 6)
{
if ((m > cs_a) || (k > cs_b) || (m > cs_c)) continue; // leading dimension should be greater than number of rows
bli_obj_create( dt, 1, 1, 0, 0, &alpha);
bli_obj_create( dt, 1, 1, 0, 0, &beta );
bli_obj_create( dt, m, k, 1, cs_a, &a );
bli_obj_create( dt, k, n, 1, cs_b, &b );
bli_obj_create( dt, m, n, 1, cs_c, &c );
bli_obj_create( dt, m, n, 1, cs_c, &c_save );
#ifdef MATRIX_INITIALISATION
bli_randm( &a );
bli_randm( &b );
bli_randm( &c );
#endif
bli_obj_set_conjtrans( transa, &a);
bli_obj_set_conjtrans( transb, &b);
//bli_setsc( 0.0, -1, &alpha );
//bli_setsc( 0.0, 1, &beta );
bli_setsc( -1, 0.0, &alpha );
bli_setsc( 1, 0.0, &beta );
#else
for ( p = p_begin; p <= p_end; p += p_inc )
{
if ( m_input < 0 ) m = p * ( dim_t )abs(m_input);
else m = ( dim_t ) m_input;
if ( n_input < 0 ) n = p * ( dim_t )abs(n_input);
else n = ( dim_t ) n_input;
if ( k_input < 0 ) k = p * ( dim_t )abs(k_input);
else k = ( dim_t ) k_input;
bli_obj_create( dt, 1, 1, 0, 0, &alpha );
bli_obj_create( dt, 1, 1, 0, 0, &beta );
bli_obj_create( dt, m, k, 0, 0, &a );
bli_obj_create( dt, k, n, 0, 0, &b );
bli_obj_create( dt, m, n, 0, 0, &c );
bli_obj_create( dt, m, n, 0, 0, &c_save );
#ifdef MATRIX_INITIALISATION
bli_randm( &a );
bli_randm( &b );
bli_randm( &c );
#endif
bli_obj_set_conjtrans( transa, &a );
bli_obj_set_conjtrans( transb, &b );
bli_setsc( (0.9/1.0), 0.2, &alpha );
bli_setsc( -(1.1/1.0), 0.3, &beta );
#endif
bli_copym( &c, &c_save );
dtime_save = DBL_MAX;
for ( r = 0; r < n_repeats; ++r )
{
bli_copym( &c_save, &c );
dtime = bli_clock();
#ifdef PRINT
bli_printm( "a", &a, "%4.1f", "" );
bli_printm( "b", &b, "%4.1f", "" );
bli_printm( "c", &c, "%4.1f", "" );
#endif
#ifndef CBLAS
if ( bli_is_scomplex( dt ) )
{
f77_int mm = bli_obj_length( &c );
f77_int kk = bli_obj_width_after_trans( &a );
f77_int nn = bli_obj_width( &c );
f77_int lda = bli_obj_col_stride( &a );
f77_int ldb = bli_obj_col_stride( &b );
f77_int ldc = bli_obj_col_stride( &c );
scomplex* alphap = bli_obj_buffer( &alpha );
scomplex* ap = bli_obj_buffer( &a );
scomplex* bp = bli_obj_buffer( &b );
scomplex* betap = bli_obj_buffer( &beta );
scomplex* cp = bli_obj_buffer( &c );
cgemm3m_( &f77_transa,
&f77_transb,
&mm,
&nn,
&kk,
alphap,
ap, &lda,
bp, &ldb,
betap,
cp, &ldc );
}
else if ( bli_is_dcomplex( dt ) )
{
f77_int mm = bli_obj_length( &c );
f77_int kk = bli_obj_width_after_trans( &a );
f77_int nn = bli_obj_width( &c );
f77_int lda = bli_obj_col_stride( &a );
f77_int ldb = bli_obj_col_stride( &b );
f77_int ldc = bli_obj_col_stride( &c );
dcomplex* alphap = bli_obj_buffer( &alpha );
dcomplex* ap = bli_obj_buffer( &a );
dcomplex* bp = bli_obj_buffer( &b );
dcomplex* betap = bli_obj_buffer( &beta );
dcomplex* cp = bli_obj_buffer( &c );
zgemm3m_( &f77_transa,
&f77_transb,
&mm,
&nn,
&kk,
alphap,
ap, &lda,
bp, &ldb,
betap,
cp, &ldc );
}
#else
if ( bli_is_scomplex( dt ) ){
scomplex* ap = bli_obj_buffer( &a );
scomplex* bp = bli_obj_buffer( &b );
scomplex* cp = bli_obj_buffer( &c );
scomplex* alphap = bli_obj_buffer( &alpha );
scomplex* betap = bli_obj_buffer( &beta );
cblas_cgemm3m( CblasColMajor,
CblasNoTrans,
CblasNoTrans,
m,
n,
k,
(const void*)alphap,
ap, m,
bp, k,
(const void*)betap,
cp, m );
}else if (bli_is_dcomplex(dt)){
dcomplex* ap = bli_obj_buffer( &a );
dcomplex* bp = bli_obj_buffer( &b );
dcomplex* cp = bli_obj_buffer( &c );
dcomplex* alphap = bli_obj_buffer( &alpha );
dcomplex* betap = bli_obj_buffer( &beta );
cblas_zgemm3m( CblasColMajor,
CblasNoTrans,
CblasNoTrans,
m,
n,
k,
(const void*)alphap,
ap, m,
bp, k,
(const void*)betap,
cp, m );
}
#endif
#ifdef PRINT
bli_printm( "c after", &c, "%4.6f", "" );
exit(1);
#endif
dtime_save = bli_clock_min_diff( dtime_save, dtime );
}
gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 );
gflops *= 4.0; //to represent complex ops in gflops
#ifdef BLIS
printf( "data_gemm_blis" );
#else
printf( "data_gemm_%s", BLAS );
#endif
#ifdef FILE_IN_OUT
printf("%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", \
( unsigned long )m,
( unsigned long )k,
( unsigned long )n, (unsigned long)cs_a, (unsigned long)cs_b, (unsigned long)cs_c, gflops);
fprintf(fout, "%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f \n", \
( unsigned long )m,
( unsigned long )k,
( unsigned long )n, (unsigned long)cs_a, (unsigned long)cs_b, (unsigned long)cs_c, gflops);
fflush(fout);
#else
printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n",
( unsigned long )(p - p_begin)/p_inc + 1,
( unsigned long )m,
( unsigned long )k,
( unsigned long )n, gflops );
#endif
bli_obj_free( &alpha );
bli_obj_free( &beta );
bli_obj_free( &a );
bli_obj_free( &b );
bli_obj_free( &c );
bli_obj_free( &c_save );
}
//bli_finalize();
#ifdef FILE_IN_OUT
fclose(fin);
fclose(fout);
#endif
return 0;
}