mirror of
https://github.com/amd/blis.git
synced 2026-05-13 18:52:14 +00:00
Added Doxygen Comment to all functions; Fixed Review comments; Modified test application to use template functions
Change-Id: I920c335776bc4597af1c988b538e8dda706195fa
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
/*
|
||||
|
||||
BLISPP
|
||||
C++ test driver for BLIS CPP gemm routine and reference cblas gemm routine.
|
||||
C++ test driver for BLIS CPP gemm routine and reference blis gemm routine.
|
||||
|
||||
Copyright (C) 2019, Advanced Micro Devices, Inc.
|
||||
|
||||
@@ -39,55 +39,77 @@
|
||||
using namespace blis;
|
||||
using namespace std;
|
||||
#define PRINT
|
||||
/*
|
||||
* Test application assumes matrices to be column major, non-transposed
|
||||
*/
|
||||
void ref_gemm(num_t dt , int64_t m, int64_t n, int64_t k,
|
||||
void * alpha,
|
||||
void *A,
|
||||
void *B,
|
||||
void * beta,
|
||||
void *C )
|
||||
|
||||
void test_dgemm( )
|
||||
{
|
||||
obj_t obj_a, obj_b, obj_c;
|
||||
obj_t obj_alpha, obj_beta;
|
||||
bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha );
|
||||
bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta );
|
||||
bli_obj_create_with_attached_buffer( dt, m, k, A, 1,m,&obj_a );
|
||||
bli_obj_create_with_attached_buffer( dt, k, n, B,1,k,&obj_b );
|
||||
bli_obj_create_with_attached_buffer( dt, m, n, C, 1,m,&obj_c );
|
||||
|
||||
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_a );
|
||||
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_b );
|
||||
bli_gemm( &obj_alpha,
|
||||
&obj_a,
|
||||
&obj_b,
|
||||
&obj_beta,
|
||||
&obj_c );
|
||||
|
||||
}
|
||||
template< typename T >
|
||||
void test_gemm( )
|
||||
{
|
||||
int i, j, p;
|
||||
double *A, *B, *C, *C_ref;
|
||||
double alpha, beta;
|
||||
double flops;
|
||||
double ref_beg, ref_time, bl_dgemm_beg, bl_dgemm_time;
|
||||
int nrepeats;
|
||||
T *A, *B, *C, *C_ref;
|
||||
T alpha, beta;
|
||||
int m,n,k;
|
||||
int lda, ldb, ldc, ldc_ref;
|
||||
double ref_rectime, bl_dgemm_rectime;
|
||||
|
||||
alpha = 1.0;
|
||||
beta = 0.0;
|
||||
m = 5;
|
||||
k = 6;
|
||||
n = 4;
|
||||
k = 4;
|
||||
n = 6;
|
||||
|
||||
A = new double[m * k];
|
||||
B = new double[k * n];
|
||||
A = new T[m * k];
|
||||
B = new T[k * n];
|
||||
|
||||
lda = m;
|
||||
ldb = k;
|
||||
ldc = m;
|
||||
ldc_ref = m;
|
||||
C = new double[ldc * n];
|
||||
C_ref= new double[m * n];
|
||||
|
||||
nrepeats = 3;
|
||||
C = new T[ldc * n];
|
||||
C_ref= new T[m * n];
|
||||
|
||||
srand48 (time(NULL));
|
||||
|
||||
// Randonly generate points in [ 0, 1 ].
|
||||
for ( p = 0; p < k; p ++ ) {
|
||||
for ( i = 0; i < m; i ++ ) {
|
||||
A( i, p ) = (double)( drand48() );
|
||||
A( i, p ) = (T)( drand48() );
|
||||
}
|
||||
}
|
||||
for ( j = 0; j < n; j ++ ) {
|
||||
for ( p = 0; p < k; p ++ ) {
|
||||
B( p, j ) = (double)( drand48() );
|
||||
B( p, j ) = (T)( drand48() );
|
||||
}
|
||||
}
|
||||
|
||||
for ( j = 0; j < n; j ++ ) {
|
||||
for ( i = 0; i < m; i ++ ) {
|
||||
C_ref( i, j ) = (double)( 0.0 );
|
||||
C( i, j ) = (double)( 0.0 );
|
||||
C_ref( i, j ) = (T)( 0.0 );
|
||||
C( i, j ) = (T)( 0.0 );
|
||||
}
|
||||
}
|
||||
#ifdef PRINT
|
||||
@@ -95,8 +117,6 @@ void test_dgemm( )
|
||||
bl_dgemm_printmatrix(B, ldb ,k,n);
|
||||
bl_dgemm_printmatrix(C, ldc ,m,n);
|
||||
#endif
|
||||
for ( i = 0; i < nrepeats; i ++ ) {
|
||||
bl_dgemm_beg = bl_clock();
|
||||
blis::gemm(
|
||||
CblasColMajor,
|
||||
CblasNoTrans,
|
||||
@@ -113,43 +133,18 @@ void test_dgemm( )
|
||||
C,
|
||||
ldc
|
||||
);
|
||||
bl_dgemm_time = bl_clock() - bl_dgemm_beg;
|
||||
|
||||
if ( i == 0 ) {
|
||||
bl_dgemm_rectime = bl_dgemm_time;
|
||||
} else {
|
||||
bl_dgemm_rectime = bl_dgemm_time < bl_dgemm_rectime ? bl_dgemm_time : bl_dgemm_rectime;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef PRINT
|
||||
bl_dgemm_printmatrix(C, ldc ,m,n);
|
||||
#endif
|
||||
for ( i = 0; i < nrepeats; i ++ ) {
|
||||
ref_beg = bl_clock();
|
||||
cblas_dgemm(
|
||||
CblasColMajor,
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
A,
|
||||
lda,
|
||||
B,
|
||||
ldb,
|
||||
beta,
|
||||
C_ref,
|
||||
ldc_ref);
|
||||
ref_time = bl_clock() - ref_beg;
|
||||
|
||||
if ( i == 0 ) {
|
||||
ref_rectime = ref_time;
|
||||
} else {
|
||||
ref_rectime = ref_time < ref_rectime ? ref_time : ref_rectime;
|
||||
}
|
||||
}
|
||||
if(is_same<T, float>::value)
|
||||
ref_gemm(BLIS_FLOAT , m, n, k, &alpha, A, B, &beta, C_ref);
|
||||
else if(is_same<T, double>::value)
|
||||
ref_gemm(BLIS_DOUBLE , m, n, k, &alpha, A, B, &beta, C_ref);
|
||||
else if(is_same<T, complex<float>>::value)
|
||||
ref_gemm(BLIS_SCOMPLEX , m, n, k, &alpha, A, B, &beta, C_ref);
|
||||
else if(is_same<T, complex<double>>::value)
|
||||
ref_gemm(BLIS_DCOMPLEX , m, n, k, &alpha, A, B, &beta, C_ref);
|
||||
|
||||
#ifdef PRINT
|
||||
bl_dgemm_printmatrix(C_ref, ldc_ref ,m,n);
|
||||
@@ -160,155 +155,20 @@ void test_dgemm( )
|
||||
printf("%s TEST PASS\n" , __func__);
|
||||
|
||||
|
||||
// Compute overall floating point operations.
|
||||
flops = ( m * n / ( 1000.0 * 1000.0 * 1000.0 ) ) * ( 2 * k );
|
||||
|
||||
printf( "%5d\t %5d\t %5d\t %5.2lf\t %5.2lf\n",
|
||||
m, n, k, flops / bl_dgemm_rectime, flops / ref_rectime );
|
||||
|
||||
free( A );
|
||||
free( B );
|
||||
free( C );
|
||||
free( C_ref );
|
||||
}
|
||||
void test_zgemm( )
|
||||
{
|
||||
int i, j, p;
|
||||
std::complex<double> *A, *B, *C, *C_ref;
|
||||
std::complex<double> alpha, beta;
|
||||
double flops;
|
||||
double ref_beg, ref_time, bl_dgemm_beg, bl_dgemm_time;
|
||||
int nrepeats;
|
||||
int m,n,k;
|
||||
int lda, ldb, ldc, ldc_ref;
|
||||
double ref_rectime, bl_dgemm_rectime;
|
||||
|
||||
alpha = 1.0;
|
||||
beta = 0.0;
|
||||
m = 5;
|
||||
k = 6;
|
||||
n = 4;
|
||||
|
||||
A = new complex<double>[m * k];
|
||||
B = new complex<double>[k * n];
|
||||
|
||||
lda = m;
|
||||
ldb = k;
|
||||
ldc = m;
|
||||
ldc_ref = m;
|
||||
C = new complex<double>[ldc * n];
|
||||
C_ref= new complex<double>[m * n];
|
||||
nrepeats = 3;
|
||||
|
||||
srand48 (time(NULL));
|
||||
|
||||
// Randonly generate points in [ 0, 1 ].
|
||||
for ( p = 0; p < k; p ++ ) {
|
||||
for ( i = 0; i < m; i ++ ) {
|
||||
A( i, p ) = (complex<double>)( drand48() );
|
||||
}
|
||||
}
|
||||
for ( j = 0; j < n; j ++ ) {
|
||||
for ( p = 0; p < k; p ++ ) {
|
||||
B( p, j ) = (complex<double>)( drand48() );
|
||||
}
|
||||
}
|
||||
|
||||
for ( j = 0; j < n; j ++ ) {
|
||||
for ( i = 0; i < m; i ++ ) {
|
||||
C_ref( i, j ) = (complex<double>)( 0.0 );
|
||||
C( i, j ) = (complex<double>)( 0.0 );
|
||||
}
|
||||
}
|
||||
#ifdef PRINT
|
||||
bl_dgemm_printmatrix(A, lda ,m,k);
|
||||
bl_dgemm_printmatrix(B, ldb ,k,n);
|
||||
bl_dgemm_printmatrix(C, ldc ,m,n);
|
||||
#endif
|
||||
for ( i = 0; i < nrepeats; i ++ ) {
|
||||
bl_dgemm_beg = bl_clock();
|
||||
blis::gemm(
|
||||
CblasColMajor,
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
A,
|
||||
lda,
|
||||
B,
|
||||
ldb,
|
||||
beta,
|
||||
C,
|
||||
ldc
|
||||
);
|
||||
|
||||
bl_dgemm_time = bl_clock() - bl_dgemm_beg;
|
||||
|
||||
|
||||
if ( i == 0 ) {
|
||||
bl_dgemm_rectime = bl_dgemm_time;
|
||||
} else {
|
||||
bl_dgemm_rectime = bl_dgemm_time < bl_dgemm_rectime ? bl_dgemm_time : bl_dgemm_rectime;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef PRINT
|
||||
bl_dgemm_printmatrix(C, ldc ,m,n);
|
||||
#endif
|
||||
for ( i = 0; i < nrepeats; i ++ ) {
|
||||
ref_beg = bl_clock();
|
||||
cblas_zgemm(
|
||||
CblasColMajor,
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
&alpha,
|
||||
A,
|
||||
lda,
|
||||
B,
|
||||
ldb,
|
||||
&beta,
|
||||
C_ref,
|
||||
ldc_ref);
|
||||
ref_time = bl_clock() - ref_beg;
|
||||
|
||||
if ( i == 0 ) {
|
||||
ref_rectime = ref_time;
|
||||
} else {
|
||||
ref_rectime = ref_time < ref_rectime ? ref_time : ref_rectime;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef PRINT
|
||||
bl_dgemm_printmatrix(C_ref, ldc_ref ,m,n);
|
||||
#endif
|
||||
if(computeError(ldc, ldc_ref, m, n, C, C_ref )==1)
|
||||
printf("%s TEST FAIL\n" ,__func__);
|
||||
else
|
||||
printf("%s TEST PASS\n" , __func__);
|
||||
|
||||
|
||||
// Compute overall floating point operations.
|
||||
flops = ( m * n / ( 1000.0 * 1000.0 * 1000.0 ) ) * ( 2 * k );
|
||||
|
||||
printf( "%5d\t %5d\t %5d\t %5.2lf\t %5.2lf\n",
|
||||
m, n, k, flops / bl_dgemm_rectime, flops / ref_rectime );
|
||||
|
||||
free( A );
|
||||
free( B );
|
||||
free( C );
|
||||
free( C_ref );
|
||||
delete[]( A );
|
||||
delete[]( B );
|
||||
delete[]( C );
|
||||
delete[]( C_ref );
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
int main( int argc, char** argv )
|
||||
{
|
||||
test_dgemm( );
|
||||
test_zgemm( );
|
||||
test_gemm<double>( );
|
||||
test_gemm<float>( );
|
||||
test_gemm<complex<float>>( );
|
||||
test_gemm<complex<double>>( );
|
||||
return 0;
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user