Files
blis/test/tensor_contraction/tcontract_ref.hpp
Devin Matthews 54fa28bd84 Move edge cases to gemm ukr; more user-custom mods. (#583)
Details:
- Moved edge-case handling into the gemm microkernel. This required
  changing the microkernel API to take m and n dimension parameters.
  This required updating all existing gemm microkernel function pointer
  types, function signatures, and related definitions to take m and n
  dimensions. We also updated all existing kernels in the 'kernels' 
  directory to take m and n dimensions, and implemented edge-case 
  handling within those microkernels via a collection of new C 
  preprocessor macros defined within bli_edge_case_macro_defs.h. Also
  removed the assembly code that formerly would handle general stride 
  IO on the microtile, since this can now be handled by the same code
  that does edge cases.
- Pass the obj_t.ker_fn (of matrix C) into bli_gemm_cntl_create() and
  bli_trsm_cntl_create(), where this function pointer is used in lieu of 
  the default macrokernel when it is non-NULL, and ignored when it is
  NULL.
- Re-implemented macrokernel in bli_gemm_ker_var2.c to be a single
  function using byte pointers rather that one function for each
  floating-point datatype. Also, obtain the microkernel function pointer
  from the .ukr field of the params struct embedded within the obj_t
  for matrix C (assuming params is non-NULL and contains a non-NULL
  value in the .ukr field). Communicate both the gemm microkernel
  pointer to use as well as the params struct to the microkernel via
  the auxinfo_t struct.
- Defined gemm_ker_params_t type (for the aforementioned obj_t.params 
  struct) in bli_gemm_var.h.
- Retired the separate _md macrokernel for mixed datatype computation.
  We now use the reimplemented bli_gemm_ker_var2() instead.
- Updated gemmt macrokernels to pass m and n dimensions into microkernel
  calls.
- Removed edge-case handling from trmm and trsm macrokernels.
- Moved most of bli_packm_alloc() code into a new helper function,
  bli_packm_alloc_ex().
- Fixed a typo bug in bli_gemmtrsm_u_template_noopt_mxn.c.
- Added test/syrk_diagonal and test/tensor_contraction directories with
  associated code to test those operations.
2021-12-24 08:00:33 -06:00

101 lines
2.7 KiB
C++

#include "blis.h"
#include "complex_math.hpp"
#include <vector>
#include <array>
#include <cassert>
inline void increment(inc_t, gint_t) {}
template <typename T, typename... Args>
void increment(inc_t n, gint_t i, T& off, const inc_t* s, Args&... args)
{
off += s[i]*n;
increment(n, i, args...);
}
template <typename Body, typename... Args>
void for_each_impl(gint_t ndim, const dim_t* n,
dim_t off, dim_t len,
Body& body,
Args&... args)
{
std::array<dim_t,8> i = {};
assert( ndim <= i.size() );
if ( off )
{
for ( gint_t k = 0; k < ndim; k++ )
{
i[k] = off % n[k];
off /= n[k];
increment(i[k], k, args...);
}
}
for ( dim_t pos = 0; pos < len; pos++ )
{
body();
for ( gint_t k = 0; k < ndim; k++ )
{
if ( i[k] == n[k]-1 )
{
increment(-i[k], k, args...);
i[k] = 0;
}
else
{
increment(1, k, args...);
i[k]++;
break;
}
}
}
}
template <typename T, typename Body>
void for_each(gint_t ndim, const dim_t* n,
dim_t off, dim_t len,
T& a, const inc_t* s_a,
Body&& body)
{
for_each_impl( ndim, n, off, len, body, a, s_a );
}
template <typename T, typename Body>
void for_each(gint_t ndim, const dim_t* n,
dim_t off, dim_t len,
T& a, const inc_t* s_a,
T& b, const inc_t* s_b,
Body&& body)
{
for_each_impl( ndim, n, off, len, body, a, s_a, b, s_b );
}
template <typename T, typename Body>
void for_each(gint_t ndim, const dim_t* n,
T& a, const inc_t* s_a,
Body&& body)
{
dim_t len = 1;
for ( gint_t i = 0;i < ndim;i++ ) len *= n[i];
for_each_impl( ndim, n, 0, len, body, a, s_a );
}
template <typename T, typename Body>
void for_each(gint_t ndim, const dim_t* n,
T& a, const inc_t* s_a,
T& b, const inc_t* s_b,
Body&& body)
{
dim_t len = 1;
for ( gint_t i = 0;i < ndim;i++ ) len *= n[i];
for_each_impl( ndim, n, 0, len, body, a, s_a, b, s_b );
}
void tcontract_ref( num_t dt, const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
const void* alpha, const void* a, const std::vector<inc_t>& rs_a, const std::vector<inc_t>& cs_a,
const void* b, const std::vector<inc_t>& rs_b, const std::vector<inc_t>& cs_b,
const void* beta, void* c, const std::vector<inc_t>& rs_c, const std::vector<inc_t>& cs_c );