diff --git a/common.mk b/common.mk index 2da306d79..90c3da83f 100644 --- a/common.mk +++ b/common.mk @@ -1009,9 +1009,11 @@ BLIS_H_FLAT := $(BASE_INC_PATH)/$(BLIS_H) # # Isolate the path to cblas.h by filtering the file from the list of framework -# header files. +# header files, and then strip the filename to obtain the directory in which +# cblas.h resides. CBLAS_H := cblas.h CBLAS_H_SRC_PATH := $(filter %/$(CBLAS_H), $(FRAME_H99_FILES)) +CBLAS_H_DIRPATH := $(dir $(CBLAS_H_SRC_PATH)) # Construct the path to what will be the intermediate flattened/monolithic # cblas.h file. @@ -1037,7 +1039,8 @@ REF_KER_H_PATHS := $(strip $(foreach header, $(REF_KER_HEADERS), \ $(FRAME_H99_FILES))))) # Add -I to each header path so we can specify our include search paths to the -# C compiler. Then add frame/include since it's needed for bli_oapi_w[o]_cntx.h. +# C compiler. Then add frame/include since it's needed when compiling source +# files that #include bli_oapi_ba.h or bli_oapi_ex.h. REF_KER_I_PATHS := $(strip $(patsubst %, -I%, $(REF_KER_H_PATHS))) REF_KER_I_PATHS += -I$(DIST_PATH)/frame/include @@ -1046,6 +1049,13 @@ REF_KER_I_PATHS += -I$(DIST_PATH)/frame/include # now #include the monolithic/flattened blis.h instead. CINCFLAGS := -I$(BASE_INC_PATH) $(REF_KER_I_PATHS) +# If CBLAS is enabled, we also include the path to the cblas.h directory so +# that the compiler will be able to find cblas.h as the CBLAS source code is +# being compiled. +ifeq ($(MK_ENABLE_CBLAS),yes) +CINCFLAGS += -I$(CBLAS_H_DIRPATH) +endif + # Obtain a list of header paths in the configured sandbox. Then add -I to each # header path. CSBOXINCFLAGS := $(strip $(patsubst %, -I%, $(SANDBOX_HDR_DIRPATHS))) diff --git a/frame/compat/bli_blas.h b/frame/compat/bli_blas.h index 1ce976453..a65953c11 100644 --- a/frame/compat/bli_blas.h +++ b/frame/compat/bli_blas.h @@ -113,6 +113,7 @@ #include "bla_amax.h" #include "bla_asum.h" #include "bla_axpy.h" +#include "bla_axpby.h" #include "bla_copy.h" #include "bla_dot.h" #include "bla_nrm2.h" @@ -199,6 +200,11 @@ #include "bla_trsm_check.h" #include "bla_gemmt_check.h" +// -- Batch prototypes -- + +#include "bla_gemm_batch.h" + + // -- Fortran-compatible APIs to BLIS functions -- #include "b77_thread.h" diff --git a/frame/compat/cblas/src/cblas.h b/frame/compat/cblas/src/cblas.h index 85e24674d..cee74233c 100644 --- a/frame/compat/cblas/src/cblas.h +++ b/frame/compat/cblas/src/cblas.h @@ -1,3 +1,4 @@ + #ifndef CBLAS_H #define CBLAS_H #include @@ -595,6 +596,62 @@ void BLIS_EXPORT_BLAS cblas_zher2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, void BLIS_EXPORT_BLAS cblas_xerbla(f77_int p, const char *rout, const char *form, ...); + +/* + * =========================================================================== + * BLAS Extension prototypes + * =========================================================================== + */ + +// -- APIs to operations unique to BLIS -- + +void BLIS_EXPORT_BLAS cblas_saxpby(f77_int N, float alpha, const float *X, + f77_int incX, float beta, float *Y, f77_int incY); +void BLIS_EXPORT_BLAS cblas_daxpby(f77_int N, double alpha, const double *X, + f77_int incX, double beta, double *Y, f77_int incY); +void BLIS_EXPORT_BLAS cblas_caxpby(f77_int N, const void *alpha, + const void *X, f77_int incX, const void* beta, + void *Y, f77_int incY); +void BLIS_EXPORT_BLAS cblas_zaxpby(f77_int N, const void *alpha, + const void *X, f77_int incX, const void *beta, + void *Y, f77_int incY); + +// -- Batch APIs -- + +void BLIS_EXPORT_BLAS cblas_sgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const float *alpha_array, const float **A, + f77_int *lda_array, const float **B, f77_int *ldb_array, + const float *beta_array, float **C, f77_int *ldc_array, + f77_int group_count, f77_int *group_size); +void BLIS_EXPORT_BLAS cblas_dgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const double *alpha_array, + const double **A,f77_int *lda_array, + const double **B, f77_int *ldb_array, + const double *beta_array, double **C, f77_int *ldc_array, + f77_int group_count, f77_int *group_size); +void BLIS_EXPORT_BLAS cblas_cgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const void *alpha_array, const void **A, + f77_int *lda_array, const void **B, f77_int *ldb_array, + const void *beta_array, void **C, f77_int *ldc_array, + f77_int group_count, f77_int *group_size); +void BLIS_EXPORT_BLAS cblas_zgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const void *alpha_array, const void **A, + f77_int *lda_array, const void **B, f77_int *ldb_array, + const void *beta_array, void **C, f77_int *ldc_array, + f77_int group_count, f77_int *group_size); + #ifdef __cplusplus } #endif diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index 5e94fdf2c..e534d2054 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -14,7 +14,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -200,9 +200,20 @@ /* * BLAS extensions */ -#define F77_sgemmt sgemmt_ -#define F77_dgemmt dgemmt_ -#define F77_cgemmt cgemmt_ -#define F77_zgemmt zgemmt_ +#define F77_sgemmt sgemmt_ +#define F77_dgemmt dgemmt_ +#define F77_cgemmt cgemmt_ +#define F77_zgemmt zgemmt_ + +#define F77_saxpby saxpby_ +#define F77_daxpby daxpby_ +#define F77_caxpby caxpby_ +#define F77_zaxpby zaxpby_ + +#define F77_sgemm_batch sgemm_batch_ +#define F77_dgemm_batch dgemm_batch_ +#define F77_cgemm_batch cgemm_batch_ +#define F77_zgemm_batch zgemm_batch_ + #endif /* CBLAS_F77_H */ diff --git a/frame/compat/cblas/src/cblas_sgemm.c b/frame/compat/cblas/src/cblas_sgemm.c index 89d0f07a8..bf40b9c0d 100644 --- a/frame/compat/cblas/src/cblas_sgemm.c +++ b/frame/compat/cblas/src/cblas_sgemm.c @@ -7,6 +7,8 @@ * Written by Keita Teranishi * 4/8/1998 * + * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * */ #include "cblas.h" @@ -17,12 +19,12 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, f77_int lda, const float *B, f77_int ldb, float beta, float *C, f77_int ldc) { - char TA, TB; + char TA, TB; #ifdef F77_CHAR F77_CHAR F77_TA, F77_TB; #else - #define F77_TA &TA - #define F77_TB &TB + #define F77_TA &TA + #define F77_TB &TB #endif #ifdef F77_INT @@ -36,7 +38,7 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, #define F77_ldb ldb #define F77_ldc ldc #endif - + extern int CBLAS_CallFromC; extern int RowMajorStrg; RowMajorStrg = 0; @@ -46,9 +48,9 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, if(TransA == CblasTrans) TA='T'; else if ( TransA == CblasConjTrans ) TA='C'; else if ( TransA == CblasNoTrans ) TA='N'; - else + else { - cblas_xerbla(2, "cblas_sgemm", + cblas_xerbla(2, "cblas_sgemm", "Illegal TransA setting, %d\n", TransA); CBLAS_CallFromC = 0; RowMajorStrg = 0; @@ -58,9 +60,9 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, if(TransB == CblasTrans) TB='T'; else if ( TransB == CblasConjTrans ) TB='C'; else if ( TransB == CblasNoTrans ) TB='N'; - else + else { - cblas_xerbla(3, "cblas_sgemm", + cblas_xerbla(3, "cblas_sgemm", "Illegal TransB setting, %d\n", TransB); CBLAS_CallFromC = 0; RowMajorStrg = 0; @@ -79,9 +81,9 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, if(TransA == CblasTrans) TB='T'; else if ( TransA == CblasConjTrans ) TB='C'; else if ( TransA == CblasNoTrans ) TB='N'; - else + else { - cblas_xerbla(2, "cblas_sgemm", + cblas_xerbla(2, "cblas_sgemm", "Illegal TransA setting, %d\n", TransA); CBLAS_CallFromC = 0; RowMajorStrg = 0; @@ -90,10 +92,10 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, if(TransB == CblasTrans) TA='T'; else if ( TransB == CblasConjTrans ) TA='C'; else if ( TransB == CblasNoTrans ) TA='N'; - else + else { - cblas_xerbla(2, "cblas_sgemm", - "Illegal TransA setting, %d\n", TransA); + cblas_xerbla(2, "cblas_sgemm", + "Illegal TransB setting, %d\n", TransB); CBLAS_CallFromC = 0; RowMajorStrg = 0; return; @@ -104,7 +106,7 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, #endif F77_sgemm(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, &alpha, B, &F77_ldb, A, &F77_lda, &beta, C, &F77_ldc); - } else + } else cblas_xerbla(1, "cblas_sgemm", "Illegal Order setting, %d\n", Order); CBLAS_CallFromC = 0; diff --git a/frame/compat/cblas/src/extra/cblas_caxpby.c b/frame/compat/cblas/src/extra/cblas_caxpby.c new file mode 100644 index 000000000..e8400d91b --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_caxpby.c @@ -0,0 +1,27 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * cblas_caxpby.c + * + * The program is a C interface to caxpby. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_caxpby( f77_int N, const void *alpha, + const void *X, f77_int incX, + const void *beta, + void *Y, f77_int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_caxpby( &F77_N, (scomplex*)alpha, (scomplex*)X, &F77_incX, (scomplex*)beta, (scomplex*)Y, &F77_incY); +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_cgemm_batch.c b/frame/compat/cblas/src/extra/cblas_cgemm_batch.c new file mode 100644 index 000000000..18dd0bad5 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_cgemm_batch.c @@ -0,0 +1,168 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_cgemm_batch.c + * This program is a C interface to cgemm_batch. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const void *alpha_array, + const void **A_array, f77_int *lda_array, + const void **B_array, f77_int *ldb_array, + const void *beta_array, + void **C_array, f77_int *ldc_array, + f77_int group_count, f77_int *group_size) +{ + char TA[group_count], TB[group_count]; +#ifdef F77_CHAR + F77_CHAR F77_TA[group_count], F77_TB[group_count]; +#else + #define F77_TA TA + #define F77_TB TB +#endif + +#ifdef F77_INT + F77_INT F77_GRP_COUNT = group_count; + F77_INT F77_M[F77_GRP_COUNT], F77_N[F77_GRP_COUNT], F77_K[F77_GRP_COUNT]; + F77_INT F77_lda[F77_GRP_COUNT], F77_ldb[F77_GRP_COUNT], F77_ldc[F77_GRP_COUNT]; + F77_INT F77_GRP_SIZE[F77_GRP_COUNT]; +#else + #define F77_GRP_COUNT group_count + #define F77_M M_array + #define F77_N N_array + #define F77_K K_array + #define F77_lda lda_array + #define F77_ldb ldb_array + #define F77_ldc ldc_array + #define F77_GRP_SIZE group_size +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + dim_t i; + if( Order == CblasColMajor ) + { + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TA[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_cgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB_array[i] == CblasTrans) TB[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(3, "cblas_cgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA[i] = C2F_CHAR(TA+i); + F77_TB[i] = C2F_CHAR(TB+i); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE[i] = group_size[i]; +#endif + } + + F77_cgemm_batch(F77_TA, F77_TB, + F77_M, F77_N, F77_K, + (const scomplex*)alpha_array, + (const scomplex**)A_array, F77_lda, + (const scomplex**)B_array, F77_ldb, + (const scomplex*)beta_array, + (scomplex**)C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } + else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + dim_t i; + + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TB[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(2, "cblas_cgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB_array[i] == CblasTrans) TA[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_cgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE = group_size[i]; +#endif + } + + F77_cgemm_batch(F77_TA, F77_TB, + F77_N, F77_M, F77_K, + (const scomplex*)alpha_array, + (const scomplex**)B_array, F77_ldb, + (const scomplex**)A_array, F77_lda, + (const scomplex*)beta_array, + (scomplex**)C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } else + cblas_xerbla(1, "cblas_cgemm_batch", + "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_daxpby.c b/frame/compat/cblas/src/extra/cblas_daxpby.c new file mode 100644 index 000000000..8fbea4d5a --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_daxpby.c @@ -0,0 +1,26 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * cblas_daxpby.c + * + * The program is a C interface to daxpby. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_daxpby( f77_int N, double alpha, + const double *X, f77_int incX, + double beta, + double *Y, f77_int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_daxpby( &F77_N, &alpha, X, &F77_incX, &beta, Y, &F77_incY); +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_dgemm_batch.c b/frame/compat/cblas/src/extra/cblas_dgemm_batch.c new file mode 100644 index 000000000..a2bed3b1a --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_dgemm_batch.c @@ -0,0 +1,168 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_dgemm_batch.c + * This program is a C interface to dgemm_batch. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const double *alpha_array, + const double **A_array, f77_int *lda_array, + const double **B_array, f77_int *ldb_array, + const double *beta_array, + double **C_array, f77_int *ldc_array, + f77_int group_count, f77_int *group_size) +{ + char TA[group_count], TB[group_count]; +#ifdef F77_CHAR + F77_CHAR F77_TA[group_count], F77_TB[group_count]; +#else + #define F77_TA TA + #define F77_TB TB +#endif + +#ifdef F77_INT + F77_INT F77_GRP_COUNT = group_count; + F77_INT F77_M[F77_GRP_COUNT], F77_N[F77_GRP_COUNT], F77_K[F77_GRP_COUNT]; + F77_INT F77_lda[F77_GRP_COUNT], F77_ldb[F77_GRP_COUNT], F77_ldc[F77_GRP_COUNT]; + F77_INT F77_GRP_SIZE[F77_GRP_COUNT]; +#else + #define F77_GRP_COUNT group_count + #define F77_M M_array + #define F77_N N_array + #define F77_K K_array + #define F77_lda lda_array + #define F77_ldb ldb_array + #define F77_ldc ldc_array + #define F77_GRP_SIZE group_size +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + dim_t i; + if( Order == CblasColMajor ) + { + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TA[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_dgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB_array[i] == CblasTrans) TB[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(3, "cblas_dgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA[i] = C2F_CHAR(TA+i); + F77_TB[i] = C2F_CHAR(TB+i); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE[i] = group_size[i]; +#endif + } + + F77_dgemm_batch(F77_TA, F77_TB, + F77_M, F77_N, F77_K, + alpha_array, + A_array, F77_lda, + B_array, F77_ldb, + beta_array, + C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } + else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + dim_t i; + + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TB[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(2, "cblas_dgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB_array[i] == CblasTrans) TA[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_dgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE = group_size[i]; +#endif + } + + F77_dgemm_batch(F77_TA, F77_TB, + F77_N, F77_M, F77_K, + alpha_array, + B_array, F77_ldb, + A_array, F77_lda, + beta_array, + C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } else + cblas_xerbla(1, "cblas_dgemm_batch", + "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_saxpby.c b/frame/compat/cblas/src/extra/cblas_saxpby.c new file mode 100644 index 000000000..685282123 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_saxpby.c @@ -0,0 +1,28 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * cblas_saxpby.c + * + * The program is a C interface to saxpby. + * It calls the fortran wrapper before calling saxpby. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_saxpby( f77_int N, float alpha, + const float *X, f77_int incX, + float beta, + float *Y, f77_int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_saxpby( &F77_N, &alpha, X, &F77_incX, &beta, Y, &F77_incY); +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_sgemm_batch.c b/frame/compat/cblas/src/extra/cblas_sgemm_batch.c new file mode 100644 index 000000000..3e8517db2 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_sgemm_batch.c @@ -0,0 +1,168 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_sgemm_batch.c + * This program is a C interface to sgemm_batch. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const float *alpha_array, + const float **A_array, f77_int *lda_array, + const float **B_array, f77_int *ldb_array, + const float *beta_array, + float **C_array, f77_int *ldc_array, + f77_int group_count, f77_int *group_size) +{ + char TA[group_count], TB[group_count]; +#ifdef F77_CHAR + F77_CHAR F77_TA[group_count], F77_TB[group_count]; +#else + #define F77_TA TA + #define F77_TB TB +#endif + +#ifdef F77_INT + F77_INT F77_GRP_COUNT = group_count; + F77_INT F77_M[F77_GRP_COUNT], F77_N[F77_GRP_COUNT], F77_K[F77_GRP_COUNT]; + F77_INT F77_lda[F77_GRP_COUNT], F77_ldb[F77_GRP_COUNT], F77_ldc[F77_GRP_COUNT]; + F77_INT F77_GRP_SIZE[F77_GRP_COUNT]; +#else + #define F77_GRP_COUNT group_count + #define F77_M M_array + #define F77_N N_array + #define F77_K K_array + #define F77_lda lda_array + #define F77_ldb ldb_array + #define F77_ldc ldc_array + #define F77_GRP_SIZE group_size +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + dim_t i; + if( Order == CblasColMajor ) + { + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TA[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_sgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB_array[i] == CblasTrans) TB[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(3, "cblas_sgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA[i] = C2F_CHAR(TA+i); + F77_TB[i] = C2F_CHAR(TB+i); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE[i] = group_size[i]; +#endif + } + + F77_sgemm_batch(F77_TA, F77_TB, + F77_M, F77_N, F77_K, + alpha_array, + A_array, F77_lda, + B_array, F77_ldb, + beta_array, + C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } + else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + dim_t i; + + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TB[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(2, "cblas_sgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB_array[i] == CblasTrans) TA[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_sgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE = group_size[i]; +#endif + } + + F77_sgemm_batch(F77_TA, F77_TB, + F77_N, F77_M, F77_K, + alpha_array, + B_array, F77_ldb, + A_array, F77_lda, + beta_array, + C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } else + cblas_xerbla(1, "cblas_sgemm_batch", + "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_zaxpby.c b/frame/compat/cblas/src/extra/cblas_zaxpby.c new file mode 100644 index 000000000..483607ec9 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_zaxpby.c @@ -0,0 +1,27 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * cblas_zaxpby.c + * + * The program is a C interface to zaxpby. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. + * + */ +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zaxpby( f77_int N, const void *alpha, + const void *X, f77_int incX, + const void *beta, + void *Y, f77_int incY) +{ +#ifdef F77_INT + F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; +#else + #define F77_N N + #define F77_incX incX + #define F77_incY incY +#endif + F77_zaxpby( &F77_N, (dcomplex*)alpha, (dcomplex*)X, &F77_incX, (dcomplex*)beta, (dcomplex*)Y, &F77_incY); +} +#endif diff --git a/frame/compat/cblas/src/extra/cblas_zgemm_batch.c b/frame/compat/cblas/src/extra/cblas_zgemm_batch.c new file mode 100644 index 000000000..2d188a9f0 --- /dev/null +++ b/frame/compat/cblas/src/extra/cblas_zgemm_batch.c @@ -0,0 +1,168 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_zgemm_batch.c + * This program is a C interface to zgemm_batch. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zgemm_batch(enum CBLAS_ORDER Order, + enum CBLAS_TRANSPOSE *TransA_array, + enum CBLAS_TRANSPOSE *TransB_array, + f77_int *M_array, f77_int *N_array, + f77_int *K_array, const void *alpha_array, + const void **A_array, f77_int *lda_array, + const void **B_array, f77_int *ldb_array, + const void *beta_array, + void **C_array, f77_int *ldc_array, + f77_int group_count, f77_int *group_size) +{ + char TA[group_count], TB[group_count]; +#ifdef F77_CHAR + F77_CHAR F77_TA[group_count], F77_TB[group_count]; +#else + #define F77_TA TA + #define F77_TB TB +#endif + +#ifdef F77_INT + F77_INT F77_GRP_COUNT = group_count; + F77_INT F77_M[F77_GRP_COUNT], F77_N[F77_GRP_COUNT], F77_K[F77_GRP_COUNT]; + F77_INT F77_lda[F77_GRP_COUNT], F77_ldb[F77_GRP_COUNT], F77_ldc[F77_GRP_COUNT]; + F77_INT F77_GRP_SIZE[F77_GRP_COUNT]; +#else + #define F77_GRP_COUNT group_count + #define F77_M M_array + #define F77_N N_array + #define F77_K K_array + #define F77_lda lda_array + #define F77_ldb ldb_array + #define F77_ldc ldc_array + #define F77_GRP_SIZE group_size +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + dim_t i; + if( Order == CblasColMajor ) + { + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TA[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_zgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB_array[i] == CblasTrans) TB[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(3, "cblas_zgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA[i] = C2F_CHAR(TA+i); + F77_TB[i] = C2F_CHAR(TB+i); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE[i] = group_size[i]; +#endif + } + + F77_zgemm_batch(F77_TA, F77_TB, + F77_M, F77_N, F77_K, + (const dcomplex*)alpha_array, + (const dcomplex**)A_array, F77_lda, + (const dcomplex**)B_array, F77_ldb, + (const dcomplex*)beta_array, + (dcomplex**)C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } + else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + dim_t i; + + for(i = 0; i < group_count; i++) + { + if(TransA_array[i] == CblasTrans) TB[i]='T'; + else if ( TransA_array[i] == CblasConjTrans ) TB[i]='C'; + else if ( TransA_array[i] == CblasNoTrans ) TB[i]='N'; + else + { + cblas_xerbla(2, "cblas_zgemm_batch", + "Illegal TransA setting %d for group %d\n", TransA_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB_array[i] == CblasTrans) TA[i]='T'; + else if ( TransB_array[i] == CblasConjTrans ) TA[i]='C'; + else if ( TransB_array[i] == CblasNoTrans ) TA[i]='N'; + else + { + cblas_xerbla(2, "cblas_zgemm_batch", + "Illegal TransB setting %d for group %d\n", TransB_array[i], i); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + +#ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); +#endif + +#ifdef F77_INT + F77_M[i] = M_array[i]; + F77_N[i] = N_array[i]; + F77_K[i] = K_array[i]; + F77_lda[i] = lda_array[i]; + F77_ldb[i] = ldb_array[i]; + F77_ldc[i] = ldc_array[i]; + F77_GRP_SIZE = group_size[i]; +#endif + } + + F77_zgemm_batch(F77_TA, F77_TB, + F77_N, F77_M, F77_K, + (const dcomplex*)alpha_array, + (const dcomplex**)B_array, F77_ldb, + (const dcomplex**)A_array, F77_lda, + (const dcomplex*)beta_array, + (dcomplex**)C_array, F77_ldc, + &F77_GRP_COUNT, F77_GRP_SIZE); + } else + cblas_xerbla(1, "cblas_zgemm_batch", + "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; +} +#endif diff --git a/frame/compat/extra/bla_axpby.c b/frame/compat/extra/bla_axpby.c new file mode 100644 index 000000000..d96d75d74 --- /dev/null +++ b/frame/compat/extra/bla_axpby.c @@ -0,0 +1,89 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + (ftype*)alpha, \ + x0, incx0, \ + (ftype*)beta, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTFUNC_BLAS( axpby, axpbyv ) +#endif diff --git a/frame/compat/extra/bla_axpby.h b/frame/compat/extra/bla_axpby.h new file mode 100644 index 000000000..ab2952be9 --- /dev/null +++ b/frame/compat/extra/bla_axpby.h @@ -0,0 +1,54 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype BLAS-to-BLIS interfaces. +// +#undef GENTPROT +#define GENTPROT( ftype, ch, blasname ) \ +\ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ); + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTPROT_BLAS( axpby ) +#endif + diff --git a/frame/compat/extra/bla_gemm_batch.c b/frame/compat/extra/bla_gemm_batch.c new file mode 100644 index 000000000..be84572a3 --- /dev/null +++ b/frame/compat/extra/bla_gemm_batch.c @@ -0,0 +1,254 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// + +#ifdef BLIS_BLAS3_CALLS_TAPI + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa_array, \ + const f77_char* transb_array, \ + const f77_int* m_array, \ + const f77_int* n_array, \ + const f77_int* k_array, \ + const ftype* alpha_array, \ + const ftype** a_array, const f77_int* lda_array, \ + const ftype** b_array, const f77_int* ldb_array, \ + const ftype* beta_array, \ + ftype** c_array, const f77_int* ldc_array, \ + const f77_int* group_count, \ + const f77_int* group_size \ + ) \ +{ \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ + inc_t rs_a, cs_a; \ + inc_t rs_b, cs_b; \ + inc_t rs_c, cs_c; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + for ( f77_int gi = 0; gi < *group_count; gi++ ) \ + { \ + PASTEBLACHK(blisname) \ + ( \ + MKSTR(ch), \ + MKSTR(blisname), \ + transa_array+gi, \ + transb_array+gi, \ + m_array+gi, \ + n_array+gi, \ + k_array+gi, \ + lda_array+gi, \ + ldb_array+gi, \ + ldc_array+gi \ + ); \ + } \ +\ + f77_int idx = 0; \ +\ + for ( f77_int i = 0; i < *group_count; i++ ) \ + { \ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( transa_array[i], &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( transb_array[i], &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( m_array[i], m0 ); \ + bli_convert_blas_dim1( n_array[i], n0 ); \ + bli_convert_blas_dim1( k_array[i], k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + rs_a = 1; \ + cs_a = lda_array[i]; \ + rs_b = 1; \ + cs_b = ldb_array[i]; \ + rs_c = 1; \ + cs_c = ldc_array[i]; \ +\ + for ( f77_int j = 0; j < group_size[i]; j++ ) \ + { \ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + blis_transb, \ + m0, \ + n0, \ + k0, \ + (ftype*)(alpha_array + i), \ + (ftype*)*(a_array + idx), rs_a, cs_a, \ + (ftype*)*(b_array + idx), rs_b, cs_b, \ + (ftype*)(beta_array + i), \ + (ftype*)*(c_array + idx), rs_c, cs_c, \ + NULL, \ + NULL \ + ); \ +\ + idx++; \ + } \ + } \ +\ + bli_finalize_auto(); \ +} + +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa_array, \ + const f77_char* transb_array, \ + const f77_int* m_array, \ + const f77_int* n_array, \ + const f77_int* k_array, \ + const ftype* alpha_array, \ + const ftype** a_array, const f77_int* lda_array, \ + const ftype** b_array, const f77_int* ldb_array, \ + const ftype* beta_array, \ + ftype** c_array, const f77_int* ldc_array, \ + const f77_int* group_count, \ + const f77_int* group_size ) \ +{ \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + for ( f77_int gi = 0; gi < *group_count; gi++ ) \ + { \ + PASTEBLACHK(blisname) \ + ( \ + MKSTR(ch), \ + MKSTR(blisname), \ + transa_array+gi, \ + transb_array+gi, \ + m_array+gi, \ + n_array+gi, \ + k_array+gi, \ + lda_array+gi, \ + ldb_array+gi, \ + ldc_array+gi \ + ); \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + f77_int idx = 0, i, j; \ +\ + for ( i = 0; i < *group_count; i++ ) \ + { \ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( transa_array[i], &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( transb_array[i], &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( m_array[i], m0 ); \ + bli_convert_blas_dim1( n_array[i], n0 ); \ + bli_convert_blas_dim1( k_array[i], k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = lda_array[i]; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = ldb_array[i]; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = ldc_array[i]; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)(alpha_array + i), &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)(beta_array + i), &betao ); \ +\ + for( j = 0; j < group_size[i]; j++ ) \ + { \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)*(a_array + idx), rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)*(b_array + idx), rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)*(c_array + idx), rs_c, cs_c, &co ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + idx++; \ + } \ + } \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTFUNC_BLAS( gemm_batch, gemm ) +#endif + diff --git a/frame/compat/extra/bla_gemm_batch.h b/frame/compat/extra/bla_gemm_batch.h new file mode 100644 index 000000000..f997f4b8e --- /dev/null +++ b/frame/compat/extra/bla_gemm_batch.h @@ -0,0 +1,61 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype BLAS-to-BLIS interfaces. +// +#undef GENTPROT +#define GENTPROT( ftype, ch, blasname ) \ +\ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa_array, \ + const f77_char* transb_array, \ + const f77_int* m_array, \ + const f77_int* n_array, \ + const f77_int* k_array, \ + const ftype* alpha_array, \ + const ftype** a_array, const f77_int* lda_array, \ + const ftype** b_array, const f77_int* ldb_array, \ + const ftype* beta_array, \ + ftype** c_array, const f77_int* ldc_array, \ + const f77_int* group_count, \ + const f77_int* group_size \ + ); + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTPROT_BLAS( gemm_batch ) +#endif + diff --git a/test/Makefile b/test/Makefile index bbd817f2d..ae998ccde 100644 --- a/test/Makefile +++ b/test/Makefile @@ -1,11 +1,11 @@ # # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc. +# Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -143,9 +143,9 @@ CFLAGS += -I$(TEST_SRC_PATH) # # Define the operations we will test. -TEST_OPS := dotv axpyv \ +TEST_OPS := dotv axpyv axpbyv\ gemv ger hemv her her2 trmv trsv \ - gemm hemm herk her2k trmm trsm + gemm gemm_batch hemm herk her2k trmm trsm # Optionally test gemmt, which some libraries might not implement. ifeq ($(BUILD_GEMMT),yes) diff --git a/test/test_axpbyv.c b/test/test_axpbyv.c new file mode 100644 index 000000000..28be2542c --- /dev/null +++ b/test/test_axpbyv.c @@ -0,0 +1,293 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + +//#define PRINT +#ifdef BLIS_ENABLE_CBLAS +//#define CHECK_CBLAS +#endif + +#ifdef CHECK_CBLAS +#include "cblas.h" +#endif + +/* + * BLIS interface API will be called by default. + * To call BLAS API, modify line 159 to '#if 0'. + * To call cblas API, modify line 159 to '#if 0'and define the + * macro 'CHECK_CBLAS' in line 44 + * + *Sample prototype for BLAS interface API is as follows: + * n alpha x incx beta y incy + *void daxpbyv_( int*, double*, double*, int*, double*, double*, int* ); + */ + +int main( int argc, char** argv ) +{ + obj_t x, y; + obj_t y_save; + obj_t alpha, beta; + dim_t n; + dim_t p; + dim_t p_begin, p_end, p_inc; + int n_input; + num_t dt_x, dt_y; + num_t dt_alpha, dt_beta; + int r, n_repeats; + num_t dt; + + double dtime; + double dtime_save; + double gflops; + + bli_init(); + + n_repeats = 3; + +#ifndef PRINT + p_begin = 40; + p_end = 4000; + p_inc = 40; + + n_input = -1; +#else + p_begin = 16; + p_end = 16; + p_inc = 1; + + n_input = 15; +#endif + +#if 1 + dt = BLIS_FLOAT; + //dt = BLIS_DOUBLE; +#else + //dt = BLIS_SCOMPLEX; + dt = BLIS_DCOMPLEX; +#endif + + + dt_x = dt_y = dt_alpha = dt_beta = dt; + + // Begin with initializing the last entry to zero so that + // matlab allocates space for the entire array once up-front. + for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; +#ifdef BLIS + printf( "data_axpbyv_blis" ); +#else + printf( "data_axpbyv_%s", BLAS ); +#endif + printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )0, 0.0 ); + + //for ( p = p_begin; p <= p_end; p += p_inc ) + for ( p = p_end; p_begin <= p; p -= p_inc ) + { + + if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + + bli_obj_create( dt_alpha, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt_beta, 1, 1, 0, 0, &beta ); + + bli_obj_create( dt_x, n, 1, 0, 0, &x ); + bli_obj_create( dt_y, n, 1, 0, 0, &y ); + bli_obj_create( dt_y, n, 1, 0, 0, &y_save ); + + bli_randm( &x ); + bli_randm( &y ); + + bli_setsc( (0.9/1.0), 0.2, &alpha ); + bli_setsc( -(1.1/1.0), 0.3, &beta ); + + bli_copym( &y, &y_save ); + + dtime_save = 1.0e9; + + for ( r = 0; r < n_repeats; ++r ) + { + bli_copym( &y_save, &y ); + + dtime = bli_clock(); + +#ifdef PRINT + bli_printm( "alpha", &alpha, "%4.1f", "" ); + bli_printm( "beta" , &beta, "%4.1f", "" ); + + bli_printm( "x", &x, "%4.1f", "" ); + bli_printm( "y", &y, "%4.1f", "" ); +#endif + +#ifdef BLIS + + bli_axpbyv( &alpha, + &x, + &beta, + &y ); +#else + if ( bli_is_float( dt ) ) + { + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + float alphap = *(( float * )bli_obj_buffer( &alpha )); + float betap = *(( float * )bli_obj_buffer( &beta )); + float* xp = bli_obj_buffer( &x ); + float* yp = bli_obj_buffer( &y ); +#ifdef CHECK_CBLAS + cblas_saxpby( nn, + alphap, + xp, incx, + betap, + yp, incy ); +#else + saxpby_( &nn, + &alphap, + xp, &incx, + &betap, + yp, &incy ); + +#endif + } + else if ( bli_is_double( dt ) ) + { + + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + double alphap = *(( double * )bli_obj_buffer( &alpha )); + double betap = *(( double * )bli_obj_buffer( &beta )); + double* xp = bli_obj_buffer( &x ); + double* yp = bli_obj_buffer( &y ); +#ifdef CHECK_CBLAS + cblas_daxpby( nn, + alphap, + xp, incx, + betap, + yp, incy ); +#else + daxpby_( &nn, + &alphap, + xp, &incx, + &betap, + yp, &incy ); +#endif + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + void* alphap = bli_obj_buffer( &alpha ); + void* betap = bli_obj_buffer( &beta ); + void* xp = bli_obj_buffer( &x ); + void* yp = bli_obj_buffer( &y ); +#ifdef CHECK_CBLAS + cblas_caxpby( nn, + alphap, + xp, incx, + betap, + yp, incy ); +#else + caxpby_( &nn, + ( scomplex* )alphap, + ( scomplex* )xp, &incx, + ( scomplex* )betap, + ( scomplex* )yp, &incy ); +#endif + } + else if ( bli_is_dcomplex( dt )) + { + f77_int nn = bli_obj_length( &x ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + void* alphap = bli_obj_buffer( &alpha ); + void* betap = bli_obj_buffer( &beta ); + void* xp = bli_obj_buffer( &x ); + void* yp = bli_obj_buffer( &y ); +#ifdef CHECK_CBLAS + cblas_zaxpby( nn, + alphap, + xp, incx, + betap, + yp, incy ); +#else + zaxpby_( &nn, + ( dcomplex* )alphap, + ( dcomplex* )xp, &incx, + ( dcomplex* )betap, + ( dcomplex* )yp, &incy ); +#endif + } +#endif + +#ifdef PRINT + bli_printm( "y after", &y, "%4.1f", "" ); + exit(1); +#endif + + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = ( 3.0 * n ) / ( dtime_save * 1.0e9 ); + +#ifdef BLIS + printf( "data_axpbyv_blis" ); +#else + printf( "data_axpbyv_%s", BLAS ); +#endif + printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )n, gflops ); + + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &x ); + bli_obj_free( &y ); + bli_obj_free( &y_save ); + } + + bli_finalize(); + + return 0; +} diff --git a/test/test_gemm_batch.c b/test/test_gemm_batch.c new file mode 100644 index 000000000..5660e4150 --- /dev/null +++ b/test/test_gemm_batch.c @@ -0,0 +1,584 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + +//#define CHECK_CBLAS +#ifdef CHECK_CBLAS +#include "cblas.h" +#endif + +/* Format for FILE input + * For each input set, first line contains 'storage scheme' + * and 'group count' seperated by space. + * Following 'group_count' number of lines contains all the parameters of + * each group separated by space in each line in the following order: + * tA tB m n k lda ldb ldc alpha_r alpha_i beta_r beta_i group_size + * + * Example: + * c 2 + * n n 4 8 4 4 4 4 1.1 0.0 0.9 0.0 2 + * n n 3 3 6 3 6 3 1.0 0.0 2.0 0.0 2 + * + */ + +//#define FILE_IN_OUT +#ifndef FILE_IN_OUT +#define GRP_COUNT 2 +#endif + +//#define PRINT + +int main( int argc, char** argv ) +{ + num_t dt; + + char stor_scheme; + dim_t i, j, idx; + dim_t r, n_repeats; + + double dtime; + double dtime_save; + double gflops; + + dim_t total_count = 0; + +#if 1 + dt = BLIS_FLOAT; + //dt = BLIS_DOUBLE; +#else + dt = BLIS_SCOMPLEX; + //dt = BLIS_DCOMPLEX; +#endif + + n_repeats = 1; + +#ifdef FILE_IN_OUT + FILE* fin = NULL; + FILE* fout = NULL; + + if(argc < 3) + { + printf("Usage: ./test_gemm_batch_XX.x input.csv output.csv\n"); + exit(1); + } + + fin = fopen(argv[1], "r"); + if( fin == NULL ) + { + printf("Error opening input file %s \n", argv[1]); + exit(1); + } + + fout = fopen(argv[2], "w"); + if(fout == NULL) + { + printf("Error opening output file %s\n",argv[2]); + exit(1); + } + + dim_t GRP_COUNT; + + fprintf(fout, "m\t n\t k\t lda\t ldb\t ldc\t transa\t transb\t grp_size\n"); + + while(fscanf(fin, "%c %ld\n", &stor_scheme, &GRP_COUNT) == 2) + { + char transa[GRP_COUNT]; + char transb[GRP_COUNT]; + + dim_t m[GRP_COUNT]; + dim_t n[GRP_COUNT]; + dim_t k[GRP_COUNT]; + + dim_t lda[GRP_COUNT]; + dim_t ldb[GRP_COUNT]; + dim_t ldc[GRP_COUNT]; + + double alpha_real[GRP_COUNT]; + double alpha_imag[GRP_COUNT]; + double beta_real[GRP_COUNT]; + double beta_imag[GRP_COUNT]; + + dim_t group_size[GRP_COUNT]; + obj_t alpha[GRP_COUNT], beta[GRP_COUNT]; + + total_count = 0; + for(i = 0; i < GRP_COUNT; i++) + { + fscanf(fin, "%c %c %ld %ld %ld %ld %ld %ld %lf %lf %lf %lf %ld\n", &transa[i], &transb[i], &m[i], &n[i], &k[i], &lda[i], &ldb[i], &ldc[i], &alpha_real[i], &alpha_imag[i], &beta_real[i], &beta_imag[i], &group_size[i]); + + total_count += group_size[i]; + } +#else + printf("m\t n\t k\t lda\t ldb\t ldc\t transa\t transb\t grp_size\n"); + + stor_scheme = 'c'; + + dim_t m[GRP_COUNT] = {4, 3}; + dim_t n[GRP_COUNT] = {8, 3}; + dim_t k[GRP_COUNT] = {4, 6}; + + dim_t lda[GRP_COUNT] = {4, 3}; + dim_t ldb[GRP_COUNT] = {4, 6}; + dim_t ldc[GRP_COUNT] = {4, 3}; + + char transa[GRP_COUNT] = {'N', 'N'}; + char transb[GRP_COUNT] = {'N', 'N'}; + + double alpha_real[GRP_COUNT] = {1.1, 1.0}; + double alpha_imag[GRP_COUNT] = {0.0, 0.0}; + + double beta_real[GRP_COUNT] = {0.9, 2.0}; + double beta_imag[GRP_COUNT] = {0.0, 0.0}; + + dim_t group_size[GRP_COUNT] = {2,2}; + + obj_t alpha[GRP_COUNT], beta[GRP_COUNT]; + + total_count = 0; + for(i = 0; i < GRP_COUNT; i++) + total_count += group_size[i]; + +#endif + obj_t a[total_count], b[total_count]; + obj_t c[total_count], c_save[total_count]; + f77_int f77_m[GRP_COUNT], f77_n[GRP_COUNT], f77_k[GRP_COUNT]; + f77_int f77_lda[GRP_COUNT], f77_ldb[GRP_COUNT], f77_ldc[GRP_COUNT]; + f77_int f77_group_size[GRP_COUNT]; + f77_int f77_group_count = GRP_COUNT; +#ifdef CHECK_CBLAS + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa[GRP_COUNT]; + enum CBLAS_TRANSPOSE cblas_transb[GRP_COUNT]; + + if(stor_scheme == 'R' || stor_scheme == 'r') + cblas_order = CblasRowMajor; + else + cblas_order = CblasColMajor; + +#else + f77_char f77_transa[GRP_COUNT]; + f77_char f77_transb[GRP_COUNT]; + + if(stor_scheme == 'r' || stor_scheme == 'R' ) + { + printf("BLAS Interface doesn't support row-major order\n"); +#ifdef FILE_IN_OUT + continue; +#else + exit(1); +#endif + } +#endif + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + { + bli_obj_create(dt, 1, 1, 0, 0, &alpha[i]); + bli_obj_create(dt, 1, 1, 0, 0, &beta[i] ); + + bli_setsc(alpha_real[i], alpha_imag[i], &alpha[i]); + bli_setsc(beta_real[i], beta_imag[i], &beta[i] ); + + trans_t blis_transa, blis_transb; + if(transa[i] == 't' || transa[i] == 'T') + blis_transa = BLIS_TRANSPOSE; + else if (transa[i] == 'c' || transa[i] == 'C') + blis_transa = BLIS_CONJ_TRANSPOSE; + else if ( transa[i] == 'n' || transa[i] == 'N') + blis_transa = BLIS_NO_TRANSPOSE; + else + { + printf("Illegal transA setting %c for group %ld\n", transa[i], i); + exit(1); + } + + if(transb[i] == 't' || transb[i] == 'T') + blis_transb = BLIS_TRANSPOSE; + else if (transb[i] == 'c' || transb[i] == 'C') + blis_transb = BLIS_CONJ_TRANSPOSE; + else if (transb[i] == 'n' || transb[i] == 'N') + blis_transb = BLIS_NO_TRANSPOSE; + else + { + printf("Illegal transB setting %c for group %ld\n", transb[i], i); + exit(1); + } +#ifdef CHECK_CBLAS + if(bli_is_trans( blis_transa )) + cblas_transa[i] = CblasTrans; + else if (bli_is_conjtrans( blis_transa )) + cblas_transa[i] = CblasConjTrans; + else + cblas_transa[i] = CblasNoTrans; + + if(bli_is_trans( blis_transb )) + cblas_transb[i] = CblasTrans; + else if (bli_is_conjtrans( blis_transb )) + cblas_transb[i] = CblasConjTrans; + else + cblas_transb[i] = CblasNoTrans; +#else + bli_param_map_blis_to_netlib_trans( blis_transa, &f77_transa[i]); + bli_param_map_blis_to_netlib_trans( blis_transb, &f77_transb[i]); + +#endif + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + bli_set_dims_with_trans( blis_transa, m[i], k[i], &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k[i], n[i], &m0_b, &n0_b ); + if(stor_scheme == 'C' || stor_scheme == 'c') + { + for(j = 0; j < group_size[i]; j++) + { + bli_obj_create(dt, m0_a, n0_a, 1, lda[i], &a[idx]); + bli_obj_create(dt, m0_b, n0_b, 1, ldb[i], &b[idx]); + bli_obj_create(dt, m[i], n[i], 1, ldc[i], &c[idx]); + bli_obj_create(dt, m[i], n[i], 1, ldc[i], &c_save[idx]); + + bli_randm( &a[idx] ); + bli_randm( &b[idx] ); + bli_randm( &c[idx] ); + + bli_obj_set_conjtrans(blis_transa, &a[idx]); + bli_obj_set_conjtrans(blis_transb, &b[idx]); + idx++; + } + } + else if(stor_scheme == 'R' || stor_scheme == 'r') + { + for(j = 0; j < group_size[i]; j++) + { + bli_obj_create(dt, m0_a, n0_a, lda[i], 1, &a[idx]); + bli_obj_create(dt, m0_b, n0_b, ldb[i], 1, &b[idx]); + bli_obj_create(dt, m[i], n[i], ldc[i], 1, &c[idx]); + bli_obj_create(dt, m[i], n[i], ldc[i], 1, &c_save[idx]); + + bli_randm( &a[idx] ); + bli_randm( &b[idx] ); + bli_randm( &c[idx] ); + + bli_obj_set_conjtrans(blis_transa, &a[idx]); + bli_obj_set_conjtrans(blis_transb, &b[idx]); + idx++; + } + } + f77_m[i] = m[i]; + f77_n[i] = n[i]; + f77_k[i] = k[i]; + f77_lda[i] = lda[i]; + f77_ldb[i] = ldb[i]; + f77_ldc[i] = ldc[i]; + f77_group_size[i] = group_size[i]; + + } + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + for(j = 0; j < group_size[i]; j++) + { + bli_copym(&c[idx], &c_save[idx]); + idx++; + } + + dtime_save = DBL_MAX; + + for( r = 0; r < n_repeats; ++r ) + { + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + for(j = 0; j < group_size[i]; j++) + { + bli_copym( &c_save[idx], &c[idx]); + idx++; + } + + dtime = bli_clock(); + +#ifdef PRINT + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + for(j = 0; j < group_size[i]; j++) + { + printf("Group: %ld Member: %ld\n", i, j); + + bli_printm("a", &a[idx], "%4.1f", ""); + bli_printm("b", &b[idx], "%4.1f", ""); + bli_printm("c", &c[idx], "%4.1f", ""); + + idx++; + } +#endif + + if(bli_is_float(dt)) + { + const float *ap[total_count], *bp[total_count]; + float *cp[total_count]; + float alphap[GRP_COUNT], betap[GRP_COUNT]; + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + { + for(j = 0; j < group_size[i]; j++) + { + ap[idx] = bli_obj_buffer( &a[idx] ); + bp[idx] = bli_obj_buffer( &b[idx] ); + cp[idx] = bli_obj_buffer( &c[idx] ); + + idx++; + } + alphap[i] = *(float*)bli_obj_buffer_for_1x1(dt, &alpha[i]); + betap[i] = *(float*)bli_obj_buffer_for_1x1(dt, &beta[i] ); + } + +#ifdef CHECK_CBLAS + cblas_sgemm_batch( cblas_order, + cblas_transa, + cblas_transb, + f77_m, f77_n, f77_k, + alphap, ap, f77_lda, + bp, f77_ldb, + betap, cp, f77_ldc, + f77_group_count, + f77_group_size + ); +#else + sgemm_batch_( f77_transa, + f77_transb, + f77_m, f77_n, f77_k, + alphap, ap, f77_lda, + bp, f77_ldb, + betap, cp, f77_ldc, + &f77_group_count, + f77_group_size + ); +#endif + + } + else if(bli_is_double(dt)) + { + const double *ap[total_count], *bp[total_count]; + double *cp[total_count]; + double alphap[GRP_COUNT], betap[GRP_COUNT]; + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + { + for(j = 0; j < group_size[i]; j++) + { + ap[idx] = bli_obj_buffer( &a[idx] ); + bp[idx] = bli_obj_buffer( &b[idx] ); + cp[idx] = bli_obj_buffer( &c[idx] ); + + idx++; + } + alphap[i] = *(double*)bli_obj_buffer_for_1x1(dt, &alpha[i]); + betap[i] = *(double*)bli_obj_buffer_for_1x1(dt, &beta[i] ); + } +#ifdef CHECK_CBLAS + cblas_dgemm_batch( cblas_order, + cblas_transa, + cblas_transb, + f77_m, f77_n, f77_k, + alphap, ap, f77_lda, + bp, f77_ldb, + betap, cp, f77_ldc, + f77_group_count, + f77_group_size + ); +#else + dgemm_batch_( f77_transa, + f77_transb, + f77_m, f77_n, f77_k, + alphap, ap, f77_lda, + bp, f77_ldb, + betap, cp, f77_ldc, + &f77_group_count, + f77_group_size + ); +#endif + + } + else if(bli_is_scomplex(dt)) + { + const scomplex *ap[total_count], *bp[total_count]; + scomplex *cp[total_count]; + scomplex alphap[GRP_COUNT], betap[GRP_COUNT]; + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + { + for(j = 0; j < group_size[i]; j++) + { + ap[idx] = bli_obj_buffer( &a[idx] ); + bp[idx] = bli_obj_buffer( &b[idx] ); + cp[idx] = bli_obj_buffer( &c[idx] ); + + idx++; + } + alphap[i] = *(scomplex*)bli_obj_buffer_for_1x1(dt, &alpha[i]); + betap[i] = *(scomplex*)bli_obj_buffer_for_1x1(dt, &beta[i] ); + } +#ifdef CHECK_CBLAS + cblas_cgemm_batch( cblas_order, + cblas_transa, + cblas_transb, + f77_m, f77_n, f77_k, + (const void*)alphap, + (const void**)ap, f77_lda, + (const void**)bp, f77_ldb, + (const void*)betap, (void**)cp, f77_ldc, + f77_group_count, + f77_group_size + ); +#else + cgemm_batch_( f77_transa, + f77_transb, + f77_m, f77_n, f77_k, + alphap, ap, f77_lda, + bp, f77_ldb, + betap, cp, f77_ldc, + &f77_group_count, + f77_group_size + ); +#endif + } + else if(bli_is_dcomplex(dt)) + { + const dcomplex *ap[total_count], *bp[total_count]; + dcomplex *cp[total_count]; + dcomplex alphap[GRP_COUNT], betap[GRP_COUNT]; + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + { + for(j = 0; j < group_size[i]; j++) + { + ap[idx] = bli_obj_buffer( &a[idx] ); + bp[idx] = bli_obj_buffer( &b[idx] ); + cp[idx] = bli_obj_buffer( &c[idx] ); + + idx++; + } + alphap[i] = *(dcomplex*)bli_obj_buffer_for_1x1(dt, &alpha[i]); + betap[i] = *(dcomplex*)bli_obj_buffer_for_1x1(dt, &beta[i] ); + } + +#ifdef CHECK_CBLAS + cblas_zgemm_batch( cblas_order, + cblas_transa, + cblas_transb, + f77_m, f77_n, f77_k, + (const void*)alphap, + (const void**)ap, f77_lda, + (const void**)bp, f77_ldb, + (const void*)betap, (void**)cp, f77_ldc, + f77_group_count, + f77_group_size + ); +#else + zgemm_batch_( f77_transa, + f77_transb, + f77_m, f77_n, f77_k, + alphap, ap, f77_lda, + bp, f77_ldb, + betap, cp, f77_ldc, + &f77_group_count, + f77_group_size + ); +#endif + } +#ifdef PRINT + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + for(j = 0; j < group_size[i]; j++) + { + printf("Group: %ld Member: %ld\n", i, j); + bli_printm("c after", &c[idx], "%4.1f", ""); + + idx++; + } +#endif + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + dim_t fp_ops = 0; + for(i = 0; i < GRP_COUNT; i++) + fp_ops += 2.0 * m[i] * k[i] * n[i] * group_size[i]; + + gflops = fp_ops / (dtime_save * 1.0e9 ); + + if(bli_is_complex( dt ) ) gflops *= 4.0; + +#ifdef FILE_IN_OUT + fprintf(fout, "Stor_scheme = %c, group_count = %lu, gflops = %7.2f\n", stor_scheme, GRP_COUNT, gflops); + for(i = 0; i < GRP_COUNT; i++) + fprintf(fout, "%4lu \t %4lu\t %4lu\t %4lu\t %4lu\t %4lu\t %c\t %c\t %4lu\n", m[i], n[i], k[i], lda[i], ldb[i], ldc[i], transa[i], transb[i], group_size[i]); + + fflush(fout); +#else + printf( "Stor_scheme = %c, group_count = %d, gflops = %7.2f\n", stor_scheme, GRP_COUNT, gflops); + for(i = 0; i < GRP_COUNT; i++) + printf("%4lu \t %4lu\t %4lu\t %4lu\t %4lu\t %4lu\t %c\t %c\t %4lu\n", m[i], n[i], k[i], lda[i], ldb[i], ldc[i], transa[i], transb[i], group_size[i]); + +#endif + + idx = 0; + for(i = 0; i < GRP_COUNT; i++) + { + bli_obj_free( &alpha[i]); + bli_obj_free( &beta[i] ); + + for(j = 0; j < group_size[i]; j++ ) + { + bli_obj_free( &a[idx]); + bli_obj_free( &b[idx]); + bli_obj_free( &c[idx]); + bli_obj_free( &c_save[idx]); + + idx++; + } + } +#ifdef FILE_IN_OUT + } + fclose(fin); + fclose(fout); +#endif + return 0; +} +