Code Cleanup done; Test code updated to add performance measurement

Change-Id: I639f22659c22226fbd81e1669e4372f200ab5129
This commit is contained in:
Chithra Sankar
2019-07-26 12:22:43 +05:30
parent 1223fcbcb8
commit 14c99492fe
13 changed files with 807 additions and 3964 deletions

View File

@@ -1,11 +1,10 @@
#ifndef BLAS_GEMM_HH
#define BLAS_GEMM_HH
#ifndef BLIS_HH
#define BLIS_HH
//#include "blis_util.hh"
#include "blis_util.hh"
#include "cblas.hh"
#include <limits>
#define blis_int int
namespace blis {
// =============================================================================
@@ -83,9 +82,9 @@ namespace blis {
template< typename TA, typename TB, typename TC >
void gemm(
blis::Layout layout,
blis::Op transA,
blis::Op transB,
CBLAS_LAYOUT layout,
CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int64_t m, int64_t n, int64_t k,
scalar_type<TA, TB, TC> alpha,
TA const *A, int64_t lda,
@@ -93,59 +92,11 @@ void gemm(
scalar_type<TA, TB, TC> beta,
TC *C, int64_t ldc )
{
#if 0
//throw std::exception(); // not yet implemented
printf("In gemm.cc\n");
cblis_gemm(cblis_layout_const(layout),
cblis_trans_const(transA),
cblis_trans_const(transB),
m, n, k, alpha, A,lda, B, ldb, beta, C, ldc);
#endif
// check arguments
blis_error_if( layout != Layout::ColMajor &&
layout != Layout::RowMajor );
blis_error_if( transA != Op::NoTrans &&
transA != Op::Trans &&
transA != Op::ConjTrans );
blis_error_if( transB != Op::NoTrans &&
transB != Op::Trans &&
transB != Op::ConjTrans );
blis_error_if( m < 0 );
blis_error_if( n < 0 );
blis_error_if( k < 0 );
if ((transA == Op::NoTrans) ^ (layout == Layout::RowMajor))
blis_error_if( lda < m );
else
blis_error_if( lda < k );
if ((transB == Op::NoTrans) ^ (layout == Layout::RowMajor))
blis_error_if( ldb < k );
else
blis_error_if( ldb < n );
if (layout == Layout::ColMajor)
blis_error_if( ldc < m );
else
blis_error_if( ldc < n );
// check for overflow in native BLAS integer type, if smaller than int64_t
if (sizeof(int64_t) > sizeof(blis_int)) {
blis_error_if( m > std::numeric_limits<blis_int>::max() );
blis_error_if( n > std::numeric_limits<blis_int>::max() );
blis_error_if( k > std::numeric_limits<blis_int>::max() );
blis_error_if( lda > std::numeric_limits<blis_int>::max() );
blis_error_if( ldb > std::numeric_limits<blis_int>::max() );
blis_error_if( ldc > std::numeric_limits<blis_int>::max() );
}
printf("In gemm.cpp\n");
cblas_gemm(cblas_layout_const(layout),
cblas_trans_const(transA),
cblas_trans_const(transB),
m, n, k, alpha, A,lda, B, ldb, beta, C, ldc);
// printf("In gemm.cpp\n");
cblas_gemm(layout, transA, transB, m, n, k, alpha, A,lda, B, ldb, beta, C, ldc);
};
} // namespace blis
#endif // #ifndef BLAS_GEMM_HH
#endif // #ifndef BLIS_HH

226
cpp/blis_util.hh Normal file
View File

@@ -0,0 +1,226 @@
#ifndef BLIS_UTIL_HH
#define BLIS_UTIL_HH
#include <complex>
#include <cstdarg>
namespace blis {
// -----------------------------------------------------------------------------
// Extend real, imag, conj to other datatypes.
template< typename T >
inline T real( T x ) { return x; }
template< typename T >
inline T imag( T x ) { return 0; }
template< typename T >
inline T conj( T x ) { return x; }
// -----------------------------------------------------------------------------
// 1-norm absolute value, |Re(x)| + |Im(x)|
template< typename T >
T abs1( T x )
{
return std::abs( x );
}
template< typename T >
T abs1( std::complex<T> x )
{
return std::abs( real(x) ) + std::abs( imag(x) );
}
// -----------------------------------------------------------------------------
// common_type_t is defined in C++14; here's a C++11 definition
#if __cplusplus >= 201402L
using std::common_type_t;
using std::decay_t;
#else
template< typename... Ts >
using common_type_t = typename std::common_type< Ts... >::type;
template< typename... Ts >
using decay_t = typename std::decay< Ts... >::type;
#endif
//------------------------------------------------------------------------------
/// True if T is std::complex<T2> for some type T2.
template <typename T>
struct is_complex:
std::integral_constant<bool, false>
{};
// specialize for std::complex
template <typename T>
struct is_complex< std::complex<T> >:
std::integral_constant<bool, true>
{};
// -----------------------------------------------------------------------------
// Based on C++14 common_type implementation from
// http://www.cplusplus.com/reference/type_traits/common_type/
// Adds promotion of complex types based on the common type of the associated
// real types. This fixes various cases:
//
// std::common_type_t< double, complex<float> > is complex<float> (wrong)
// scalar_type< double, complex<float> > is complex<double> (right)
//
// std::common_type_t< int, complex<long> > is not defined (compile error)
// scalar_type< int, complex<long> > is complex<long> (right)
// for zero types
template< typename... Types >
struct scalar_type_traits;
// define scalar_type<> type alias
template< typename... Types >
using scalar_type = typename scalar_type_traits< Types... >::type;
// for one type
template< typename T >
struct scalar_type_traits< T >
{
using type = decay_t<T>;
};
// for two types
// relies on type of ?: operator being the common type of its two arguments
template< typename T1, typename T2 >
struct scalar_type_traits< T1, T2 >
{
using type = decay_t< decltype( true ? std::declval<T1>() : std::declval<T2>() ) >;
};
// for either or both complex,
// find common type of associated real types, then add complex
template< typename T1, typename T2 >
struct scalar_type_traits< std::complex<T1>, T2 >
{
using type = std::complex< common_type_t< T1, T2 > >;
};
template< typename T1, typename T2 >
struct scalar_type_traits< T1, std::complex<T2> >
{
using type = std::complex< common_type_t< T1, T2 > >;
};
template< typename T1, typename T2 >
struct scalar_type_traits< std::complex<T1>, std::complex<T2> >
{
using type = std::complex< common_type_t< T1, T2 > >;
};
// for three or more types
template< typename T1, typename T2, typename... Types >
struct scalar_type_traits< T1, T2, Types... >
{
using type = scalar_type< scalar_type< T1, T2 >, Types... >;
};
// -----------------------------------------------------------------------------
// for any combination of types, determine associated real, scalar,
// and complex types.
//
// real_type< float > is float
// real_type< float, double, complex<float> > is double
//
// scalar_type< float > is float
// scalar_type< float, complex<float> > is complex<float>
// scalar_type< float, double, complex<float> > is complex<double>
//
// complex_type< float > is complex<float>
// complex_type< float, double > is complex<double>
// complex_type< float, double, complex<float> > is complex<double>
// for zero types
template< typename... Types >
struct real_type_traits;
// define real_type<> type alias
template< typename... Types >
using real_type = typename real_type_traits< Types... >::real_t;
// define complex_type<> type alias
template< typename... Types >
using complex_type = std::complex< real_type< Types... > >;
// for one type
template< typename T >
struct real_type_traits<T>
{
using real_t = T;
};
// for one complex type, strip complex
template< typename T >
struct real_type_traits< std::complex<T> >
{
using real_t = T;
};
// for two or more types
template< typename T1, typename... Types >
struct real_type_traits< T1, Types... >
{
using real_t = scalar_type< real_type<T1>, real_type< Types... > >;
};
// -----------------------------------------------------------------------------
// max that works with different data types: int64_t = max( int, int64_t )
// and any number of arguments: max( a, b, c, d )
// one argument
template< typename T >
T max( T x )
{
return x;
}
// two arguments
template< typename T1, typename T2 >
scalar_type< T1, T2 >
max( T1 x, T2 y )
{
return (x >= y ? x : y);
}
// three or more arguments
template< typename T1, typename... Types >
scalar_type< T1, Types... >
max( T1 first, Types... args )
{
return max( first, max( args... ) );
}
// -----------------------------------------------------------------------------
// min that works with different data types: int64_t = min( int, int64_t )
// and any number of arguments: min( a, b, c, d )
// one argument
template< typename T >
T min( T x )
{
return x;
}
// two arguments
template< typename T1, typename T2 >
scalar_type< T1, T2 >
min( T1 x, T2 y )
{
return (x <= y ? x : y);
}
// three or more arguments
template< typename T1, typename... Types >
scalar_type< T1, Types... >
min( T1 first, Types... args )
{
return min( first, min( args... ) );
}
} // namespace blis
#endif // #ifndef BLIS_UTIL_HH

85
cpp/cblas.hh Normal file
View File

@@ -0,0 +1,85 @@
#ifndef CBLAS_HH
#define CBLAS_HH
extern "C" {
#include <cblas.h>
#include <blis.h>
}
typedef CBLAS_ORDER CBLAS_LAYOUT;
#include <complex>
namespace blis{
// =============================================================================
// Level 3 BLAS
// -----------------------------------------------------------------------------
inline void
cblas_gemm(
CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
int m, int n, int k,
float alpha,
float const *A, int lda,
float const *B, int ldb,
float beta,
float* C, int ldc )
{
// printf("cblas_sgemm\n");
cblas_sgemm( layout, transA, transB, m, n, k,
alpha, A, lda, B, ldb,
beta, C, ldc );
}
inline void
cblas_gemm(
CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
int m, int n, int k,
double alpha,
double const *A, int lda,
double const *B, int ldb,
double beta,
double* C, int ldc )
{
// printf("cblas_dgemm\n");
cblas_dgemm( layout, transA, transB, m, n, k,
alpha, A, lda, B, ldb,
beta, C, ldc );
}
inline void
cblas_gemm(
CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
int m, int n, int k,
std::complex<float> alpha,
std::complex<float> const *A, int lda,
std::complex<float> const *B, int ldb,
std::complex<float> beta,
std::complex<float>* C, int ldc )
{
// printf("cblas_cgemm\n");
cblas_cgemm( layout, transA, transB, m, n, k,
&alpha, A, lda, B, ldb,
&beta, C, ldc );
}
inline void
cblas_gemm(
CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
int m, int n, int k,
std::complex<double> alpha,
std::complex<double> const *A, int lda,
std::complex<double> const *B, int ldb,
std::complex<double> beta,
std::complex<double>* C, int ldc )
{
// printf("cblas_zgemm\n");
cblas_zgemm( layout, transA, transB, m, n, k,
&alpha, A, lda, B, ldb,
&beta, C, ldc );
}
}//namespace blis
#endif // #ifndef CBLAS_HH

View File

@@ -1,56 +0,0 @@
#ifndef BLIS_HH
#define BLIS_HH
#include "blis_wrappers.hh"
// =============================================================================
// Level 1 BLIS template implementations
/*#include "asum.hh"
#include "axpy.hh"
#include "copy.hh"
#include "dot.hh"
#include "iamax.hh"
#include "nrm2.hh"
#include "rot.hh"
#include "rotg.hh"
#include "rotm.hh"
#include "rotmg.hh"
#include "scal.hh"
#include "swap.hh"
*/
// =============================================================================
// Level 2 BLIS template implementations
/*
#include "gemv.hh"
#include "ger.hh"
#include "geru.hh"
#include "hemv.hh"
#include "her.hh"
#include "her2.hh"
#include "symv.hh"
#include "syr.hh"
#include "syr2.hh"
#include "trmv.hh"
#include "trsv.hh"
// =============================================================================
// Level 3 BLIS template implementations
*/
#include "gemm.hh"
/*#include "hemm.hh"
#include "herk.hh"
#include "her2k.hh"
#include "symm.hh"
#include "syrk.hh"
#include "syr2k.hh"
#include "trmm.hh"
#include "trsm.hh"
*/
// =============================================================================
// Device BLIS
/*#ifdef BLISPP_WITH_CUBLIS
#include "device_blis.hh"
#endif
*/
#endif // #ifndef BLIS_HH

View File

@@ -1,450 +0,0 @@
#ifndef BLIS_UTIL_HH
#define BLIS_UTIL_HH
#include <exception>
#include <complex>
#include <cstdarg>
#include <assert.h>
namespace blis {
// -----------------------------------------------------------------------------
enum class Layout : char { ColMajor = 'C', RowMajor = 'R' };
enum class Op : char { NoTrans = 'N', Trans = 'T', ConjTrans = 'C' };
enum class Uplo : char { Upper = 'U', Lower = 'L', General = 'G' };
enum class Diag : char { NonUnit = 'N', Unit = 'U' };
enum class Side : char { Left = 'L', Right = 'R' };
// -----------------------------------------------------------------------------
// Convert enum to LAPACK-style char.
inline char layout2char( Layout layout ) { return char(layout); }
inline char op2char( Op op ) { return char(op); }
inline char uplo2char( Uplo uplo ) { return char(uplo); }
inline char diag2char( Diag diag ) { return char(diag); }
inline char side2char( Side side ) { return char(side); }
// -----------------------------------------------------------------------------
// Convert enum to LAPACK-style string.
inline const char* layout2str( Layout layout )
{
switch (layout) {
case Layout::ColMajor: return "col";
case Layout::RowMajor: return "row";
}
return "";
}
inline const char* op2str( Op op )
{
switch (op) {
case Op::NoTrans: return "notrans";
case Op::Trans: return "trans";
case Op::ConjTrans: return "conj";
}
return "";
}
inline const char* uplo2str( Uplo uplo )
{
switch (uplo) {
case Uplo::Lower: return "lower";
case Uplo::Upper: return "upper";
case Uplo::General: return "general";
}
return "";
}
inline const char* diag2str( Diag diag )
{
switch (diag) {
case Diag::NonUnit: return "nonunit";
case Diag::Unit: return "unit";
}
return "";
}
inline const char* side2str( Side side )
{
switch (side) {
case Side::Left: return "left";
case Side::Right: return "right";
}
return "";
}
// -----------------------------------------------------------------------------
// Convert LAPACK-style char to enum.
inline Layout char2layout( char layout )
{
layout = (char) toupper( layout );
assert( layout == 'C' || layout == 'R' );
return Layout( layout );
}
inline Op char2op( char op )
{
op = (char) toupper( op );
assert( op == 'N' || op == 'T' || op == 'C' );
return Op( op );
}
inline Uplo char2uplo( char uplo )
{
uplo = (char) toupper( uplo );
assert( uplo == 'L' || uplo == 'U' || uplo == 'G' );
return Uplo( uplo );
}
inline Diag char2diag( char diag )
{
diag = (char) toupper( diag );
assert( diag == 'N' || diag == 'U' );
return Diag( diag );
}
inline Side char2side( char side )
{
side = (char) toupper( side );
assert( side == 'L' || side == 'R' );
return Side( side );
}
// -----------------------------------------------------------------------------
/// Exception class for BLIS errors.
class Error: public std::exception {
public:
/// Constructs BLIS error
Error():
std::exception()
{}
/// Constructs BLIS error with message
Error( std::string const& msg ):
std::exception(),
msg_( msg )
{}
/// Constructs BLIS error with message: "msg, in function func"
Error( const char* msg, const char* func ):
std::exception(),
msg_( std::string(msg) + ", in function " + func )
{}
/// Returns BLIS error message
virtual const char* what() const noexcept override
{ return msg_.c_str(); }
private:
std::string msg_;
};
// -----------------------------------------------------------------------------
// Extend real, imag, conj to other datatypes.
template< typename T >
inline T real( T x ) { return x; }
template< typename T >
inline T imag( T x ) { return 0; }
template< typename T >
inline T conj( T x ) { return x; }
// -----------------------------------------------------------------------------
// 1-norm absolute value, |Re(x)| + |Im(x)|
template< typename T >
T abs1( T x )
{
return std::abs( x );
}
template< typename T >
T abs1( std::complex<T> x )
{
return std::abs( real(x) ) + std::abs( imag(x) );
}
// -----------------------------------------------------------------------------
// common_type_t is defined in C++14; here's a C++11 definition
#if __cplusplus >= 201402L
using std::common_type_t;
using std::decay_t;
#else
template< typename... Ts >
using common_type_t = typename std::common_type< Ts... >::type;
template< typename... Ts >
using decay_t = typename std::decay< Ts... >::type;
#endif
//------------------------------------------------------------------------------
/// True if T is std::complex<T2> for some type T2.
template <typename T>
struct is_complex:
std::integral_constant<bool, false>
{};
// specialize for std::complex
template <typename T>
struct is_complex< std::complex<T> >:
std::integral_constant<bool, true>
{};
// -----------------------------------------------------------------------------
// Based on C++14 common_type implementation from
// http://www.cplusplus.com/reference/type_traits/common_type/
// Adds promotion of complex types based on the common type of the associated
// real types. This fixes various cases:
//
// std::common_type_t< double, complex<float> > is complex<float> (wrong)
// scalar_type< double, complex<float> > is complex<double> (right)
//
// std::common_type_t< int, complex<long> > is not defined (compile error)
// scalar_type< int, complex<long> > is complex<long> (right)
// for zero types
template< typename... Types >
struct scalar_type_traits;
// define scalar_type<> type alias
template< typename... Types >
using scalar_type = typename scalar_type_traits< Types... >::type;
// for one type
template< typename T >
struct scalar_type_traits< T >
{
using type = decay_t<T>;
};
// for two types
// relies on type of ?: operator being the common type of its two arguments
template< typename T1, typename T2 >
struct scalar_type_traits< T1, T2 >
{
using type = decay_t< decltype( true ? std::declval<T1>() : std::declval<T2>() ) >;
};
// for either or both complex,
// find common type of associated real types, then add complex
template< typename T1, typename T2 >
struct scalar_type_traits< std::complex<T1>, T2 >
{
using type = std::complex< common_type_t< T1, T2 > >;
};
template< typename T1, typename T2 >
struct scalar_type_traits< T1, std::complex<T2> >
{
using type = std::complex< common_type_t< T1, T2 > >;
};
template< typename T1, typename T2 >
struct scalar_type_traits< std::complex<T1>, std::complex<T2> >
{
using type = std::complex< common_type_t< T1, T2 > >;
};
// for three or more types
template< typename T1, typename T2, typename... Types >
struct scalar_type_traits< T1, T2, Types... >
{
using type = scalar_type< scalar_type< T1, T2 >, Types... >;
};
// -----------------------------------------------------------------------------
// for any combination of types, determine associated real, scalar,
// and complex types.
//
// real_type< float > is float
// real_type< float, double, complex<float> > is double
//
// scalar_type< float > is float
// scalar_type< float, complex<float> > is complex<float>
// scalar_type< float, double, complex<float> > is complex<double>
//
// complex_type< float > is complex<float>
// complex_type< float, double > is complex<double>
// complex_type< float, double, complex<float> > is complex<double>
// for zero types
template< typename... Types >
struct real_type_traits;
// define real_type<> type alias
template< typename... Types >
using real_type = typename real_type_traits< Types... >::real_t;
// define complex_type<> type alias
template< typename... Types >
using complex_type = std::complex< real_type< Types... > >;
// for one type
template< typename T >
struct real_type_traits<T>
{
using real_t = T;
};
// for one complex type, strip complex
template< typename T >
struct real_type_traits< std::complex<T> >
{
using real_t = T;
};
// for two or more types
template< typename T1, typename... Types >
struct real_type_traits< T1, Types... >
{
using real_t = scalar_type< real_type<T1>, real_type< Types... > >;
};
// -----------------------------------------------------------------------------
// max that works with different data types: int64_t = max( int, int64_t )
// and any number of arguments: max( a, b, c, d )
// one argument
template< typename T >
T max( T x )
{
return x;
}
// two arguments
template< typename T1, typename T2 >
scalar_type< T1, T2 >
max( T1 x, T2 y )
{
return (x >= y ? x : y);
}
// three or more arguments
template< typename T1, typename... Types >
scalar_type< T1, Types... >
max( T1 first, Types... args )
{
return max( first, max( args... ) );
}
// -----------------------------------------------------------------------------
// min that works with different data types: int64_t = min( int, int64_t )
// and any number of arguments: min( a, b, c, d )
// one argument
template< typename T >
T min( T x )
{
return x;
}
// two arguments
template< typename T1, typename T2 >
scalar_type< T1, T2 >
min( T1 x, T2 y )
{
return (x <= y ? x : y);
}
// three or more arguments
template< typename T1, typename... Types >
scalar_type< T1, Types... >
min( T1 first, Types... args )
{
return min( first, min( args... ) );
}
namespace internal {
// -----------------------------------------------------------------------------
// internal helper function; throws Error if cond is true
// called by blis_error_if macro
inline void throw_if( bool cond, const char* condstr, const char* func )
{
if (cond) {
throw Error( condstr, func );
}
}
// -----------------------------------------------------------------------------
// internal helper function; throws Error if cond is true
// uses printf-style format for error message
// called by blis_error_if_msg macro
// condstr is ignored, but differentiates this from other version.
inline void throw_if( bool cond, const char* condstr, const char* func, const char* format, ... )
__attribute__((format( printf, 4, 5 )));
inline void throw_if( bool cond, const char* condstr, const char* func, const char* format, ... )
{
if (cond) {
char buf[80];
va_list va;
va_start( va, format );
vsnprintf( buf, sizeof(buf), format, va );
throw Error( buf, func );
}
}
// -----------------------------------------------------------------------------
// internal helper function; aborts if cond is true
// uses printf-style format for error message
// called by blis_error_if_msg macro
inline void abort_if( bool cond, const char* func, const char* format, ... )
__attribute__((format( printf, 3, 4 )));
inline void abort_if( bool cond, const char* func, const char* format, ... )
{
if (cond) {
char buf[80];
va_list va;
va_start( va, format );
vsnprintf( buf, sizeof(buf), format, va );
fprintf( stderr, "Error: %s, in function %s\n", buf, func );
abort();
}
}
} // namespace internal
// -----------------------------------------------------------------------------
// internal macros to handle error checks
#if defined(BLIS_ERROR_NDEBUG) || (defined(BLIS_ERROR_ASSERT) && defined(NDEBUG))
// blispp does no error checking;
// lower level BLIS may still handle errors via xerbla
#define blis_error_if( cond ) \
((void)0)
#define blis_error_if_msg( cond, ... ) \
((void)0)
#elif defined(BLIS_ERROR_ASSERT)
// blispp aborts on error
#define blis_error_if( cond ) \
blis::internal::abort_if( cond, __func__, "%s", #cond )
#define blis_error_if_msg( cond, ... ) \
blis::internal::abort_if( cond, __func__, __VA_ARGS__ )
#else
// blispp throws errors (default)
// internal macro to get string #cond; throws Error if cond is true
// ex: blis_error_if( a < b );
#define blis_error_if( cond ) \
blis::internal::throw_if( cond, #cond, __func__ )
// internal macro takes cond and printf-style format for error message.
// throws Error if cond is true.
// ex: blis_error_if_msg( a < b, "a %d < b %d", a, b );
#define blis_error_if_msg( cond, ... ) \
blis::internal::throw_if( cond, #cond, __func__, __VA_ARGS__ )
#endif
} // namespace blis
#endif // #ifndef BLIS_UTIL_HH

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -144,10 +144,10 @@ TEST_OBJS := $(patsubst $(TEST_SRC_PATH)/%.c, \
CINCFLAGS := -I$(INC_PATH)
# Use the CFLAGS for the configuration family.
CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME))
override CFLAGS += $(call get-user-cflags-for,$(CONFIG_NAME))
# Add local header paths to CFLAGS
CFLAGS += -I$(TEST_SRC_PATH)
override CFLAGS += -I$(TEST_SRC_PATH)
# Locate the libblis library to which we will link.
LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L)
@@ -165,22 +165,7 @@ LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L)
#all: blis openblas atlas mkl
all: blis openblas mkl
blis: test_dotv_blis.x \
test_axpyv_blis.x \
test_gemv_blis.x \
test_ger_blis.x \
test_hemv_blis.x \
test_her_blis.x \
test_her2_blis.x \
test_trmv_blis.x \
test_trsv_blis.x \
\
test_gemm_blis.x \
test_hemm_blis.x \
test_herk_blis.x \
test_her2k_blis.x \
test_trmm_blis.x \
test_trsm_blis.x
blis: test_gemm_blis.x
openblas: \
test_dotv_openblas.x \

View File

@@ -39,7 +39,7 @@
//#define FILE_IN_OUT // File based input matrix dimensions
//#define PRINT
#define PRINT
int main( int argc, char** argv )
{
@@ -73,7 +73,7 @@ int main( int argc, char** argv )
//bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING );
n_repeats = 10;
n_repeats = 100;
#ifndef PRINT
p_begin = 200;
@@ -93,13 +93,23 @@ int main( int argc, char** argv )
n_input = 4;
#endif
#if 1
//dt = BLIS_FLOAT;
dt = BLIS_DOUBLE;
#if 0
// dt = BLIS_FLOAT;
// dt = BLIS_DOUBLE;
#else
//dt = BLIS_SCOMPLEX;
dt = BLIS_DCOMPLEX;
// dt = BLIS_SCOMPLEX;
// dt = BLIS_DCOMPLEX;
#endif
#ifdef FLOAT
dt = BLIS_FLOAT;
#elif defined DOUBLE
dt = BLIS_DOUBLE;
#elif defined SCOMPLEX
dt = BLIS_SCOMPLEX;
#elif defined DCOMPLEX
dt = BLIS_DCOMPLEX;
#endif
transa = BLIS_NO_TRANSPOSE;
transb = BLIS_NO_TRANSPOSE;
@@ -205,7 +215,7 @@ int main( int argc, char** argv )
bli_printm( "c", &c, "%4.1f", "" );
#endif
#if 0 //def BLIS
#if BLIS
bli_gemm( &alpha,
&a,

View File

@@ -131,7 +131,7 @@ MAC_LIB := -framework Accelerate
#
TEST_SRC_PATH := .
CPP_SRC_PATH := ../srccpp/
CPP_SRC_PATH := ../cpp/
TEST_OBJ_PATH := .
# Gather all local object files.
@@ -147,13 +147,13 @@ CINCFLAGS := -I$(INC_PATH)
CXX = g++
# Use the CFLAGS for the configuration family.
CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME))
override CFLAGS += $(call get-user-cflags-for,$(CONFIG_NAME))
# Add local header paths to CFLAGS
#CFLAGS = -O0 -g -Wall
#CFLAGS += -I$(INC_PATH)
CFLAGS += -I$(TEST_SRC_PATH)
CFLAGS += -I$(CPP_SRC_PATH)
override CFLAGS += -I$(TEST_SRC_PATH)
override CFLAGS += -I$(CPP_SRC_PATH)
LINKER = $(CXX)
@@ -173,7 +173,8 @@ LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L)
#all: blis openblas atlas mkl
all: blis openblas mkl
blis: test_gemm_blis.x
blis: test_gemm_blis.x \
test_gemm1_blis.x
openblas: \
test_dotv_openblas.x \
@@ -285,7 +286,7 @@ test_%_mac.o: test_%.c
$(CC) $(CFLAGS) -DBLAS=\"mac\" -c $< -o $@
test_%_blis.o: test_%.cc
$(CXX) $(CFLAGS) -DBLIS -c $< -o $@
$(CXX) $(CFLAGS) -c $< -o $@
# -- Executable file rules --

32
testcpp/test.sh Executable file
View File

@@ -0,0 +1,32 @@
CWD=$(pwd)
echo $CWD
make clean
make blis CFLAGS+="-DFLOAT"
numactl -C 1 ./test_gemm1_blis.x
make clean
make blis CFLAGS+="-DDOUBLE"
numactl -C 1 ./test_gemm1_blis.x
make clean
make blis CFLAGS+="-DSCOMPLEX"
numactl -C 1 ./test_gemm1_blis.x
make clean
make blis CFLAGS+="-DDCOMPLEX"
numactl -C 1 ./test_gemm1_blis.x
cd ../test/
CWD=$(pwd)
echo $CWD
make clean
make blis CFLAGS+="-DFLOAT"
numactl -C 1 ./test_gemm_blis.x
make clean
make blis CFLAGS+="-DDOUBLE"
numactl -C 1 ./test_gemm_blis.x
make clean
make blis CFLAGS+="-DSCOMPLEX"
numactl -C 1 ./test_gemm_blis.x
make clean
make blis CFLAGS+="-DDCOMPLEX"
numactl -C 1 ./test_gemm_blis.x

View File

@@ -4,7 +4,7 @@
#include <iostream>
#include <string.h>
#include <unistd.h>
#include "gemm.hh"
#include "blis.hh"
using namespace std;
@@ -60,7 +60,7 @@ int main( int argc, char** argv )
print_matrix<float>(a_f , M , K);
cout<<"b_f= \n";
print_matrix<float>(b_f , K , N);
blis::gemm(blis::Layout::RowMajor, blis::Op::NoTrans, blis::Op::NoTrans, M, N, K, alpha_f, a_f,
blis::gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha_f, a_f,
lda, b_f, ldb, beta_f, c_f, ldc);
cout<<"c_f= \n";
print_matrix<float>(c_f , M , N);
@@ -71,7 +71,7 @@ int main( int argc, char** argv )
print_matrix<double>(a_d , M , K);
printf("b_d = \n");
print_matrix<double>(b_d , K , N);
blis::gemm(blis::Layout::RowMajor, blis::Op::NoTrans, blis::Op::NoTrans, M, N, K, alpha_d, a_d,
blis::gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha_d, a_d,
lda, b_d, ldb, beta_d, c_d, ldc);
printf("c_d = \n");
print_matrix<double>(c_d , M , N);
@@ -82,7 +82,7 @@ int main( int argc, char** argv )
print_matrix<std::complex<float>>(a_c , M , K);
printf("b_c = \n");
print_matrix<std::complex<float>>(b_c , K , N);
blis::gemm(blis::Layout::RowMajor, blis::Op::NoTrans, blis::Op::NoTrans, M, N, K, alpha_c, a_c,
blis::gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha_c, a_c,
lda, b_c, ldb, beta_c, c_c, ldc);
printf("c_c = \n");
print_matrix<std::complex<float>>(c_c , M , N);
@@ -93,7 +93,7 @@ int main( int argc, char** argv )
print_matrix<std::complex<double>>(a_z , M , K);
printf("b_z = \n");
print_matrix<std::complex<double>>(b_z , K , N);
blis::gemm(blis::Layout::RowMajor, blis::Op::NoTrans, blis::Op::NoTrans, M, N, K, alpha_z, a_z,
blis::gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha_z, a_z,
lda, b_z, ldb, beta_z, c_z, ldc);
printf("c_z = \n");
print_matrix<std::complex<double>>(c_z , M , N);

421
testcpp/test_gemm1.cc Normal file
View File

@@ -0,0 +1,421 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2014, The University of Texas at Austin
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name of The University of Texas nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <complex>
#include <stdio.h>
#include <iostream>
#include <string.h>
#include <unistd.h>
#include "blis.hh"
using namespace std;
//#define FILE_IN_OUT // File based input matrix dimensions
#define PRINT
int main( int argc, char** argv )
{
obj_t a, b, c;
obj_t c_save;
obj_t alpha, beta;
dim_t m, n, k;
dim_t p;
dim_t p_begin, p_end, p_inc;
int m_input, n_input, k_input;
num_t dt;
int r, n_repeats;
trans_t transa;
trans_t transb;
// f77_char f77_transa;
// f77_char f77_transb;
double dtime;
double dtime_save;
double gflops;
#ifdef FILE_IN_OUT
FILE* fin = NULL;
FILE* fout = NULL;
char gemm = 's';
#endif
//bli_init();
//bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING );
n_repeats = 100;
#ifndef PRINT
p_begin = 200;
p_end = 2000;
p_inc = 200;
m_input = -1;
n_input = -1;
k_input = -1;
#else
p_begin = 16;
p_end = 16;
p_inc = 1;
m_input = 5;
k_input = 6;
n_input = 4;
#endif
#ifdef FLOAT
dt = BLIS_FLOAT;
#elif defined DOUBLE
dt = BLIS_DOUBLE;
#elif defined SCOMPLEX
dt = BLIS_SCOMPLEX;
#elif defined DCOMPLEX
dt = BLIS_DCOMPLEX;
#endif
transa = BLIS_NO_TRANSPOSE;
transb = BLIS_NO_TRANSPOSE;
// bli_param_map_blis_to_netlib_trans( transa, &f77_transa );
// bli_param_map_blis_to_netlib_trans( transb, &f77_transb );
#ifdef FILE_IN_OUT
if (argc < 3)
{
printf("Usage: ./test_gemm_XX.x input.csv output.csv\n");
exit(1);
}
fin = fopen(argv[1], "r");
if (fin == NULL)
{
printf("Error opening the file %s\n", argv[1]);
exit(1);
}
fout = fopen(argv[2], "w");
if (fout == NULL)
{
printf("Error opening output file %s\n", argv[2]);
exit(1);
}
fprintf(fout, "m\t k\t n\t cs_a\t cs_b\t cs_c\t gflops\t GEMM_Algo\n");
printf("~~~~~~~~~~_BLAS\t m\t k\t n\t cs_a\t cs_b\t cs_c \t gflops\t GEMM_Algo\n");
inc_t cs_a;
inc_t cs_b;
inc_t cs_c;
while (fscanf(fin, "%ld %ld %ld %ld %ld %ld\n", &m, &k, &n, &cs_a, &cs_b, &cs_c) == 6)
{
if ((m > cs_a) || (k > cs_b) || (m > cs_c)) continue; // leading dimension should be greater than number of rows
bli_obj_create( dt, 1, 1, 0, 0, &alpha);
bli_obj_create( dt, 1, 1, 0, 0, &beta );
bli_obj_create( dt, m, k, 1, cs_a, &a );
bli_obj_create( dt, k, n, 1, cs_b, &b );
bli_obj_create( dt, m, n, 1, cs_c, &c );
bli_obj_create( dt, m, n, 1, cs_c, &c_save );
bli_obj_set_conjtrans( transa, &a);
bli_obj_set_conjtrans( transb, &b);
//bli_setsc( 0.0, -1, &alpha );
//bli_setsc( 0.0, 1, &beta );
bli_setsc( -1, 0.0, &alpha );
bli_setsc( 1, 0.0, &beta );
// printf("%1.1f %1.1f\n", *((double *)bli_obj_buffer_for_const(BLIS_FLOAT, &alpha)), *((double *)bli_obj_buffer_for_const(BLIS_FLOAT, &beta)));
#else
for ( p = p_begin; p <= p_end; p += p_inc )
{
if ( m_input < 0 ) m = p * ( dim_t )abs(m_input);
else m = ( dim_t ) m_input;
if ( n_input < 0 ) n = p * ( dim_t )abs(n_input);
else n = ( dim_t ) n_input;
if ( k_input < 0 ) k = p * ( dim_t )abs(k_input);
else k = ( dim_t ) k_input;
bli_obj_create( dt, 1, 1, 0, 0, &alpha );
bli_obj_create( dt, 1, 1, 0, 0, &beta );
bli_obj_create( dt, m, k, 0, 0, &a );
bli_obj_create( dt, k, n, 0, 0, &b );
bli_obj_create( dt, m, n, 0, 0, &c );
bli_obj_create( dt, m, n, 0, 0, &c_save );
bli_randm( &a );
bli_randm( &b );
bli_randm( &c );
bli_obj_set_conjtrans( transa, &a );
bli_obj_set_conjtrans( transb, &b );
bli_setsc( (0.9/1.0), 0.2, &alpha );
bli_setsc( -(1.1/1.0), 0.3, &beta );
#endif
bli_copym( &c, &c_save );
dtime_save = DBL_MAX;
for ( r = 0; r < n_repeats; ++r )
{
bli_copym( &c_save, &c );
dtime = bli_clock();
#ifdef PRINT
bli_printm( "a", &a, "%4.1f", "" );
bli_printm( "b", &b, "%4.1f", "" );
bli_printm( "c", &c, "%4.1f", "" );
#endif
#if 0
bli_gemm( &alpha,
&a,
&b,
&beta,
&c );
#else
if ( bli_is_float( dt ) )
{
int M = bli_obj_length( &c );
int K = bli_obj_width_after_trans( &a );
int N = bli_obj_width( &c );
int lda = bli_obj_col_stride( &a );
int ldb = bli_obj_col_stride( &b );
int ldc = bli_obj_col_stride( &c );
float* alphap = (float *)bli_obj_buffer( &alpha );
float* ap = (float *)bli_obj_buffer( &a );
float* bp = (float *)bli_obj_buffer( &b );
float* betap = (float *)bli_obj_buffer( &beta );
float* cp = (float *)bli_obj_buffer( &c );
blis::gemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, K, *alphap, ap,
lda, bp, ldb, *betap, cp, ldc);
#if 0
sgemm_( &f77_transa,
&f77_transb,
&mm,
&nn,
&kk,
alphap,
ap, &lda,
bp, &ldb,
betap,
cp, &ldc );
#endif
}
else if ( bli_is_double( dt ) )
{
int M = bli_obj_length( &c );
int K = bli_obj_width_after_trans( &a );
int N = bli_obj_width( &c );
int lda = bli_obj_col_stride( &a );
int ldb = bli_obj_col_stride( &b );
int ldc = bli_obj_col_stride( &c );
double* alphap = (double*)bli_obj_buffer( &alpha );
double* ap = (double*)bli_obj_buffer( &a );
double* bp = (double*)bli_obj_buffer( &b );
double* betap = (double*)bli_obj_buffer( &beta );
double* cp = (double*)bli_obj_buffer( &c );
blis::gemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, K, *alphap, ap,
lda, bp, ldb, *betap, cp, ldc);
#if 0
dgemm_( &f77_transa,
&f77_transb,
&mm,
&nn,
&kk,
alphap,
ap, &lda,
bp, &ldb,
betap,
cp, &ldc );
#endif
}
else if ( bli_is_scomplex( dt ) )
{
int M = bli_obj_length( &c );
int K = bli_obj_width_after_trans( &a );
int N = bli_obj_width( &c );
int lda = bli_obj_col_stride( &a );
int ldb = bli_obj_col_stride( &b );
int ldc = bli_obj_col_stride( &c );
std::complex<float>* alphap = (std::complex<float>*)bli_obj_buffer( &alpha );
std::complex<float>* ap = (std::complex<float>*)bli_obj_buffer( &a );
std::complex<float>* bp = (std::complex<float>*)bli_obj_buffer( &b );
std::complex<float>* betap = (std::complex<float>*)bli_obj_buffer( &beta );
std::complex<float>* cp = (std::complex<float>*)bli_obj_buffer( &c );
blis::gemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, K, *alphap, ap,
lda, bp, ldb, *betap, cp, ldc);
#if 0
cgemm_( &f77_transa,
&f77_transb,
&mm,
&nn,
&kk,
alphap,
ap, &lda,
bp, &ldb,
betap,
cp, &ldc );
#endif
}
else if ( bli_is_dcomplex( dt ) )
{
f77_int M = bli_obj_length( &c );
f77_int K = bli_obj_width_after_trans( &a );
f77_int N = bli_obj_width( &c );
f77_int lda = bli_obj_col_stride( &a );
f77_int ldb = bli_obj_col_stride( &b );
f77_int ldc = bli_obj_col_stride( &c );
std::complex<double>* alphap = (std::complex<double>*)bli_obj_buffer( &alpha );
std::complex<double>* ap = (std::complex<double>*)bli_obj_buffer( &a );
std::complex<double>* bp = (std::complex<double>*)bli_obj_buffer( &b );
std::complex<double>* betap = (std::complex<double>*)bli_obj_buffer( &beta );
std::complex<double>* cp = (std::complex<double>*)bli_obj_buffer( &c );
blis::gemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M, N, K, *alphap, ap,
lda, bp, ldb, *betap, cp, ldc);
#if 0
zgemm_( &f77_transa,
&f77_transb,
&mm,
&nn,
&kk,
alphap,
ap, &lda,
bp, &ldb,
betap,
cp, &ldc );
#endif
}
#endif
#ifdef PRINT
bli_printm( "c after", &c, "%4.1f", "" );
exit(1);
#endif
dtime_save = bli_clock_min_diff( dtime_save, dtime );
}
gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 );
if ( bli_is_complex( dt ) ) gflops *= 4.0;
#ifdef BLIS
printf( "data_gemm_blis" );
#else
//printf( "data_gemm_%s", BLAS );
#endif
#ifdef FILE_IN_OUT
if ( bli_is_double( dt ) ) {
if (((m * n) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES/4)) || ((m < (BLIS_SMALL_M_RECT_MATRIX_THRES/2) ) && (k < (BLIS_SMALL_K_RECT_MATRIX_THRES/2) )))
gemm = 'S'; // small gemm
else gemm = 'N'; // Normal blis gemm
}
else if (bli_is_float( dt )) {
if (((m * n) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES)) || ((m < BLIS_SMALL_M_RECT_MATRIX_THRES) && (k < BLIS_SMALL_K_RECT_MATRIX_THRES)))
gemm = 'S'; // small gemm
else gemm = 'N'; // normal blis gemm
}
printf("%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f \t %c\n", \
( unsigned long )m,
( unsigned long )k,
( unsigned long )n, (unsigned long)cs_a, (unsigned long)cs_b, (unsigned long)cs_c, gflops, gemm );
fprintf(fout, "%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f \t %c\n", \
( unsigned long )m,
( unsigned long )k,
( unsigned long )n, (unsigned long)cs_a, (unsigned long)cs_b, (unsigned long)cs_c, gflops, gemm );
fflush(fout);
#else
printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n",
( unsigned long )(p - p_begin + 1)/p_inc + 1,
( unsigned long )m,
( unsigned long )k,
( unsigned long )n, gflops );
#endif
bli_obj_free( &alpha );
bli_obj_free( &beta );
bli_obj_free( &a );
bli_obj_free( &b );
bli_obj_free( &c );
bli_obj_free( &c_save );
}
//bli_finalize();
#ifdef FILE_IN_OUT
fclose(fin);
fclose(fout);
#endif
return 0;
}