From be25ec0065eff77b136fca07b94e9049b5efc70f Mon Sep 17 00:00:00 2001 From: Chithra Sankar Date: Fri, 20 Sep 2019 11:52:55 +0530 Subject: [PATCH] CPP Implementtaion of dsdot included. Test application refactored to include review comments Change-Id: Iec0b973c23a2825e61f2ec9da236b3aea327d98a --- cpp/blis.hh | 21 ++++++++++-------- cpp/cblas.hh | 10 --------- testcpp/test_gemm.cc | 51 ++++++++++++++++++++++++++------------------ 3 files changed, 42 insertions(+), 40 deletions(-) diff --git a/cpp/blis.hh b/cpp/blis.hh index f4d28798a..00a69a1eb 100644 --- a/cpp/blis.hh +++ b/cpp/blis.hh @@ -366,7 +366,7 @@ void axpy( cblas_axpy( n, alpha, x, incx, y, incy ); } -/*! \brief Performs forms the dot product of two vectors for arbitrary data types +/*! \brief Performs the dot product of two vectors for arbitrary data types \b Purpose: @@ -400,16 +400,19 @@ void axpy( \return Unconjugated dot product, x^T * y. REAL/DOUBLE PRECISION */ -template< typename TX, typename TY > -TY dot( +template< typename T, typename TR > +TR dot( int64_t n, - TX const *x, int64_t incx, - TX const *y, int64_t incy ) + T const *x, int64_t incx, + T const *y, int64_t incy ) { - return cblas_dot( n, x, incx, y, incy ); + if((std::is_same::value)&(std::is_same::value)) + return cblas_dsdot( n, x, incx, y, incy ); + else + return cblas_dot( n, x, incx, y, incy ); } -/*! \brief Performs forms the dot product of two complex vectors +/*! \brief Performs the dot product of two complex vectors \b Purpose: @@ -451,7 +454,7 @@ T dotu( return cblas_dotu( n, x, incx, y, incy ); } -/*! \brief Performs forms the dot product of two complex vectors +/*! \brief Performs the dot product of two complex vectors \b Purpose: @@ -2982,7 +2985,7 @@ void gemm( cblas_gemm(layout, transA, transB, m, n, k, alpha, A,lda, B, ldb, beta, C, ldc); } -/*! \brief Solve the triangular matrix-vector equation for arbitrary data types +/*! \brief Solve the triangular matrix-matrix equation for arbitrary data types \b Purpose: diff --git a/cpp/cblas.hh b/cpp/cblas.hh index f7bf16647..337ae6878 100644 --- a/cpp/cblas.hh +++ b/cpp/cblas.hh @@ -290,16 +290,6 @@ cblas_dot( { return cblas_ddot( n, x, incx, y, incy ); } -#if 0 -inline double -cblas_dot( - int n, - float const *x, int incx, - float const *y, int incy ) -{ - return cblas_dsdot( n, x, incx, y, incy ); -} -#endif // ----------------------------------------------------------------------------- inline std::complex cblas_dotu( diff --git a/testcpp/test_gemm.cc b/testcpp/test_gemm.cc index 4f24c0f44..db0f5155e 100644 --- a/testcpp/test_gemm.cc +++ b/testcpp/test_gemm.cc @@ -39,19 +39,35 @@ using namespace blis; using namespace std; #define PRINT +#define ALPHA 1.0 +#define BETA 0.0 +#define M 5 +#define N 6 +#define K 4 + /* * Test application assumes matrices to be column major, non-transposed */ -void ref_gemm(num_t dt , int64_t m, int64_t n, int64_t k, - void * alpha, - void *A, - void *B, - void * beta, - void *C ) +template< typename T > +void ref_gemm(int64_t m, int64_t n, int64_t k, + T * alpha, + T *A, + T *B, + T * beta, + T *C ) { obj_t obj_a, obj_b, obj_c; obj_t obj_alpha, obj_beta; + num_t dt; + if(is_same::value) + dt = BLIS_FLOAT; + else if(is_same::value) + dt = BLIS_DOUBLE; + else if(is_same>::value) + dt = BLIS_SCOMPLEX; + else if(is_same>::value) + dt = BLIS_DCOMPLEX; bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); bli_obj_create_with_attached_buffer( dt, m, k, A, 1,m,&obj_a ); @@ -76,11 +92,11 @@ void test_gemm( ) int m,n,k; int lda, ldb, ldc, ldc_ref; - alpha = 1.0; - beta = 0.0; - m = 5; - k = 4; - n = 6; + alpha = ALPHA; + beta = BETA; + m = M; + k = K; + n = N; A = new T[m * k]; B = new T[k * n]; @@ -137,22 +153,15 @@ void test_gemm( ) #ifdef PRINT bl_dgemm_printmatrix(C, ldc ,m,n); #endif - if(is_same::value) - ref_gemm(BLIS_FLOAT , m, n, k, &alpha, A, B, &beta, C_ref); - else if(is_same::value) - ref_gemm(BLIS_DOUBLE , m, n, k, &alpha, A, B, &beta, C_ref); - else if(is_same>::value) - ref_gemm(BLIS_SCOMPLEX , m, n, k, &alpha, A, B, &beta, C_ref); - else if(is_same>::value) - ref_gemm(BLIS_DCOMPLEX , m, n, k, &alpha, A, B, &beta, C_ref); + ref_gemm(m, n, k, &alpha, A, B, &beta, C_ref); #ifdef PRINT bl_dgemm_printmatrix(C_ref, ldc_ref ,m,n); #endif if(computeError(ldc, ldc_ref, m, n, C, C_ref )==1) - printf("%s TEST FAIL\n" ,__func__); + printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); else - printf("%s TEST PASS\n" , __func__); + printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ );