Files
blis/test/syrk_diagonal/complex_math.hpp
Devin Matthews 54fa28bd84 Move edge cases to gemm ukr; more user-custom mods. (#583)
Details:
- Moved edge-case handling into the gemm microkernel. This required
  changing the microkernel API to take m and n dimension parameters.
  This required updating all existing gemm microkernel function pointer
  types, function signatures, and related definitions to take m and n
  dimensions. We also updated all existing kernels in the 'kernels' 
  directory to take m and n dimensions, and implemented edge-case 
  handling within those microkernels via a collection of new C 
  preprocessor macros defined within bli_edge_case_macro_defs.h. Also
  removed the assembly code that formerly would handle general stride 
  IO on the microtile, since this can now be handled by the same code
  that does edge cases.
- Pass the obj_t.ker_fn (of matrix C) into bli_gemm_cntl_create() and
  bli_trsm_cntl_create(), where this function pointer is used in lieu of 
  the default macrokernel when it is non-NULL, and ignored when it is
  NULL.
- Re-implemented macrokernel in bli_gemm_ker_var2.c to be a single
  function using byte pointers rather that one function for each
  floating-point datatype. Also, obtain the microkernel function pointer
  from the .ukr field of the params struct embedded within the obj_t
  for matrix C (assuming params is non-NULL and contains a non-NULL
  value in the .ukr field). Communicate both the gemm microkernel
  pointer to use as well as the params struct to the microkernel via
  the auxinfo_t struct.
- Defined gemm_ker_params_t type (for the aforementioned obj_t.params 
  struct) in bli_gemm_var.h.
- Retired the separate _md macrokernel for mixed datatype computation.
  We now use the reimplemented bli_gemm_ker_var2() instead.
- Updated gemmt macrokernels to pass m and n dimensions into microkernel
  calls.
- Removed edge-case handling from trmm and trsm macrokernels.
- Moved most of bli_packm_alloc() code into a new helper function,
  bli_packm_alloc_ex().
- Fixed a typo bug in bli_gemmtrsm_u_template_noopt_mxn.c.
- Added test/syrk_diagonal and test/tensor_contraction directories with
  associated code to test those operations.
2021-12-24 08:00:33 -06:00

268 lines
6.3 KiB
C++

#include <cmath>
#include <algorithm>
#include <type_traits>
#include "blis.h"
template <typename T>
struct is_complex : std::false_type {};
template <>
struct is_complex<scomplex> : std::true_type {};
template <>
struct is_complex<dcomplex> : std::true_type {};
template <typename T>
struct is_real : std::integral_constant<bool,!is_complex<T>::value> {};
template <typename T> struct make_complex;
template <> struct make_complex<float > { using type = scomplex; };
template <> struct make_complex<double > { using type = dcomplex; };
template <> struct make_complex<scomplex> { using type = scomplex; };
template <> struct make_complex<dcomplex> { using type = dcomplex; };
template <typename T>
using make_complex_t = typename make_complex<T>::type;
template <typename T> struct make_real;
template <> struct make_real<float > { using type = float; };
template <> struct make_real<double > { using type = double; };
template <> struct make_real<scomplex> { using type = float; };
template <> struct make_real<dcomplex> { using type = double; };
template <typename T>
using make_real_t = typename make_real<T>::type;
template <typename T, bool Cond>
struct make_complex_if : std::conditional<Cond,make_complex_t<T>,make_real_t<T>> {};
template <typename T, bool Cond>
using make_complex_if_t = typename make_complex_if<T,Cond>::type;
template <typename T>
struct real_imag_part
{
real_imag_part& operator=(T) { return *this; }
operator T() const { return T(); }
};
template <typename T>
std::enable_if_t<std::is_arithmetic<typename std::remove_cv<T>::type>::value,T&> real(T& x) { return x; }
template <typename T>
std::enable_if_t<std::is_arithmetic<T>::value,real_imag_part<T>> imag(T x) { return {}; }
inline float& real(scomplex& x) { return x.real; }
inline float& imag(scomplex& x) { return x.imag; }
inline double& real(dcomplex& x) { return x.real; }
inline double& imag(dcomplex& x) { return x.imag; }
inline const float& real(const scomplex& x) { return x.real; }
inline const float& imag(const scomplex& x) { return x.imag; }
inline const double& real(const dcomplex& x) { return x.real; }
inline const double& imag(const dcomplex& x) { return x.imag; }
template <typename T>
std::enable_if_t<is_real<T>::value,T> conj(T x) { return x; }
template <typename T>
std::enable_if_t<is_complex<T>::value,T> conj(const T& x) { return {x.real, -x.imag}; }
template <typename T, typename U, typename=void>
struct convert_impl;
template <typename T, typename U>
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_real<U>::value>>
{
void operator()(T x, U& y) const { y = x; }
};
template <typename T, typename U>
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_complex<U>::value>>
{
void operator()(T x, U& y) const { y.real = x; y.imag = 0; }
};
template <typename T, typename U>
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_real<U>::value>>
{
void operator()(T x, U& y) const { y = x.real; }
};
template <typename T, typename U>
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_complex<U>::value>>
{
void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; }
};
template <typename U, typename T>
U convert(T x)
{
U y;
convert_impl<T,U>{}(x,y);
return y;
}
template <typename U, typename T>
auto convert_prec(T x) -> make_complex_if_t<U,is_complex<T>::value>
{
return convert<make_complex_if_t<U,is_complex<T>::value>>(x);
}
#define COMPLEX_MATH_OPS(rtype, ctype) \
\
inline bool operator==(rtype x, ctype y) \
{ \
return x == y.real && y.imag == 0; \
} \
\
inline bool operator==(ctype x, rtype y) \
{ \
return y == x.real && x.imag == 0; \
} \
\
inline bool operator==(ctype x, ctype y) \
{ \
return x.real == y.real && \
x.imag == y.imag; \
} \
\
inline ctype operator-(ctype x) \
{ \
return {-x.real, -x.imag}; \
} \
\
inline ctype operator+(rtype x, ctype y) \
{ \
return {x+y.real, y.imag}; \
} \
\
inline ctype operator+(ctype x, rtype y) \
{ \
return {y+x.real, x.imag}; \
} \
\
inline ctype operator+(ctype x, ctype y) \
{ \
return {x.real+y.real, x.imag+y.imag}; \
} \
\
inline ctype operator-(rtype x, ctype y) \
{ \
return {x-y.real, -y.imag}; \
} \
\
inline ctype operator-(ctype x, rtype y) \
{ \
return {x.real-y, x.imag}; \
} \
\
inline ctype operator-(ctype x, ctype y) \
{ \
return {x.real-y.real, x.imag-y.imag}; \
} \
\
inline ctype operator*(rtype x, ctype y) \
{ \
return {x*y.real, x*y.imag}; \
} \
\
inline ctype operator*(ctype x, rtype y) \
{ \
return {y*x.real, y*x.imag}; \
} \
\
inline ctype operator*(ctype x, ctype y) \
{ \
return {x.real*y.real - x.imag*y.imag, \
x.real*y.imag + x.imag*y.real}; \
} \
\
inline ctype operator/(rtype x, ctype y) \
{ \
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
auto n = std::ilogb(scale); \
auto yrs = std::scalbn(y.real, -n); \
auto yis = std::scalbn(y.imag, -n); \
auto denom = y.real*yrs + y.imag*yis; \
return {x*yrs/denom, -x*yis/denom}; \
} \
\
inline ctype operator/(ctype x, rtype y) \
{ \
return {x.real/y, x.imag/y}; \
} \
\
inline ctype operator/(ctype x, ctype y) \
{ \
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
auto n = std::ilogb(scale); \
auto yrs = std::scalbn(y.real, -n); \
auto yis = std::scalbn(y.imag, -n); \
auto denom = y.real*yrs + y.imag*yis; \
return {(x.real*yrs + x.imag*yis)/denom, \
(x.imag*yrs - x.real*yis)/denom}; \
} \
\
inline ctype& operator+=(ctype& x, rtype y) \
{ \
x.real += y; \
return x; \
} \
\
inline ctype& operator+=(ctype& x, ctype y) \
{ \
x.real += y.real; x.imag += y.imag; \
return x; \
} \
\
inline ctype& operator-=(ctype& x, rtype y) \
{ \
x.real -= y; \
return x; \
} \
\
inline ctype& operator-=(ctype& x, ctype y) \
{ \
x.real -= y.real; x.imag -= y.imag; \
return x; \
} \
\
inline ctype& operator*=(ctype& x, rtype y) \
{ \
x.real *= y; x.imag *= y; \
return x; \
} \
\
inline ctype& operator*=(ctype& x, ctype y) \
{ \
x = x * y; \
return x; \
} \
\
inline ctype& operator/=(ctype& x, rtype y) \
{ \
x.real /= y; x.imag /= y; \
return x; \
} \
\
inline ctype& operator/=(ctype& x, ctype y) \
{ \
x = x / y; \
return x; \
}
COMPLEX_MATH_OPS(float, scomplex);
COMPLEX_MATH_OPS(double, dcomplex);