Added Doxygen Comment to all functions; Fixed Review comments; Modified test application to use template functions

Change-Id: I920c335776bc4597af1c988b538e8dda706195fa
This commit is contained in:
Chithra Sankar
2019-09-05 12:17:45 +05:30
parent c195d9a576
commit ce0b1caa7f
5 changed files with 2495 additions and 443 deletions

View File

@@ -1,7 +1,7 @@
/*
BLISPP
C++ test driver for BLIS CPP gemm routine and reference cblas gemm routine.
C++ test driver for BLIS CPP gemm routine and reference blis gemm routine.
Copyright (C) 2019, Advanced Micro Devices, Inc.
@@ -39,55 +39,77 @@
using namespace blis;
using namespace std;
#define PRINT
/*
* Test application assumes matrices to be column major, non-transposed
*/
void ref_gemm(num_t dt , int64_t m, int64_t n, int64_t k,
void * alpha,
void *A,
void *B,
void * beta,
void *C )
void test_dgemm( )
{
obj_t obj_a, obj_b, obj_c;
obj_t obj_alpha, obj_beta;
bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha );
bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta );
bli_obj_create_with_attached_buffer( dt, m, k, A, 1,m,&obj_a );
bli_obj_create_with_attached_buffer( dt, k, n, B,1,k,&obj_b );
bli_obj_create_with_attached_buffer( dt, m, n, C, 1,m,&obj_c );
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_a );
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_b );
bli_gemm( &obj_alpha,
&obj_a,
&obj_b,
&obj_beta,
&obj_c );
}
template< typename T >
void test_gemm( )
{
int i, j, p;
double *A, *B, *C, *C_ref;
double alpha, beta;
double flops;
double ref_beg, ref_time, bl_dgemm_beg, bl_dgemm_time;
int nrepeats;
T *A, *B, *C, *C_ref;
T alpha, beta;
int m,n,k;
int lda, ldb, ldc, ldc_ref;
double ref_rectime, bl_dgemm_rectime;
alpha = 1.0;
beta = 0.0;
m = 5;
k = 6;
n = 4;
k = 4;
n = 6;
A = new double[m * k];
B = new double[k * n];
A = new T[m * k];
B = new T[k * n];
lda = m;
ldb = k;
ldc = m;
ldc_ref = m;
C = new double[ldc * n];
C_ref= new double[m * n];
nrepeats = 3;
C = new T[ldc * n];
C_ref= new T[m * n];
srand48 (time(NULL));
// Randonly generate points in [ 0, 1 ].
for ( p = 0; p < k; p ++ ) {
for ( i = 0; i < m; i ++ ) {
A( i, p ) = (double)( drand48() );
A( i, p ) = (T)( drand48() );
}
}
for ( j = 0; j < n; j ++ ) {
for ( p = 0; p < k; p ++ ) {
B( p, j ) = (double)( drand48() );
B( p, j ) = (T)( drand48() );
}
}
for ( j = 0; j < n; j ++ ) {
for ( i = 0; i < m; i ++ ) {
C_ref( i, j ) = (double)( 0.0 );
C( i, j ) = (double)( 0.0 );
C_ref( i, j ) = (T)( 0.0 );
C( i, j ) = (T)( 0.0 );
}
}
#ifdef PRINT
@@ -95,8 +117,6 @@ void test_dgemm( )
bl_dgemm_printmatrix(B, ldb ,k,n);
bl_dgemm_printmatrix(C, ldc ,m,n);
#endif
for ( i = 0; i < nrepeats; i ++ ) {
bl_dgemm_beg = bl_clock();
blis::gemm(
CblasColMajor,
CblasNoTrans,
@@ -113,43 +133,18 @@ void test_dgemm( )
C,
ldc
);
bl_dgemm_time = bl_clock() - bl_dgemm_beg;
if ( i == 0 ) {
bl_dgemm_rectime = bl_dgemm_time;
} else {
bl_dgemm_rectime = bl_dgemm_time < bl_dgemm_rectime ? bl_dgemm_time : bl_dgemm_rectime;
}
}
#ifdef PRINT
bl_dgemm_printmatrix(C, ldc ,m,n);
#endif
for ( i = 0; i < nrepeats; i ++ ) {
ref_beg = bl_clock();
cblas_dgemm(
CblasColMajor,
CblasNoTrans,
CblasNoTrans,
m,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C_ref,
ldc_ref);
ref_time = bl_clock() - ref_beg;
if ( i == 0 ) {
ref_rectime = ref_time;
} else {
ref_rectime = ref_time < ref_rectime ? ref_time : ref_rectime;
}
}
if(is_same<T, float>::value)
ref_gemm(BLIS_FLOAT , m, n, k, &alpha, A, B, &beta, C_ref);
else if(is_same<T, double>::value)
ref_gemm(BLIS_DOUBLE , m, n, k, &alpha, A, B, &beta, C_ref);
else if(is_same<T, complex<float>>::value)
ref_gemm(BLIS_SCOMPLEX , m, n, k, &alpha, A, B, &beta, C_ref);
else if(is_same<T, complex<double>>::value)
ref_gemm(BLIS_DCOMPLEX , m, n, k, &alpha, A, B, &beta, C_ref);
#ifdef PRINT
bl_dgemm_printmatrix(C_ref, ldc_ref ,m,n);
@@ -160,155 +155,20 @@ void test_dgemm( )
printf("%s TEST PASS\n" , __func__);
// Compute overall floating point operations.
flops = ( m * n / ( 1000.0 * 1000.0 * 1000.0 ) ) * ( 2 * k );
printf( "%5d\t %5d\t %5d\t %5.2lf\t %5.2lf\n",
m, n, k, flops / bl_dgemm_rectime, flops / ref_rectime );
free( A );
free( B );
free( C );
free( C_ref );
}
void test_zgemm( )
{
int i, j, p;
std::complex<double> *A, *B, *C, *C_ref;
std::complex<double> alpha, beta;
double flops;
double ref_beg, ref_time, bl_dgemm_beg, bl_dgemm_time;
int nrepeats;
int m,n,k;
int lda, ldb, ldc, ldc_ref;
double ref_rectime, bl_dgemm_rectime;
alpha = 1.0;
beta = 0.0;
m = 5;
k = 6;
n = 4;
A = new complex<double>[m * k];
B = new complex<double>[k * n];
lda = m;
ldb = k;
ldc = m;
ldc_ref = m;
C = new complex<double>[ldc * n];
C_ref= new complex<double>[m * n];
nrepeats = 3;
srand48 (time(NULL));
// Randonly generate points in [ 0, 1 ].
for ( p = 0; p < k; p ++ ) {
for ( i = 0; i < m; i ++ ) {
A( i, p ) = (complex<double>)( drand48() );
}
}
for ( j = 0; j < n; j ++ ) {
for ( p = 0; p < k; p ++ ) {
B( p, j ) = (complex<double>)( drand48() );
}
}
for ( j = 0; j < n; j ++ ) {
for ( i = 0; i < m; i ++ ) {
C_ref( i, j ) = (complex<double>)( 0.0 );
C( i, j ) = (complex<double>)( 0.0 );
}
}
#ifdef PRINT
bl_dgemm_printmatrix(A, lda ,m,k);
bl_dgemm_printmatrix(B, ldb ,k,n);
bl_dgemm_printmatrix(C, ldc ,m,n);
#endif
for ( i = 0; i < nrepeats; i ++ ) {
bl_dgemm_beg = bl_clock();
blis::gemm(
CblasColMajor,
CblasNoTrans,
CblasNoTrans,
m,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc
);
bl_dgemm_time = bl_clock() - bl_dgemm_beg;
if ( i == 0 ) {
bl_dgemm_rectime = bl_dgemm_time;
} else {
bl_dgemm_rectime = bl_dgemm_time < bl_dgemm_rectime ? bl_dgemm_time : bl_dgemm_rectime;
}
}
#ifdef PRINT
bl_dgemm_printmatrix(C, ldc ,m,n);
#endif
for ( i = 0; i < nrepeats; i ++ ) {
ref_beg = bl_clock();
cblas_zgemm(
CblasColMajor,
CblasNoTrans,
CblasNoTrans,
m,
n,
k,
&alpha,
A,
lda,
B,
ldb,
&beta,
C_ref,
ldc_ref);
ref_time = bl_clock() - ref_beg;
if ( i == 0 ) {
ref_rectime = ref_time;
} else {
ref_rectime = ref_time < ref_rectime ? ref_time : ref_rectime;
}
}
#ifdef PRINT
bl_dgemm_printmatrix(C_ref, ldc_ref ,m,n);
#endif
if(computeError(ldc, ldc_ref, m, n, C, C_ref )==1)
printf("%s TEST FAIL\n" ,__func__);
else
printf("%s TEST PASS\n" , __func__);
// Compute overall floating point operations.
flops = ( m * n / ( 1000.0 * 1000.0 * 1000.0 ) ) * ( 2 * k );
printf( "%5d\t %5d\t %5d\t %5.2lf\t %5.2lf\n",
m, n, k, flops / bl_dgemm_rectime, flops / ref_rectime );
free( A );
free( B );
free( C );
free( C_ref );
delete[]( A );
delete[]( B );
delete[]( C );
delete[]( C_ref );
}
// -----------------------------------------------------------------------------
int main( int argc, char** argv )
{
test_dgemm( );
test_zgemm( );
test_gemm<double>( );
test_gemm<float>( );
test_gemm<complex<float>>( );
test_gemm<complex<double>>( );
return 0;
}