Added BLAS/CBLAS APIs for axpby, gemm_batch. (#566)

Details:
- Expanded the BLAS compatibility layer to include support for 
  ?axpby_() and ?gemm_batch_(). The former is a straightforward
  BLAS-like interface into the axpbyv operation while the latter
  implements a batched gemm via loops over bli_?gemm(). Also
  expanded the CBLAS compatibility layer to include support for
  cblas_?axpby() and cblas_?gemm_batch(), which serve as wrappers to 
  the corresponding (new) BLAS-like APIs. Thanks to Meghana Vankadari
  for submitting these new APIs via #566.
- Fixed a long-standing bug in common.mk that for some reason never
  manifested until now. Previously, CBLAS source files were compiled
  *without* the location of cblas.h being specified via a -I flag.
  I'm not sure why this worked, but it may be due to the fact that
  the cblas.h file resided in the same directory as all of the CBLAS
  source, and perhaps compilers implicitly add a -I flag for the
  directory that corresponds to the location of the source file being
  compiled. This bug only showed up because some CBLAS-like source code
  was moved into an 'extra' subdirectory of that frame/compat/cblas/src
  directory. After moving the code, compilation for those files failed
  (because the cblas.h header file, presumably, could not be found in
  the same location). This bug was fixed within common.mk by explicitly
  adding the cblas.h directory to the list of -I flags passed to the
  compiler.
- Added test_axpbyv.c and test_gemm_batch.c files to 'test' directory,
  and updated test/Makefile to build those drivers.
- Fixed typo in error message string in cblas_sgemm.c.
This commit is contained in:
Meghana-vankadari
2021-11-12 04:16:14 +05:30
committed by GitHub
parent 28b0982ea7
commit 7bc8ab485e
20 changed files with 2226 additions and 25 deletions

View File

@@ -1009,9 +1009,11 @@ BLIS_H_FLAT := $(BASE_INC_PATH)/$(BLIS_H)
#
# Isolate the path to cblas.h by filtering the file from the list of framework
# header files.
# header files, and then strip the filename to obtain the directory in which
# cblas.h resides.
CBLAS_H := cblas.h
CBLAS_H_SRC_PATH := $(filter %/$(CBLAS_H), $(FRAME_H99_FILES))
CBLAS_H_DIRPATH := $(dir $(CBLAS_H_SRC_PATH))
# Construct the path to what will be the intermediate flattened/monolithic
# cblas.h file.
@@ -1037,7 +1039,8 @@ REF_KER_H_PATHS := $(strip $(foreach header, $(REF_KER_HEADERS), \
$(FRAME_H99_FILES)))))
# Add -I to each header path so we can specify our include search paths to the
# C compiler. Then add frame/include since it's needed for bli_oapi_w[o]_cntx.h.
# C compiler. Then add frame/include since it's needed when compiling source
# files that #include bli_oapi_ba.h or bli_oapi_ex.h.
REF_KER_I_PATHS := $(strip $(patsubst %, -I%, $(REF_KER_H_PATHS)))
REF_KER_I_PATHS += -I$(DIST_PATH)/frame/include
@@ -1046,6 +1049,13 @@ REF_KER_I_PATHS += -I$(DIST_PATH)/frame/include
# now #include the monolithic/flattened blis.h instead.
CINCFLAGS := -I$(BASE_INC_PATH) $(REF_KER_I_PATHS)
# If CBLAS is enabled, we also include the path to the cblas.h directory so
# that the compiler will be able to find cblas.h as the CBLAS source code is
# being compiled.
ifeq ($(MK_ENABLE_CBLAS),yes)
CINCFLAGS += -I$(CBLAS_H_DIRPATH)
endif
# Obtain a list of header paths in the configured sandbox. Then add -I to each
# header path.
CSBOXINCFLAGS := $(strip $(patsubst %, -I%, $(SANDBOX_HDR_DIRPATHS)))

View File

@@ -113,6 +113,7 @@
#include "bla_amax.h"
#include "bla_asum.h"
#include "bla_axpy.h"
#include "bla_axpby.h"
#include "bla_copy.h"
#include "bla_dot.h"
#include "bla_nrm2.h"
@@ -199,6 +200,11 @@
#include "bla_trsm_check.h"
#include "bla_gemmt_check.h"
// -- Batch prototypes --
#include "bla_gemm_batch.h"
// -- Fortran-compatible APIs to BLIS functions --
#include "b77_thread.h"

View File

@@ -1,3 +1,4 @@
#ifndef CBLAS_H
#define CBLAS_H
#include <stddef.h>
@@ -595,6 +596,62 @@ void BLIS_EXPORT_BLAS cblas_zher2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo,
void BLIS_EXPORT_BLAS cblas_xerbla(f77_int p, const char *rout, const char *form, ...);
/*
* ===========================================================================
* BLAS Extension prototypes
* ===========================================================================
*/
// -- APIs to operations unique to BLIS --
void BLIS_EXPORT_BLAS cblas_saxpby(f77_int N, float alpha, const float *X,
f77_int incX, float beta, float *Y, f77_int incY);
void BLIS_EXPORT_BLAS cblas_daxpby(f77_int N, double alpha, const double *X,
f77_int incX, double beta, double *Y, f77_int incY);
void BLIS_EXPORT_BLAS cblas_caxpby(f77_int N, const void *alpha,
const void *X, f77_int incX, const void* beta,
void *Y, f77_int incY);
void BLIS_EXPORT_BLAS cblas_zaxpby(f77_int N, const void *alpha,
const void *X, f77_int incX, const void *beta,
void *Y, f77_int incY);
// -- Batch APIs --
void BLIS_EXPORT_BLAS cblas_sgemm_batch(enum CBLAS_ORDER Order,
enum CBLAS_TRANSPOSE *TransA_array,
enum CBLAS_TRANSPOSE *TransB_array,
f77_int *M_array, f77_int *N_array,
f77_int *K_array, const float *alpha_array, const float **A,
f77_int *lda_array, const float **B, f77_int *ldb_array,
const float *beta_array, float **C, f77_int *ldc_array,
f77_int group_count, f77_int *group_size);
void BLIS_EXPORT_BLAS cblas_dgemm_batch(enum CBLAS_ORDER Order,
enum CBLAS_TRANSPOSE *TransA_array,
enum CBLAS_TRANSPOSE *TransB_array,
f77_int *M_array, f77_int *N_array,
f77_int *K_array, const double *alpha_array,
const double **A,f77_int *lda_array,
const double **B, f77_int *ldb_array,
const double *beta_array, double **C, f77_int *ldc_array,
f77_int group_count, f77_int *group_size);
void BLIS_EXPORT_BLAS cblas_cgemm_batch(enum CBLAS_ORDER Order,
enum CBLAS_TRANSPOSE *TransA_array,
enum CBLAS_TRANSPOSE *TransB_array,
f77_int *M_array, f77_int *N_array,
f77_int *K_array, const void *alpha_array, const void **A,
f77_int *lda_array, const void **B, f77_int *ldb_array,
const void *beta_array, void **C, f77_int *ldc_array,
f77_int group_count, f77_int *group_size);
void BLIS_EXPORT_BLAS cblas_zgemm_batch(enum CBLAS_ORDER Order,
enum CBLAS_TRANSPOSE *TransA_array,
enum CBLAS_TRANSPOSE *TransB_array,
f77_int *M_array, f77_int *N_array,
f77_int *K_array, const void *alpha_array, const void **A,
f77_int *lda_array, const void **B, f77_int *ldb_array,
const void *beta_array, void **C, f77_int *ldc_array,
f77_int group_count, f77_int *group_size);
#ifdef __cplusplus
}
#endif

View File

@@ -14,7 +14,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc.
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -200,9 +200,20 @@
/*
* BLAS extensions
*/
#define F77_sgemmt sgemmt_
#define F77_dgemmt dgemmt_
#define F77_cgemmt cgemmt_
#define F77_zgemmt zgemmt_
#define F77_sgemmt sgemmt_
#define F77_dgemmt dgemmt_
#define F77_cgemmt cgemmt_
#define F77_zgemmt zgemmt_
#define F77_saxpby saxpby_
#define F77_daxpby daxpby_
#define F77_caxpby caxpby_
#define F77_zaxpby zaxpby_
#define F77_sgemm_batch sgemm_batch_
#define F77_dgemm_batch dgemm_batch_
#define F77_cgemm_batch cgemm_batch_
#define F77_zgemm_batch zgemm_batch_
#endif /* CBLAS_F77_H */

View File

@@ -7,6 +7,8 @@
* Written by Keita Teranishi
* 4/8/1998
*
* Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
*
*/
#include "cblas.h"
@@ -17,12 +19,12 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA,
f77_int lda, const float *B, f77_int ldb,
float beta, float *C, f77_int ldc)
{
char TA, TB;
char TA, TB;
#ifdef F77_CHAR
F77_CHAR F77_TA, F77_TB;
#else
#define F77_TA &TA
#define F77_TB &TB
#define F77_TA &TA
#define F77_TB &TB
#endif
#ifdef F77_INT
@@ -36,7 +38,7 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA,
#define F77_ldb ldb
#define F77_ldc ldc
#endif
extern int CBLAS_CallFromC;
extern int RowMajorStrg;
RowMajorStrg = 0;
@@ -46,9 +48,9 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA,
if(TransA == CblasTrans) TA='T';
else if ( TransA == CblasConjTrans ) TA='C';
else if ( TransA == CblasNoTrans ) TA='N';
else
else
{
cblas_xerbla(2, "cblas_sgemm",
cblas_xerbla(2, "cblas_sgemm",
"Illegal TransA setting, %d\n", TransA);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
@@ -58,9 +60,9 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA,
if(TransB == CblasTrans) TB='T';
else if ( TransB == CblasConjTrans ) TB='C';
else if ( TransB == CblasNoTrans ) TB='N';
else
else
{
cblas_xerbla(3, "cblas_sgemm",
cblas_xerbla(3, "cblas_sgemm",
"Illegal TransB setting, %d\n", TransB);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
@@ -79,9 +81,9 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA,
if(TransA == CblasTrans) TB='T';
else if ( TransA == CblasConjTrans ) TB='C';
else if ( TransA == CblasNoTrans ) TB='N';
else
else
{
cblas_xerbla(2, "cblas_sgemm",
cblas_xerbla(2, "cblas_sgemm",
"Illegal TransA setting, %d\n", TransA);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
@@ -90,10 +92,10 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA,
if(TransB == CblasTrans) TA='T';
else if ( TransB == CblasConjTrans ) TA='C';
else if ( TransB == CblasNoTrans ) TA='N';
else
else
{
cblas_xerbla(2, "cblas_sgemm",
"Illegal TransA setting, %d\n", TransA);
cblas_xerbla(2, "cblas_sgemm",
"Illegal TransB setting, %d\n", TransB);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
@@ -104,7 +106,7 @@ void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA,
#endif
F77_sgemm(F77_TA, F77_TB, &F77_N, &F77_M, &F77_K, &alpha, B, &F77_ldb, A, &F77_lda, &beta, C, &F77_ldc);
} else
} else
cblas_xerbla(1, "cblas_sgemm",
"Illegal Order setting, %d\n", Order);
CBLAS_CallFromC = 0;

View File

@@ -0,0 +1,27 @@
#include "blis.h"
#ifdef BLIS_ENABLE_CBLAS
/*
* cblas_caxpby.c
*
* The program is a C interface to caxpby.
*
* Copyright (C) 2020, Advanced Micro Devices, Inc
*
*/
#include "cblas.h"
#include "cblas_f77.h"
void cblas_caxpby( f77_int N, const void *alpha,
const void *X, f77_int incX,
const void *beta,
void *Y, f77_int incY)
{
#ifdef F77_INT
F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
#else
#define F77_N N
#define F77_incX incX
#define F77_incY incY
#endif
F77_caxpby( &F77_N, (scomplex*)alpha, (scomplex*)X, &F77_incX, (scomplex*)beta, (scomplex*)Y, &F77_incY);
}
#endif

View File

@@ -0,0 +1,168 @@
#include "blis.h"
#ifdef BLIS_ENABLE_CBLAS
/*
*
* cblas_cgemm_batch.c
* This program is a C interface to cgemm_batch.
*
* Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
*
*/
#include "cblas.h"
#include "cblas_f77.h"
void cblas_cgemm_batch(enum CBLAS_ORDER Order,
enum CBLAS_TRANSPOSE *TransA_array,
enum CBLAS_TRANSPOSE *TransB_array,
f77_int *M_array, f77_int *N_array,
f77_int *K_array, const void *alpha_array,
const void **A_array, f77_int *lda_array,
const void **B_array, f77_int *ldb_array,
const void *beta_array,
void **C_array, f77_int *ldc_array,
f77_int group_count, f77_int *group_size)
{
char TA[group_count], TB[group_count];
#ifdef F77_CHAR
F77_CHAR F77_TA[group_count], F77_TB[group_count];
#else
#define F77_TA TA
#define F77_TB TB
#endif
#ifdef F77_INT
F77_INT F77_GRP_COUNT = group_count;
F77_INT F77_M[F77_GRP_COUNT], F77_N[F77_GRP_COUNT], F77_K[F77_GRP_COUNT];
F77_INT F77_lda[F77_GRP_COUNT], F77_ldb[F77_GRP_COUNT], F77_ldc[F77_GRP_COUNT];
F77_INT F77_GRP_SIZE[F77_GRP_COUNT];
#else
#define F77_GRP_COUNT group_count
#define F77_M M_array
#define F77_N N_array
#define F77_K K_array
#define F77_lda lda_array
#define F77_ldb ldb_array
#define F77_ldc ldc_array
#define F77_GRP_SIZE group_size
#endif
extern int CBLAS_CallFromC;
extern int RowMajorStrg;
RowMajorStrg = 0;
CBLAS_CallFromC = 1;
dim_t i;
if( Order == CblasColMajor )
{
for(i = 0; i < group_count; i++)
{
if(TransA_array[i] == CblasTrans) TA[i]='T';
else if ( TransA_array[i] == CblasConjTrans ) TA[i]='C';
else if ( TransA_array[i] == CblasNoTrans ) TA[i]='N';
else
{
cblas_xerbla(2, "cblas_cgemm_batch",
"Illegal TransA setting %d for group %d\n", TransA_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
if(TransB_array[i] == CblasTrans) TB[i]='T';
else if ( TransB_array[i] == CblasConjTrans ) TB[i]='C';
else if ( TransB_array[i] == CblasNoTrans ) TB[i]='N';
else
{
cblas_xerbla(3, "cblas_cgemm_batch",
"Illegal TransB setting %d for group %d\n", TransB_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
#ifdef F77_CHAR
F77_TA[i] = C2F_CHAR(TA+i);
F77_TB[i] = C2F_CHAR(TB+i);
#endif
#ifdef F77_INT
F77_M[i] = M_array[i];
F77_N[i] = N_array[i];
F77_K[i] = K_array[i];
F77_lda[i] = lda_array[i];
F77_ldb[i] = ldb_array[i];
F77_ldc[i] = ldc_array[i];
F77_GRP_SIZE[i] = group_size[i];
#endif
}
F77_cgemm_batch(F77_TA, F77_TB,
F77_M, F77_N, F77_K,
(const scomplex*)alpha_array,
(const scomplex**)A_array, F77_lda,
(const scomplex**)B_array, F77_ldb,
(const scomplex*)beta_array,
(scomplex**)C_array, F77_ldc,
&F77_GRP_COUNT, F77_GRP_SIZE);
}
else if (Order == CblasRowMajor)
{
RowMajorStrg = 1;
dim_t i;
for(i = 0; i < group_count; i++)
{
if(TransA_array[i] == CblasTrans) TB[i]='T';
else if ( TransA_array[i] == CblasConjTrans ) TB[i]='C';
else if ( TransA_array[i] == CblasNoTrans ) TB[i]='N';
else
{
cblas_xerbla(2, "cblas_cgemm_batch",
"Illegal TransA setting %d for group %d\n", TransA_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
if(TransB_array[i] == CblasTrans) TA[i]='T';
else if ( TransB_array[i] == CblasConjTrans ) TA[i]='C';
else if ( TransB_array[i] == CblasNoTrans ) TA[i]='N';
else
{
cblas_xerbla(2, "cblas_cgemm_batch",
"Illegal TransB setting %d for group %d\n", TransB_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
#ifdef F77_CHAR
F77_TA = C2F_CHAR(&TA);
F77_TB = C2F_CHAR(&TB);
#endif
#ifdef F77_INT
F77_M[i] = M_array[i];
F77_N[i] = N_array[i];
F77_K[i] = K_array[i];
F77_lda[i] = lda_array[i];
F77_ldb[i] = ldb_array[i];
F77_ldc[i] = ldc_array[i];
F77_GRP_SIZE = group_size[i];
#endif
}
F77_cgemm_batch(F77_TA, F77_TB,
F77_N, F77_M, F77_K,
(const scomplex*)alpha_array,
(const scomplex**)B_array, F77_ldb,
(const scomplex**)A_array, F77_lda,
(const scomplex*)beta_array,
(scomplex**)C_array, F77_ldc,
&F77_GRP_COUNT, F77_GRP_SIZE);
} else
cblas_xerbla(1, "cblas_cgemm_batch",
"Illegal Order setting, %d\n", Order);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
}
#endif

View File

@@ -0,0 +1,26 @@
#include "blis.h"
#ifdef BLIS_ENABLE_CBLAS
/*
* cblas_daxpby.c
*
* The program is a C interface to daxpby.
*
* Copyright (C) 2020, Advanced Micro Devices, Inc.
*/
#include "cblas.h"
#include "cblas_f77.h"
void cblas_daxpby( f77_int N, double alpha,
const double *X, f77_int incX,
double beta,
double *Y, f77_int incY)
{
#ifdef F77_INT
F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
#else
#define F77_N N
#define F77_incX incX
#define F77_incY incY
#endif
F77_daxpby( &F77_N, &alpha, X, &F77_incX, &beta, Y, &F77_incY);
}
#endif

View File

@@ -0,0 +1,168 @@
#include "blis.h"
#ifdef BLIS_ENABLE_CBLAS
/*
*
* cblas_dgemm_batch.c
* This program is a C interface to dgemm_batch.
*
* Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
*
*/
#include "cblas.h"
#include "cblas_f77.h"
void cblas_dgemm_batch(enum CBLAS_ORDER Order,
enum CBLAS_TRANSPOSE *TransA_array,
enum CBLAS_TRANSPOSE *TransB_array,
f77_int *M_array, f77_int *N_array,
f77_int *K_array, const double *alpha_array,
const double **A_array, f77_int *lda_array,
const double **B_array, f77_int *ldb_array,
const double *beta_array,
double **C_array, f77_int *ldc_array,
f77_int group_count, f77_int *group_size)
{
char TA[group_count], TB[group_count];
#ifdef F77_CHAR
F77_CHAR F77_TA[group_count], F77_TB[group_count];
#else
#define F77_TA TA
#define F77_TB TB
#endif
#ifdef F77_INT
F77_INT F77_GRP_COUNT = group_count;
F77_INT F77_M[F77_GRP_COUNT], F77_N[F77_GRP_COUNT], F77_K[F77_GRP_COUNT];
F77_INT F77_lda[F77_GRP_COUNT], F77_ldb[F77_GRP_COUNT], F77_ldc[F77_GRP_COUNT];
F77_INT F77_GRP_SIZE[F77_GRP_COUNT];
#else
#define F77_GRP_COUNT group_count
#define F77_M M_array
#define F77_N N_array
#define F77_K K_array
#define F77_lda lda_array
#define F77_ldb ldb_array
#define F77_ldc ldc_array
#define F77_GRP_SIZE group_size
#endif
extern int CBLAS_CallFromC;
extern int RowMajorStrg;
RowMajorStrg = 0;
CBLAS_CallFromC = 1;
dim_t i;
if( Order == CblasColMajor )
{
for(i = 0; i < group_count; i++)
{
if(TransA_array[i] == CblasTrans) TA[i]='T';
else if ( TransA_array[i] == CblasConjTrans ) TA[i]='C';
else if ( TransA_array[i] == CblasNoTrans ) TA[i]='N';
else
{
cblas_xerbla(2, "cblas_dgemm_batch",
"Illegal TransA setting %d for group %d\n", TransA_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
if(TransB_array[i] == CblasTrans) TB[i]='T';
else if ( TransB_array[i] == CblasConjTrans ) TB[i]='C';
else if ( TransB_array[i] == CblasNoTrans ) TB[i]='N';
else
{
cblas_xerbla(3, "cblas_dgemm_batch",
"Illegal TransB setting %d for group %d\n", TransB_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
#ifdef F77_CHAR
F77_TA[i] = C2F_CHAR(TA+i);
F77_TB[i] = C2F_CHAR(TB+i);
#endif
#ifdef F77_INT
F77_M[i] = M_array[i];
F77_N[i] = N_array[i];
F77_K[i] = K_array[i];
F77_lda[i] = lda_array[i];
F77_ldb[i] = ldb_array[i];
F77_ldc[i] = ldc_array[i];
F77_GRP_SIZE[i] = group_size[i];
#endif
}
F77_dgemm_batch(F77_TA, F77_TB,
F77_M, F77_N, F77_K,
alpha_array,
A_array, F77_lda,
B_array, F77_ldb,
beta_array,
C_array, F77_ldc,
&F77_GRP_COUNT, F77_GRP_SIZE);
}
else if (Order == CblasRowMajor)
{
RowMajorStrg = 1;
dim_t i;
for(i = 0; i < group_count; i++)
{
if(TransA_array[i] == CblasTrans) TB[i]='T';
else if ( TransA_array[i] == CblasConjTrans ) TB[i]='C';
else if ( TransA_array[i] == CblasNoTrans ) TB[i]='N';
else
{
cblas_xerbla(2, "cblas_dgemm_batch",
"Illegal TransA setting %d for group %d\n", TransA_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
if(TransB_array[i] == CblasTrans) TA[i]='T';
else if ( TransB_array[i] == CblasConjTrans ) TA[i]='C';
else if ( TransB_array[i] == CblasNoTrans ) TA[i]='N';
else
{
cblas_xerbla(2, "cblas_dgemm_batch",
"Illegal TransB setting %d for group %d\n", TransB_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
#ifdef F77_CHAR
F77_TA = C2F_CHAR(&TA);
F77_TB = C2F_CHAR(&TB);
#endif
#ifdef F77_INT
F77_M[i] = M_array[i];
F77_N[i] = N_array[i];
F77_K[i] = K_array[i];
F77_lda[i] = lda_array[i];
F77_ldb[i] = ldb_array[i];
F77_ldc[i] = ldc_array[i];
F77_GRP_SIZE = group_size[i];
#endif
}
F77_dgemm_batch(F77_TA, F77_TB,
F77_N, F77_M, F77_K,
alpha_array,
B_array, F77_ldb,
A_array, F77_lda,
beta_array,
C_array, F77_ldc,
&F77_GRP_COUNT, F77_GRP_SIZE);
} else
cblas_xerbla(1, "cblas_dgemm_batch",
"Illegal Order setting, %d\n", Order);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
}
#endif

View File

@@ -0,0 +1,28 @@
#include "blis.h"
#ifdef BLIS_ENABLE_CBLAS
/*
* cblas_saxpby.c
*
* The program is a C interface to saxpby.
* It calls the fortran wrapper before calling saxpby.
*
* Copyright (C) 2020, Advanced Micro Devices, Inc.
*/
#include "cblas.h"
#include "cblas_f77.h"
void cblas_saxpby( f77_int N, float alpha,
const float *X, f77_int incX,
float beta,
float *Y, f77_int incY)
{
#ifdef F77_INT
F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
#else
#define F77_N N
#define F77_incX incX
#define F77_incY incY
#endif
F77_saxpby( &F77_N, &alpha, X, &F77_incX, &beta, Y, &F77_incY);
}
#endif

View File

@@ -0,0 +1,168 @@
#include "blis.h"
#ifdef BLIS_ENABLE_CBLAS
/*
*
* cblas_sgemm_batch.c
* This program is a C interface to sgemm_batch.
*
* Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
*
*/
#include "cblas.h"
#include "cblas_f77.h"
void cblas_sgemm_batch(enum CBLAS_ORDER Order,
enum CBLAS_TRANSPOSE *TransA_array,
enum CBLAS_TRANSPOSE *TransB_array,
f77_int *M_array, f77_int *N_array,
f77_int *K_array, const float *alpha_array,
const float **A_array, f77_int *lda_array,
const float **B_array, f77_int *ldb_array,
const float *beta_array,
float **C_array, f77_int *ldc_array,
f77_int group_count, f77_int *group_size)
{
char TA[group_count], TB[group_count];
#ifdef F77_CHAR
F77_CHAR F77_TA[group_count], F77_TB[group_count];
#else
#define F77_TA TA
#define F77_TB TB
#endif
#ifdef F77_INT
F77_INT F77_GRP_COUNT = group_count;
F77_INT F77_M[F77_GRP_COUNT], F77_N[F77_GRP_COUNT], F77_K[F77_GRP_COUNT];
F77_INT F77_lda[F77_GRP_COUNT], F77_ldb[F77_GRP_COUNT], F77_ldc[F77_GRP_COUNT];
F77_INT F77_GRP_SIZE[F77_GRP_COUNT];
#else
#define F77_GRP_COUNT group_count
#define F77_M M_array
#define F77_N N_array
#define F77_K K_array
#define F77_lda lda_array
#define F77_ldb ldb_array
#define F77_ldc ldc_array
#define F77_GRP_SIZE group_size
#endif
extern int CBLAS_CallFromC;
extern int RowMajorStrg;
RowMajorStrg = 0;
CBLAS_CallFromC = 1;
dim_t i;
if( Order == CblasColMajor )
{
for(i = 0; i < group_count; i++)
{
if(TransA_array[i] == CblasTrans) TA[i]='T';
else if ( TransA_array[i] == CblasConjTrans ) TA[i]='C';
else if ( TransA_array[i] == CblasNoTrans ) TA[i]='N';
else
{
cblas_xerbla(2, "cblas_sgemm_batch",
"Illegal TransA setting %d for group %d\n", TransA_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
if(TransB_array[i] == CblasTrans) TB[i]='T';
else if ( TransB_array[i] == CblasConjTrans ) TB[i]='C';
else if ( TransB_array[i] == CblasNoTrans ) TB[i]='N';
else
{
cblas_xerbla(3, "cblas_sgemm_batch",
"Illegal TransB setting %d for group %d\n", TransB_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
#ifdef F77_CHAR
F77_TA[i] = C2F_CHAR(TA+i);
F77_TB[i] = C2F_CHAR(TB+i);
#endif
#ifdef F77_INT
F77_M[i] = M_array[i];
F77_N[i] = N_array[i];
F77_K[i] = K_array[i];
F77_lda[i] = lda_array[i];
F77_ldb[i] = ldb_array[i];
F77_ldc[i] = ldc_array[i];
F77_GRP_SIZE[i] = group_size[i];
#endif
}
F77_sgemm_batch(F77_TA, F77_TB,
F77_M, F77_N, F77_K,
alpha_array,
A_array, F77_lda,
B_array, F77_ldb,
beta_array,
C_array, F77_ldc,
&F77_GRP_COUNT, F77_GRP_SIZE);
}
else if (Order == CblasRowMajor)
{
RowMajorStrg = 1;
dim_t i;
for(i = 0; i < group_count; i++)
{
if(TransA_array[i] == CblasTrans) TB[i]='T';
else if ( TransA_array[i] == CblasConjTrans ) TB[i]='C';
else if ( TransA_array[i] == CblasNoTrans ) TB[i]='N';
else
{
cblas_xerbla(2, "cblas_sgemm_batch",
"Illegal TransA setting %d for group %d\n", TransA_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
if(TransB_array[i] == CblasTrans) TA[i]='T';
else if ( TransB_array[i] == CblasConjTrans ) TA[i]='C';
else if ( TransB_array[i] == CblasNoTrans ) TA[i]='N';
else
{
cblas_xerbla(2, "cblas_sgemm_batch",
"Illegal TransB setting %d for group %d\n", TransB_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
#ifdef F77_CHAR
F77_TA = C2F_CHAR(&TA);
F77_TB = C2F_CHAR(&TB);
#endif
#ifdef F77_INT
F77_M[i] = M_array[i];
F77_N[i] = N_array[i];
F77_K[i] = K_array[i];
F77_lda[i] = lda_array[i];
F77_ldb[i] = ldb_array[i];
F77_ldc[i] = ldc_array[i];
F77_GRP_SIZE = group_size[i];
#endif
}
F77_sgemm_batch(F77_TA, F77_TB,
F77_N, F77_M, F77_K,
alpha_array,
B_array, F77_ldb,
A_array, F77_lda,
beta_array,
C_array, F77_ldc,
&F77_GRP_COUNT, F77_GRP_SIZE);
} else
cblas_xerbla(1, "cblas_sgemm_batch",
"Illegal Order setting, %d\n", Order);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
}
#endif

View File

@@ -0,0 +1,27 @@
#include "blis.h"
#ifdef BLIS_ENABLE_CBLAS
/*
* cblas_zaxpby.c
*
* The program is a C interface to zaxpby.
*
* Copyright (C) 2020, Advanced Micro Devices, Inc.
*
*/
#include "cblas.h"
#include "cblas_f77.h"
void cblas_zaxpby( f77_int N, const void *alpha,
const void *X, f77_int incX,
const void *beta,
void *Y, f77_int incY)
{
#ifdef F77_INT
F77_INT F77_N=N, F77_incX=incX, F77_incY=incY;
#else
#define F77_N N
#define F77_incX incX
#define F77_incY incY
#endif
F77_zaxpby( &F77_N, (dcomplex*)alpha, (dcomplex*)X, &F77_incX, (dcomplex*)beta, (dcomplex*)Y, &F77_incY);
}
#endif

View File

@@ -0,0 +1,168 @@
#include "blis.h"
#ifdef BLIS_ENABLE_CBLAS
/*
*
* cblas_zgemm_batch.c
* This program is a C interface to zgemm_batch.
*
* Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
*
*/
#include "cblas.h"
#include "cblas_f77.h"
void cblas_zgemm_batch(enum CBLAS_ORDER Order,
enum CBLAS_TRANSPOSE *TransA_array,
enum CBLAS_TRANSPOSE *TransB_array,
f77_int *M_array, f77_int *N_array,
f77_int *K_array, const void *alpha_array,
const void **A_array, f77_int *lda_array,
const void **B_array, f77_int *ldb_array,
const void *beta_array,
void **C_array, f77_int *ldc_array,
f77_int group_count, f77_int *group_size)
{
char TA[group_count], TB[group_count];
#ifdef F77_CHAR
F77_CHAR F77_TA[group_count], F77_TB[group_count];
#else
#define F77_TA TA
#define F77_TB TB
#endif
#ifdef F77_INT
F77_INT F77_GRP_COUNT = group_count;
F77_INT F77_M[F77_GRP_COUNT], F77_N[F77_GRP_COUNT], F77_K[F77_GRP_COUNT];
F77_INT F77_lda[F77_GRP_COUNT], F77_ldb[F77_GRP_COUNT], F77_ldc[F77_GRP_COUNT];
F77_INT F77_GRP_SIZE[F77_GRP_COUNT];
#else
#define F77_GRP_COUNT group_count
#define F77_M M_array
#define F77_N N_array
#define F77_K K_array
#define F77_lda lda_array
#define F77_ldb ldb_array
#define F77_ldc ldc_array
#define F77_GRP_SIZE group_size
#endif
extern int CBLAS_CallFromC;
extern int RowMajorStrg;
RowMajorStrg = 0;
CBLAS_CallFromC = 1;
dim_t i;
if( Order == CblasColMajor )
{
for(i = 0; i < group_count; i++)
{
if(TransA_array[i] == CblasTrans) TA[i]='T';
else if ( TransA_array[i] == CblasConjTrans ) TA[i]='C';
else if ( TransA_array[i] == CblasNoTrans ) TA[i]='N';
else
{
cblas_xerbla(2, "cblas_zgemm_batch",
"Illegal TransA setting %d for group %d\n", TransA_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
if(TransB_array[i] == CblasTrans) TB[i]='T';
else if ( TransB_array[i] == CblasConjTrans ) TB[i]='C';
else if ( TransB_array[i] == CblasNoTrans ) TB[i]='N';
else
{
cblas_xerbla(3, "cblas_zgemm_batch",
"Illegal TransB setting %d for group %d\n", TransB_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
#ifdef F77_CHAR
F77_TA[i] = C2F_CHAR(TA+i);
F77_TB[i] = C2F_CHAR(TB+i);
#endif
#ifdef F77_INT
F77_M[i] = M_array[i];
F77_N[i] = N_array[i];
F77_K[i] = K_array[i];
F77_lda[i] = lda_array[i];
F77_ldb[i] = ldb_array[i];
F77_ldc[i] = ldc_array[i];
F77_GRP_SIZE[i] = group_size[i];
#endif
}
F77_zgemm_batch(F77_TA, F77_TB,
F77_M, F77_N, F77_K,
(const dcomplex*)alpha_array,
(const dcomplex**)A_array, F77_lda,
(const dcomplex**)B_array, F77_ldb,
(const dcomplex*)beta_array,
(dcomplex**)C_array, F77_ldc,
&F77_GRP_COUNT, F77_GRP_SIZE);
}
else if (Order == CblasRowMajor)
{
RowMajorStrg = 1;
dim_t i;
for(i = 0; i < group_count; i++)
{
if(TransA_array[i] == CblasTrans) TB[i]='T';
else if ( TransA_array[i] == CblasConjTrans ) TB[i]='C';
else if ( TransA_array[i] == CblasNoTrans ) TB[i]='N';
else
{
cblas_xerbla(2, "cblas_zgemm_batch",
"Illegal TransA setting %d for group %d\n", TransA_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
if(TransB_array[i] == CblasTrans) TA[i]='T';
else if ( TransB_array[i] == CblasConjTrans ) TA[i]='C';
else if ( TransB_array[i] == CblasNoTrans ) TA[i]='N';
else
{
cblas_xerbla(2, "cblas_zgemm_batch",
"Illegal TransB setting %d for group %d\n", TransB_array[i], i);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
return;
}
#ifdef F77_CHAR
F77_TA = C2F_CHAR(&TA);
F77_TB = C2F_CHAR(&TB);
#endif
#ifdef F77_INT
F77_M[i] = M_array[i];
F77_N[i] = N_array[i];
F77_K[i] = K_array[i];
F77_lda[i] = lda_array[i];
F77_ldb[i] = ldb_array[i];
F77_ldc[i] = ldc_array[i];
F77_GRP_SIZE = group_size[i];
#endif
}
F77_zgemm_batch(F77_TA, F77_TB,
F77_N, F77_M, F77_K,
(const dcomplex*)alpha_array,
(const dcomplex**)B_array, F77_ldb,
(const dcomplex**)A_array, F77_lda,
(const dcomplex*)beta_array,
(dcomplex**)C_array, F77_ldc,
&F77_GRP_COUNT, F77_GRP_SIZE);
} else
cblas_xerbla(1, "cblas_zgemm_batch",
"Illegal Order setting, %d\n", Order);
CBLAS_CallFromC = 0;
RowMajorStrg = 0;
}
#endif

View File

@@ -0,0 +1,89 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
//
// Define BLAS-to-BLIS interfaces.
//
#undef GENTFUNC
#define GENTFUNC( ftype, ch, blasname, blisname ) \
\
void PASTEF77(ch,blasname) \
( \
const f77_int* n, \
const ftype* alpha, \
const ftype* x, const f77_int* incx, \
const ftype* beta, \
ftype* y, const f77_int* incy \
) \
{ \
dim_t n0; \
ftype* x0; \
ftype* y0; \
inc_t incx0; \
inc_t incy0; \
\
/* Initialize BLIS. */ \
bli_init_auto(); \
\
/* Convert/typecast negative values of n to zero. */ \
bli_convert_blas_dim1( *n, n0 ); \
\
/* If the input increments are negative, adjust the pointers so we can
use positive increments instead. */ \
bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \
bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \
\
/* Call BLIS interface. */ \
PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \
( \
BLIS_NO_CONJUGATE, \
n0, \
(ftype*)alpha, \
x0, incx0, \
(ftype*)beta, \
y0, incy0, \
NULL, \
NULL \
); \
\
/* Finalize BLIS. */ \
bli_finalize_auto(); \
}
#ifdef BLIS_ENABLE_BLAS
INSERT_GENTFUNC_BLAS( axpby, axpbyv )
#endif

View File

@@ -0,0 +1,54 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
//
// Prototype BLAS-to-BLIS interfaces.
//
#undef GENTPROT
#define GENTPROT( ftype, ch, blasname ) \
\
BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \
( \
const f77_int* n, \
const ftype* alpha, \
const ftype* x, const f77_int* incx, \
const ftype* beta, \
ftype* y, const f77_int* incy \
);
#ifdef BLIS_ENABLE_BLAS
INSERT_GENTPROT_BLAS( axpby )
#endif

View File

@@ -0,0 +1,254 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
//
// Define BLAS-to-BLIS interfaces.
//
#ifdef BLIS_BLAS3_CALLS_TAPI
#undef GENTFUNC
#define GENTFUNC( ftype, ch, blasname, blisname ) \
\
void PASTEF77(ch,blasname) \
( \
const f77_char* transa_array, \
const f77_char* transb_array, \
const f77_int* m_array, \
const f77_int* n_array, \
const f77_int* k_array, \
const ftype* alpha_array, \
const ftype** a_array, const f77_int* lda_array, \
const ftype** b_array, const f77_int* ldb_array, \
const ftype* beta_array, \
ftype** c_array, const f77_int* ldc_array, \
const f77_int* group_count, \
const f77_int* group_size \
) \
{ \
trans_t blis_transa; \
trans_t blis_transb; \
dim_t m0, n0, k0; \
inc_t rs_a, cs_a; \
inc_t rs_b, cs_b; \
inc_t rs_c, cs_c; \
\
/* Initialize BLIS. */ \
bli_init_auto(); \
\
/* Perform BLAS parameter checking. */ \
for ( f77_int gi = 0; gi < *group_count; gi++ ) \
{ \
PASTEBLACHK(blisname) \
( \
MKSTR(ch), \
MKSTR(blisname), \
transa_array+gi, \
transb_array+gi, \
m_array+gi, \
n_array+gi, \
k_array+gi, \
lda_array+gi, \
ldb_array+gi, \
ldc_array+gi \
); \
} \
\
f77_int idx = 0; \
\
for ( f77_int i = 0; i < *group_count; i++ ) \
{ \
/* Map BLAS chars to their corresponding BLIS enumerated type value. */ \
bli_param_map_netlib_to_blis_trans( transa_array[i], &blis_transa ); \
bli_param_map_netlib_to_blis_trans( transb_array[i], &blis_transb ); \
\
/* Typecast BLAS integers to BLIS integers. */ \
bli_convert_blas_dim1( m_array[i], m0 ); \
bli_convert_blas_dim1( n_array[i], n0 ); \
bli_convert_blas_dim1( k_array[i], k0 ); \
\
/* Set the row and column strides of the matrix operands. */ \
rs_a = 1; \
cs_a = lda_array[i]; \
rs_b = 1; \
cs_b = ldb_array[i]; \
rs_c = 1; \
cs_c = ldc_array[i]; \
\
for ( f77_int j = 0; j < group_size[i]; j++ ) \
{ \
/* Call BLIS interface. */ \
PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \
( \
blis_transa, \
blis_transb, \
m0, \
n0, \
k0, \
(ftype*)(alpha_array + i), \
(ftype*)*(a_array + idx), rs_a, cs_a, \
(ftype*)*(b_array + idx), rs_b, cs_b, \
(ftype*)(beta_array + i), \
(ftype*)*(c_array + idx), rs_c, cs_c, \
NULL, \
NULL \
); \
\
idx++; \
} \
} \
\
bli_finalize_auto(); \
}
#else
#undef GENTFUNC
#define GENTFUNC( ftype, ch, blasname, blisname ) \
\
void PASTEF77(ch,blasname) \
( \
const f77_char* transa_array, \
const f77_char* transb_array, \
const f77_int* m_array, \
const f77_int* n_array, \
const f77_int* k_array, \
const ftype* alpha_array, \
const ftype** a_array, const f77_int* lda_array, \
const ftype** b_array, const f77_int* ldb_array, \
const ftype* beta_array, \
ftype** c_array, const f77_int* ldc_array, \
const f77_int* group_count, \
const f77_int* group_size ) \
{ \
trans_t blis_transa; \
trans_t blis_transb; \
dim_t m0, n0, k0; \
\
/* Initialize BLIS. */ \
bli_init_auto(); \
\
/* Perform BLAS parameter checking. */ \
for ( f77_int gi = 0; gi < *group_count; gi++ ) \
{ \
PASTEBLACHK(blisname) \
( \
MKSTR(ch), \
MKSTR(blisname), \
transa_array+gi, \
transb_array+gi, \
m_array+gi, \
n_array+gi, \
k_array+gi, \
lda_array+gi, \
ldb_array+gi, \
ldc_array+gi \
); \
} \
\
const num_t dt = PASTEMAC(ch,type); \
\
f77_int idx = 0, i, j; \
\
for ( i = 0; i < *group_count; i++ ) \
{ \
/* Map BLAS chars to their corresponding BLIS enumerated type value. */ \
bli_param_map_netlib_to_blis_trans( transa_array[i], &blis_transa ); \
bli_param_map_netlib_to_blis_trans( transb_array[i], &blis_transb ); \
\
/* Typecast BLAS integers to BLIS integers. */ \
bli_convert_blas_dim1( m_array[i], m0 ); \
bli_convert_blas_dim1( n_array[i], n0 ); \
bli_convert_blas_dim1( k_array[i], k0 ); \
\
/* Set the row and column strides of the matrix operands. */ \
const inc_t rs_a = 1; \
const inc_t cs_a = lda_array[i]; \
const inc_t rs_b = 1; \
const inc_t cs_b = ldb_array[i]; \
const inc_t rs_c = 1; \
const inc_t cs_c = ldc_array[i]; \
\
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \
\
dim_t m0_a, n0_a; \
dim_t m0_b, n0_b; \
\
bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \
bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \
\
bli_obj_init_finish_1x1( dt, (ftype*)(alpha_array + i), &alphao ); \
bli_obj_init_finish_1x1( dt, (ftype*)(beta_array + i), &betao ); \
\
for( j = 0; j < group_size[i]; j++ ) \
{ \
obj_t ao = BLIS_OBJECT_INITIALIZER; \
obj_t bo = BLIS_OBJECT_INITIALIZER; \
obj_t co = BLIS_OBJECT_INITIALIZER; \
\
bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)*(a_array + idx), rs_a, cs_a, &ao ); \
bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)*(b_array + idx), rs_b, cs_b, &bo ); \
bli_obj_init_finish( dt, m0, n0, (ftype*)*(c_array + idx), rs_c, cs_c, &co ); \
bli_obj_set_conjtrans( blis_transa, &ao ); \
bli_obj_set_conjtrans( blis_transb, &bo ); \
\
PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \
( \
&alphao, \
&ao, \
&bo, \
&betao, \
&co, \
NULL, \
NULL \
); \
\
idx++; \
} \
} \
\
/* Finalize BLIS. */ \
bli_finalize_auto(); \
}
#endif
#ifdef BLIS_ENABLE_BLAS
INSERT_GENTFUNC_BLAS( gemm_batch, gemm )
#endif

View File

@@ -0,0 +1,61 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
//
// Prototype BLAS-to-BLIS interfaces.
//
#undef GENTPROT
#define GENTPROT( ftype, ch, blasname ) \
\
BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \
( \
const f77_char* transa_array, \
const f77_char* transb_array, \
const f77_int* m_array, \
const f77_int* n_array, \
const f77_int* k_array, \
const ftype* alpha_array, \
const ftype** a_array, const f77_int* lda_array, \
const ftype** b_array, const f77_int* ldb_array, \
const ftype* beta_array, \
ftype** c_array, const f77_int* ldc_array, \
const f77_int* group_count, \
const f77_int* group_size \
);
#ifdef BLIS_ENABLE_BLAS
INSERT_GENTPROT_BLAS( gemm_batch )
#endif

View File

@@ -1,11 +1,11 @@
#
#
# BLIS
# BLIS
# An object-based framework for developing high-performance BLAS-like
# libraries.
#
# Copyright (C) 2014, The University of Texas at Austin
# Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc.
# Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
@@ -143,9 +143,9 @@ CFLAGS += -I$(TEST_SRC_PATH)
#
# Define the operations we will test.
TEST_OPS := dotv axpyv \
TEST_OPS := dotv axpyv axpbyv\
gemv ger hemv her her2 trmv trsv \
gemm hemm herk her2k trmm trsm
gemm gemm_batch hemm herk her2k trmm trsm
# Optionally test gemmt, which some libraries might not implement.
ifeq ($(BUILD_GEMMT),yes)

293
test/test_axpbyv.c Normal file
View File

@@ -0,0 +1,293 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifdef WIN32
#include <io.h>
#else
#include <unistd.h>
#endif
#include "blis.h"
//#define PRINT
#ifdef BLIS_ENABLE_CBLAS
//#define CHECK_CBLAS
#endif
#ifdef CHECK_CBLAS
#include "cblas.h"
#endif
/*
* BLIS interface API will be called by default.
* To call BLAS API, modify line 159 to '#if 0'.
* To call cblas API, modify line 159 to '#if 0'and define the
* macro 'CHECK_CBLAS' in line 44
*
*Sample prototype for BLAS interface API is as follows:
* n alpha x incx beta y incy
*void daxpbyv_( int*, double*, double*, int*, double*, double*, int* );
*/
int main( int argc, char** argv )
{
obj_t x, y;
obj_t y_save;
obj_t alpha, beta;
dim_t n;
dim_t p;
dim_t p_begin, p_end, p_inc;
int n_input;
num_t dt_x, dt_y;
num_t dt_alpha, dt_beta;
int r, n_repeats;
num_t dt;
double dtime;
double dtime_save;
double gflops;
bli_init();
n_repeats = 3;
#ifndef PRINT
p_begin = 40;
p_end = 4000;
p_inc = 40;
n_input = -1;
#else
p_begin = 16;
p_end = 16;
p_inc = 1;
n_input = 15;
#endif
#if 1
dt = BLIS_FLOAT;
//dt = BLIS_DOUBLE;
#else
//dt = BLIS_SCOMPLEX;
dt = BLIS_DCOMPLEX;
#endif
dt_x = dt_y = dt_alpha = dt_beta = dt;
// Begin with initializing the last entry to zero so that
// matlab allocates space for the entire array once up-front.
for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ;
#ifdef BLIS
printf( "data_axpbyv_blis" );
#else
printf( "data_axpbyv_%s", BLAS );
#endif
printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n",
( unsigned long )(p - p_begin)/p_inc + 1,
( unsigned long )0, 0.0 );
//for ( p = p_begin; p <= p_end; p += p_inc )
for ( p = p_end; p_begin <= p; p -= p_inc )
{
if ( n_input < 0 ) n = p * ( dim_t )abs(n_input);
else n = ( dim_t ) n_input;
bli_obj_create( dt_alpha, 1, 1, 0, 0, &alpha );
bli_obj_create( dt_beta, 1, 1, 0, 0, &beta );
bli_obj_create( dt_x, n, 1, 0, 0, &x );
bli_obj_create( dt_y, n, 1, 0, 0, &y );
bli_obj_create( dt_y, n, 1, 0, 0, &y_save );
bli_randm( &x );
bli_randm( &y );
bli_setsc( (0.9/1.0), 0.2, &alpha );
bli_setsc( -(1.1/1.0), 0.3, &beta );
bli_copym( &y, &y_save );
dtime_save = 1.0e9;
for ( r = 0; r < n_repeats; ++r )
{
bli_copym( &y_save, &y );
dtime = bli_clock();
#ifdef PRINT
bli_printm( "alpha", &alpha, "%4.1f", "" );
bli_printm( "beta" , &beta, "%4.1f", "" );
bli_printm( "x", &x, "%4.1f", "" );
bli_printm( "y", &y, "%4.1f", "" );
#endif
#ifdef BLIS
bli_axpbyv( &alpha,
&x,
&beta,
&y );
#else
if ( bli_is_float( dt ) )
{
f77_int nn = bli_obj_length( &x );
f77_int incx = bli_obj_vector_inc( &x );
f77_int incy = bli_obj_vector_inc( &y );
float alphap = *(( float * )bli_obj_buffer( &alpha ));
float betap = *(( float * )bli_obj_buffer( &beta ));
float* xp = bli_obj_buffer( &x );
float* yp = bli_obj_buffer( &y );
#ifdef CHECK_CBLAS
cblas_saxpby( nn,
alphap,
xp, incx,
betap,
yp, incy );
#else
saxpby_( &nn,
&alphap,
xp, &incx,
&betap,
yp, &incy );
#endif
}
else if ( bli_is_double( dt ) )
{
f77_int nn = bli_obj_length( &x );
f77_int incx = bli_obj_vector_inc( &x );
f77_int incy = bli_obj_vector_inc( &y );
double alphap = *(( double * )bli_obj_buffer( &alpha ));
double betap = *(( double * )bli_obj_buffer( &beta ));
double* xp = bli_obj_buffer( &x );
double* yp = bli_obj_buffer( &y );
#ifdef CHECK_CBLAS
cblas_daxpby( nn,
alphap,
xp, incx,
betap,
yp, incy );
#else
daxpby_( &nn,
&alphap,
xp, &incx,
&betap,
yp, &incy );
#endif
}
else if ( bli_is_scomplex( dt ) )
{
f77_int nn = bli_obj_length( &x );
f77_int incx = bli_obj_vector_inc( &x );
f77_int incy = bli_obj_vector_inc( &y );
void* alphap = bli_obj_buffer( &alpha );
void* betap = bli_obj_buffer( &beta );
void* xp = bli_obj_buffer( &x );
void* yp = bli_obj_buffer( &y );
#ifdef CHECK_CBLAS
cblas_caxpby( nn,
alphap,
xp, incx,
betap,
yp, incy );
#else
caxpby_( &nn,
( scomplex* )alphap,
( scomplex* )xp, &incx,
( scomplex* )betap,
( scomplex* )yp, &incy );
#endif
}
else if ( bli_is_dcomplex( dt ))
{
f77_int nn = bli_obj_length( &x );
f77_int incx = bli_obj_vector_inc( &x );
f77_int incy = bli_obj_vector_inc( &y );
void* alphap = bli_obj_buffer( &alpha );
void* betap = bli_obj_buffer( &beta );
void* xp = bli_obj_buffer( &x );
void* yp = bli_obj_buffer( &y );
#ifdef CHECK_CBLAS
cblas_zaxpby( nn,
alphap,
xp, incx,
betap,
yp, incy );
#else
zaxpby_( &nn,
( dcomplex* )alphap,
( dcomplex* )xp, &incx,
( dcomplex* )betap,
( dcomplex* )yp, &incy );
#endif
}
#endif
#ifdef PRINT
bli_printm( "y after", &y, "%4.1f", "" );
exit(1);
#endif
dtime_save = bli_clock_min_diff( dtime_save, dtime );
}
gflops = ( 3.0 * n ) / ( dtime_save * 1.0e9 );
#ifdef BLIS
printf( "data_axpbyv_blis" );
#else
printf( "data_axpbyv_%s", BLAS );
#endif
printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n",
( unsigned long )(p - p_begin)/p_inc + 1,
( unsigned long )n, gflops );
bli_obj_free( &alpha );
bli_obj_free( &beta );
bli_obj_free( &x );
bli_obj_free( &y );
bli_obj_free( &y_save );
}
bli_finalize();
return 0;
}

584
test/test_gemm_batch.c Normal file
View File

@@ -0,0 +1,584 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name of The University of Texas nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifdef WIN32
#include <io.h>
#else
#include <unistd.h>
#endif
#include "blis.h"
//#define CHECK_CBLAS
#ifdef CHECK_CBLAS
#include "cblas.h"
#endif
/* Format for FILE input
* For each input set, first line contains 'storage scheme'
* and 'group count' seperated by space.
* Following 'group_count' number of lines contains all the parameters of
* each group separated by space in each line in the following order:
* tA tB m n k lda ldb ldc alpha_r alpha_i beta_r beta_i group_size
*
* Example:
* c 2
* n n 4 8 4 4 4 4 1.1 0.0 0.9 0.0 2
* n n 3 3 6 3 6 3 1.0 0.0 2.0 0.0 2
*
*/
//#define FILE_IN_OUT
#ifndef FILE_IN_OUT
#define GRP_COUNT 2
#endif
//#define PRINT
int main( int argc, char** argv )
{
num_t dt;
char stor_scheme;
dim_t i, j, idx;
dim_t r, n_repeats;
double dtime;
double dtime_save;
double gflops;
dim_t total_count = 0;
#if 1
dt = BLIS_FLOAT;
//dt = BLIS_DOUBLE;
#else
dt = BLIS_SCOMPLEX;
//dt = BLIS_DCOMPLEX;
#endif
n_repeats = 1;
#ifdef FILE_IN_OUT
FILE* fin = NULL;
FILE* fout = NULL;
if(argc < 3)
{
printf("Usage: ./test_gemm_batch_XX.x input.csv output.csv\n");
exit(1);
}
fin = fopen(argv[1], "r");
if( fin == NULL )
{
printf("Error opening input file %s \n", argv[1]);
exit(1);
}
fout = fopen(argv[2], "w");
if(fout == NULL)
{
printf("Error opening output file %s\n",argv[2]);
exit(1);
}
dim_t GRP_COUNT;
fprintf(fout, "m\t n\t k\t lda\t ldb\t ldc\t transa\t transb\t grp_size\n");
while(fscanf(fin, "%c %ld\n", &stor_scheme, &GRP_COUNT) == 2)
{
char transa[GRP_COUNT];
char transb[GRP_COUNT];
dim_t m[GRP_COUNT];
dim_t n[GRP_COUNT];
dim_t k[GRP_COUNT];
dim_t lda[GRP_COUNT];
dim_t ldb[GRP_COUNT];
dim_t ldc[GRP_COUNT];
double alpha_real[GRP_COUNT];
double alpha_imag[GRP_COUNT];
double beta_real[GRP_COUNT];
double beta_imag[GRP_COUNT];
dim_t group_size[GRP_COUNT];
obj_t alpha[GRP_COUNT], beta[GRP_COUNT];
total_count = 0;
for(i = 0; i < GRP_COUNT; i++)
{
fscanf(fin, "%c %c %ld %ld %ld %ld %ld %ld %lf %lf %lf %lf %ld\n", &transa[i], &transb[i], &m[i], &n[i], &k[i], &lda[i], &ldb[i], &ldc[i], &alpha_real[i], &alpha_imag[i], &beta_real[i], &beta_imag[i], &group_size[i]);
total_count += group_size[i];
}
#else
printf("m\t n\t k\t lda\t ldb\t ldc\t transa\t transb\t grp_size\n");
stor_scheme = 'c';
dim_t m[GRP_COUNT] = {4, 3};
dim_t n[GRP_COUNT] = {8, 3};
dim_t k[GRP_COUNT] = {4, 6};
dim_t lda[GRP_COUNT] = {4, 3};
dim_t ldb[GRP_COUNT] = {4, 6};
dim_t ldc[GRP_COUNT] = {4, 3};
char transa[GRP_COUNT] = {'N', 'N'};
char transb[GRP_COUNT] = {'N', 'N'};
double alpha_real[GRP_COUNT] = {1.1, 1.0};
double alpha_imag[GRP_COUNT] = {0.0, 0.0};
double beta_real[GRP_COUNT] = {0.9, 2.0};
double beta_imag[GRP_COUNT] = {0.0, 0.0};
dim_t group_size[GRP_COUNT] = {2,2};
obj_t alpha[GRP_COUNT], beta[GRP_COUNT];
total_count = 0;
for(i = 0; i < GRP_COUNT; i++)
total_count += group_size[i];
#endif
obj_t a[total_count], b[total_count];
obj_t c[total_count], c_save[total_count];
f77_int f77_m[GRP_COUNT], f77_n[GRP_COUNT], f77_k[GRP_COUNT];
f77_int f77_lda[GRP_COUNT], f77_ldb[GRP_COUNT], f77_ldc[GRP_COUNT];
f77_int f77_group_size[GRP_COUNT];
f77_int f77_group_count = GRP_COUNT;
#ifdef CHECK_CBLAS
enum CBLAS_ORDER cblas_order;
enum CBLAS_TRANSPOSE cblas_transa[GRP_COUNT];
enum CBLAS_TRANSPOSE cblas_transb[GRP_COUNT];
if(stor_scheme == 'R' || stor_scheme == 'r')
cblas_order = CblasRowMajor;
else
cblas_order = CblasColMajor;
#else
f77_char f77_transa[GRP_COUNT];
f77_char f77_transb[GRP_COUNT];
if(stor_scheme == 'r' || stor_scheme == 'R' )
{
printf("BLAS Interface doesn't support row-major order\n");
#ifdef FILE_IN_OUT
continue;
#else
exit(1);
#endif
}
#endif
idx = 0;
for(i = 0; i < GRP_COUNT; i++)
{
bli_obj_create(dt, 1, 1, 0, 0, &alpha[i]);
bli_obj_create(dt, 1, 1, 0, 0, &beta[i] );
bli_setsc(alpha_real[i], alpha_imag[i], &alpha[i]);
bli_setsc(beta_real[i], beta_imag[i], &beta[i] );
trans_t blis_transa, blis_transb;
if(transa[i] == 't' || transa[i] == 'T')
blis_transa = BLIS_TRANSPOSE;
else if (transa[i] == 'c' || transa[i] == 'C')
blis_transa = BLIS_CONJ_TRANSPOSE;
else if ( transa[i] == 'n' || transa[i] == 'N')
blis_transa = BLIS_NO_TRANSPOSE;
else
{
printf("Illegal transA setting %c for group %ld\n", transa[i], i);
exit(1);
}
if(transb[i] == 't' || transb[i] == 'T')
blis_transb = BLIS_TRANSPOSE;
else if (transb[i] == 'c' || transb[i] == 'C')
blis_transb = BLIS_CONJ_TRANSPOSE;
else if (transb[i] == 'n' || transb[i] == 'N')
blis_transb = BLIS_NO_TRANSPOSE;
else
{
printf("Illegal transB setting %c for group %ld\n", transb[i], i);
exit(1);
}
#ifdef CHECK_CBLAS
if(bli_is_trans( blis_transa ))
cblas_transa[i] = CblasTrans;
else if (bli_is_conjtrans( blis_transa ))
cblas_transa[i] = CblasConjTrans;
else
cblas_transa[i] = CblasNoTrans;
if(bli_is_trans( blis_transb ))
cblas_transb[i] = CblasTrans;
else if (bli_is_conjtrans( blis_transb ))
cblas_transb[i] = CblasConjTrans;
else
cblas_transb[i] = CblasNoTrans;
#else
bli_param_map_blis_to_netlib_trans( blis_transa, &f77_transa[i]);
bli_param_map_blis_to_netlib_trans( blis_transb, &f77_transb[i]);
#endif
dim_t m0_a, n0_a;
dim_t m0_b, n0_b;
bli_set_dims_with_trans( blis_transa, m[i], k[i], &m0_a, &n0_a );
bli_set_dims_with_trans( blis_transb, k[i], n[i], &m0_b, &n0_b );
if(stor_scheme == 'C' || stor_scheme == 'c')
{
for(j = 0; j < group_size[i]; j++)
{
bli_obj_create(dt, m0_a, n0_a, 1, lda[i], &a[idx]);
bli_obj_create(dt, m0_b, n0_b, 1, ldb[i], &b[idx]);
bli_obj_create(dt, m[i], n[i], 1, ldc[i], &c[idx]);
bli_obj_create(dt, m[i], n[i], 1, ldc[i], &c_save[idx]);
bli_randm( &a[idx] );
bli_randm( &b[idx] );
bli_randm( &c[idx] );
bli_obj_set_conjtrans(blis_transa, &a[idx]);
bli_obj_set_conjtrans(blis_transb, &b[idx]);
idx++;
}
}
else if(stor_scheme == 'R' || stor_scheme == 'r')
{
for(j = 0; j < group_size[i]; j++)
{
bli_obj_create(dt, m0_a, n0_a, lda[i], 1, &a[idx]);
bli_obj_create(dt, m0_b, n0_b, ldb[i], 1, &b[idx]);
bli_obj_create(dt, m[i], n[i], ldc[i], 1, &c[idx]);
bli_obj_create(dt, m[i], n[i], ldc[i], 1, &c_save[idx]);
bli_randm( &a[idx] );
bli_randm( &b[idx] );
bli_randm( &c[idx] );
bli_obj_set_conjtrans(blis_transa, &a[idx]);
bli_obj_set_conjtrans(blis_transb, &b[idx]);
idx++;
}
}
f77_m[i] = m[i];
f77_n[i] = n[i];
f77_k[i] = k[i];
f77_lda[i] = lda[i];
f77_ldb[i] = ldb[i];
f77_ldc[i] = ldc[i];
f77_group_size[i] = group_size[i];
}
idx = 0;
for(i = 0; i < GRP_COUNT; i++)
for(j = 0; j < group_size[i]; j++)
{
bli_copym(&c[idx], &c_save[idx]);
idx++;
}
dtime_save = DBL_MAX;
for( r = 0; r < n_repeats; ++r )
{
idx = 0;
for(i = 0; i < GRP_COUNT; i++)
for(j = 0; j < group_size[i]; j++)
{
bli_copym( &c_save[idx], &c[idx]);
idx++;
}
dtime = bli_clock();
#ifdef PRINT
idx = 0;
for(i = 0; i < GRP_COUNT; i++)
for(j = 0; j < group_size[i]; j++)
{
printf("Group: %ld Member: %ld\n", i, j);
bli_printm("a", &a[idx], "%4.1f", "");
bli_printm("b", &b[idx], "%4.1f", "");
bli_printm("c", &c[idx], "%4.1f", "");
idx++;
}
#endif
if(bli_is_float(dt))
{
const float *ap[total_count], *bp[total_count];
float *cp[total_count];
float alphap[GRP_COUNT], betap[GRP_COUNT];
idx = 0;
for(i = 0; i < GRP_COUNT; i++)
{
for(j = 0; j < group_size[i]; j++)
{
ap[idx] = bli_obj_buffer( &a[idx] );
bp[idx] = bli_obj_buffer( &b[idx] );
cp[idx] = bli_obj_buffer( &c[idx] );
idx++;
}
alphap[i] = *(float*)bli_obj_buffer_for_1x1(dt, &alpha[i]);
betap[i] = *(float*)bli_obj_buffer_for_1x1(dt, &beta[i] );
}
#ifdef CHECK_CBLAS
cblas_sgemm_batch( cblas_order,
cblas_transa,
cblas_transb,
f77_m, f77_n, f77_k,
alphap, ap, f77_lda,
bp, f77_ldb,
betap, cp, f77_ldc,
f77_group_count,
f77_group_size
);
#else
sgemm_batch_( f77_transa,
f77_transb,
f77_m, f77_n, f77_k,
alphap, ap, f77_lda,
bp, f77_ldb,
betap, cp, f77_ldc,
&f77_group_count,
f77_group_size
);
#endif
}
else if(bli_is_double(dt))
{
const double *ap[total_count], *bp[total_count];
double *cp[total_count];
double alphap[GRP_COUNT], betap[GRP_COUNT];
idx = 0;
for(i = 0; i < GRP_COUNT; i++)
{
for(j = 0; j < group_size[i]; j++)
{
ap[idx] = bli_obj_buffer( &a[idx] );
bp[idx] = bli_obj_buffer( &b[idx] );
cp[idx] = bli_obj_buffer( &c[idx] );
idx++;
}
alphap[i] = *(double*)bli_obj_buffer_for_1x1(dt, &alpha[i]);
betap[i] = *(double*)bli_obj_buffer_for_1x1(dt, &beta[i] );
}
#ifdef CHECK_CBLAS
cblas_dgemm_batch( cblas_order,
cblas_transa,
cblas_transb,
f77_m, f77_n, f77_k,
alphap, ap, f77_lda,
bp, f77_ldb,
betap, cp, f77_ldc,
f77_group_count,
f77_group_size
);
#else
dgemm_batch_( f77_transa,
f77_transb,
f77_m, f77_n, f77_k,
alphap, ap, f77_lda,
bp, f77_ldb,
betap, cp, f77_ldc,
&f77_group_count,
f77_group_size
);
#endif
}
else if(bli_is_scomplex(dt))
{
const scomplex *ap[total_count], *bp[total_count];
scomplex *cp[total_count];
scomplex alphap[GRP_COUNT], betap[GRP_COUNT];
idx = 0;
for(i = 0; i < GRP_COUNT; i++)
{
for(j = 0; j < group_size[i]; j++)
{
ap[idx] = bli_obj_buffer( &a[idx] );
bp[idx] = bli_obj_buffer( &b[idx] );
cp[idx] = bli_obj_buffer( &c[idx] );
idx++;
}
alphap[i] = *(scomplex*)bli_obj_buffer_for_1x1(dt, &alpha[i]);
betap[i] = *(scomplex*)bli_obj_buffer_for_1x1(dt, &beta[i] );
}
#ifdef CHECK_CBLAS
cblas_cgemm_batch( cblas_order,
cblas_transa,
cblas_transb,
f77_m, f77_n, f77_k,
(const void*)alphap,
(const void**)ap, f77_lda,
(const void**)bp, f77_ldb,
(const void*)betap, (void**)cp, f77_ldc,
f77_group_count,
f77_group_size
);
#else
cgemm_batch_( f77_transa,
f77_transb,
f77_m, f77_n, f77_k,
alphap, ap, f77_lda,
bp, f77_ldb,
betap, cp, f77_ldc,
&f77_group_count,
f77_group_size
);
#endif
}
else if(bli_is_dcomplex(dt))
{
const dcomplex *ap[total_count], *bp[total_count];
dcomplex *cp[total_count];
dcomplex alphap[GRP_COUNT], betap[GRP_COUNT];
idx = 0;
for(i = 0; i < GRP_COUNT; i++)
{
for(j = 0; j < group_size[i]; j++)
{
ap[idx] = bli_obj_buffer( &a[idx] );
bp[idx] = bli_obj_buffer( &b[idx] );
cp[idx] = bli_obj_buffer( &c[idx] );
idx++;
}
alphap[i] = *(dcomplex*)bli_obj_buffer_for_1x1(dt, &alpha[i]);
betap[i] = *(dcomplex*)bli_obj_buffer_for_1x1(dt, &beta[i] );
}
#ifdef CHECK_CBLAS
cblas_zgemm_batch( cblas_order,
cblas_transa,
cblas_transb,
f77_m, f77_n, f77_k,
(const void*)alphap,
(const void**)ap, f77_lda,
(const void**)bp, f77_ldb,
(const void*)betap, (void**)cp, f77_ldc,
f77_group_count,
f77_group_size
);
#else
zgemm_batch_( f77_transa,
f77_transb,
f77_m, f77_n, f77_k,
alphap, ap, f77_lda,
bp, f77_ldb,
betap, cp, f77_ldc,
&f77_group_count,
f77_group_size
);
#endif
}
#ifdef PRINT
idx = 0;
for(i = 0; i < GRP_COUNT; i++)
for(j = 0; j < group_size[i]; j++)
{
printf("Group: %ld Member: %ld\n", i, j);
bli_printm("c after", &c[idx], "%4.1f", "");
idx++;
}
#endif
dtime_save = bli_clock_min_diff( dtime_save, dtime );
}
dim_t fp_ops = 0;
for(i = 0; i < GRP_COUNT; i++)
fp_ops += 2.0 * m[i] * k[i] * n[i] * group_size[i];
gflops = fp_ops / (dtime_save * 1.0e9 );
if(bli_is_complex( dt ) ) gflops *= 4.0;
#ifdef FILE_IN_OUT
fprintf(fout, "Stor_scheme = %c, group_count = %lu, gflops = %7.2f\n", stor_scheme, GRP_COUNT, gflops);
for(i = 0; i < GRP_COUNT; i++)
fprintf(fout, "%4lu \t %4lu\t %4lu\t %4lu\t %4lu\t %4lu\t %c\t %c\t %4lu\n", m[i], n[i], k[i], lda[i], ldb[i], ldc[i], transa[i], transb[i], group_size[i]);
fflush(fout);
#else
printf( "Stor_scheme = %c, group_count = %d, gflops = %7.2f\n", stor_scheme, GRP_COUNT, gflops);
for(i = 0; i < GRP_COUNT; i++)
printf("%4lu \t %4lu\t %4lu\t %4lu\t %4lu\t %4lu\t %c\t %c\t %4lu\n", m[i], n[i], k[i], lda[i], ldb[i], ldc[i], transa[i], transb[i], group_size[i]);
#endif
idx = 0;
for(i = 0; i < GRP_COUNT; i++)
{
bli_obj_free( &alpha[i]);
bli_obj_free( &beta[i] );
for(j = 0; j < group_size[i]; j++ )
{
bli_obj_free( &a[idx]);
bli_obj_free( &b[idx]);
bli_obj_free( &c[idx]);
bli_obj_free( &c_save[idx]);
idx++;
}
}
#ifdef FILE_IN_OUT
}
fclose(fin);
fclose(fout);
#endif
return 0;
}