CPP Implementtaion of dsdot included. Test application refactored to include review comments

Change-Id: Iec0b973c23a2825e61f2ec9da236b3aea327d98a
This commit is contained in:
Chithra Sankar
2019-09-20 11:52:55 +05:30
parent ce0b1caa7f
commit be25ec0065
3 changed files with 42 additions and 40 deletions

View File

@@ -39,19 +39,35 @@
using namespace blis;
using namespace std;
#define PRINT
#define ALPHA 1.0
#define BETA 0.0
#define M 5
#define N 6
#define K 4
/*
* Test application assumes matrices to be column major, non-transposed
*/
void ref_gemm(num_t dt , int64_t m, int64_t n, int64_t k,
void * alpha,
void *A,
void *B,
void * beta,
void *C )
template< typename T >
void ref_gemm(int64_t m, int64_t n, int64_t k,
T * alpha,
T *A,
T *B,
T * beta,
T *C )
{
obj_t obj_a, obj_b, obj_c;
obj_t obj_alpha, obj_beta;
num_t dt;
if(is_same<T, float>::value)
dt = BLIS_FLOAT;
else if(is_same<T, double>::value)
dt = BLIS_DOUBLE;
else if(is_same<T, complex<float>>::value)
dt = BLIS_SCOMPLEX;
else if(is_same<T, complex<double>>::value)
dt = BLIS_DCOMPLEX;
bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha );
bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta );
bli_obj_create_with_attached_buffer( dt, m, k, A, 1,m,&obj_a );
@@ -76,11 +92,11 @@ void test_gemm( )
int m,n,k;
int lda, ldb, ldc, ldc_ref;
alpha = 1.0;
beta = 0.0;
m = 5;
k = 4;
n = 6;
alpha = ALPHA;
beta = BETA;
m = M;
k = K;
n = N;
A = new T[m * k];
B = new T[k * n];
@@ -137,22 +153,15 @@ void test_gemm( )
#ifdef PRINT
bl_dgemm_printmatrix(C, ldc ,m,n);
#endif
if(is_same<T, float>::value)
ref_gemm(BLIS_FLOAT , m, n, k, &alpha, A, B, &beta, C_ref);
else if(is_same<T, double>::value)
ref_gemm(BLIS_DOUBLE , m, n, k, &alpha, A, B, &beta, C_ref);
else if(is_same<T, complex<float>>::value)
ref_gemm(BLIS_SCOMPLEX , m, n, k, &alpha, A, B, &beta, C_ref);
else if(is_same<T, complex<double>>::value)
ref_gemm(BLIS_DCOMPLEX , m, n, k, &alpha, A, B, &beta, C_ref);
ref_gemm(m, n, k, &alpha, A, B, &beta, C_ref);
#ifdef PRINT
bl_dgemm_printmatrix(C_ref, ldc_ref ,m,n);
#endif
if(computeError(ldc, ldc_ref, m, n, C, C_ref )==1)
printf("%s TEST FAIL\n" ,__func__);
printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ );
else
printf("%s TEST PASS\n" , __func__);
printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ );