mirror of
https://github.com/amd/blis.git
synced 2026-05-14 03:02:08 +00:00
Details: - Moved edge-case handling into the gemm microkernel. This required changing the microkernel API to take m and n dimension parameters. This required updating all existing gemm microkernel function pointer types, function signatures, and related definitions to take m and n dimensions. We also updated all existing kernels in the 'kernels' directory to take m and n dimensions, and implemented edge-case handling within those microkernels via a collection of new C preprocessor macros defined within bli_edge_case_macro_defs.h. Also removed the assembly code that formerly would handle general stride IO on the microtile, since this can now be handled by the same code that does edge cases. - Pass the obj_t.ker_fn (of matrix C) into bli_gemm_cntl_create() and bli_trsm_cntl_create(), where this function pointer is used in lieu of the default macrokernel when it is non-NULL, and ignored when it is NULL. - Re-implemented macrokernel in bli_gemm_ker_var2.c to be a single function using byte pointers rather that one function for each floating-point datatype. Also, obtain the microkernel function pointer from the .ukr field of the params struct embedded within the obj_t for matrix C (assuming params is non-NULL and contains a non-NULL value in the .ukr field). Communicate both the gemm microkernel pointer to use as well as the params struct to the microkernel via the auxinfo_t struct. - Defined gemm_ker_params_t type (for the aforementioned obj_t.params struct) in bli_gemm_var.h. - Retired the separate _md macrokernel for mixed datatype computation. We now use the reimplemented bli_gemm_ker_var2() instead. - Updated gemmt macrokernels to pass m and n dimensions into microkernel calls. - Removed edge-case handling from trmm and trsm macrokernels. - Moved most of bli_packm_alloc() code into a new helper function, bli_packm_alloc_ex(). - Fixed a typo bug in bli_gemmtrsm_u_template_noopt_mxn.c. - Added test/syrk_diagonal and test/tensor_contraction directories with associated code to test those operations.
101 lines
2.7 KiB
C++
101 lines
2.7 KiB
C++
#include "blis.h"
|
|
#include "complex_math.hpp"
|
|
|
|
#include <vector>
|
|
#include <array>
|
|
#include <cassert>
|
|
|
|
inline void increment(inc_t, gint_t) {}
|
|
|
|
template <typename T, typename... Args>
|
|
void increment(inc_t n, gint_t i, T& off, const inc_t* s, Args&... args)
|
|
{
|
|
off += s[i]*n;
|
|
increment(n, i, args...);
|
|
}
|
|
|
|
template <typename Body, typename... Args>
|
|
void for_each_impl(gint_t ndim, const dim_t* n,
|
|
dim_t off, dim_t len,
|
|
Body& body,
|
|
Args&... args)
|
|
{
|
|
std::array<dim_t,8> i = {};
|
|
assert( ndim <= i.size() );
|
|
|
|
if ( off )
|
|
{
|
|
for ( gint_t k = 0; k < ndim; k++ )
|
|
{
|
|
i[k] = off % n[k];
|
|
off /= n[k];
|
|
increment(i[k], k, args...);
|
|
}
|
|
}
|
|
|
|
for ( dim_t pos = 0; pos < len; pos++ )
|
|
{
|
|
body();
|
|
|
|
for ( gint_t k = 0; k < ndim; k++ )
|
|
{
|
|
if ( i[k] == n[k]-1 )
|
|
{
|
|
increment(-i[k], k, args...);
|
|
i[k] = 0;
|
|
}
|
|
else
|
|
{
|
|
increment(1, k, args...);
|
|
i[k]++;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, typename Body>
|
|
void for_each(gint_t ndim, const dim_t* n,
|
|
dim_t off, dim_t len,
|
|
T& a, const inc_t* s_a,
|
|
Body&& body)
|
|
{
|
|
for_each_impl( ndim, n, off, len, body, a, s_a );
|
|
}
|
|
|
|
template <typename T, typename Body>
|
|
void for_each(gint_t ndim, const dim_t* n,
|
|
dim_t off, dim_t len,
|
|
T& a, const inc_t* s_a,
|
|
T& b, const inc_t* s_b,
|
|
Body&& body)
|
|
{
|
|
for_each_impl( ndim, n, off, len, body, a, s_a, b, s_b );
|
|
}
|
|
|
|
template <typename T, typename Body>
|
|
void for_each(gint_t ndim, const dim_t* n,
|
|
T& a, const inc_t* s_a,
|
|
Body&& body)
|
|
{
|
|
dim_t len = 1;
|
|
for ( gint_t i = 0;i < ndim;i++ ) len *= n[i];
|
|
for_each_impl( ndim, n, 0, len, body, a, s_a );
|
|
}
|
|
|
|
template <typename T, typename Body>
|
|
void for_each(gint_t ndim, const dim_t* n,
|
|
T& a, const inc_t* s_a,
|
|
T& b, const inc_t* s_b,
|
|
Body&& body)
|
|
{
|
|
dim_t len = 1;
|
|
for ( gint_t i = 0;i < ndim;i++ ) len *= n[i];
|
|
for_each_impl( ndim, n, 0, len, body, a, s_a, b, s_b );
|
|
}
|
|
|
|
void tcontract_ref( num_t dt, const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
|
|
const void* alpha, const void* a, const std::vector<inc_t>& rs_a, const std::vector<inc_t>& cs_a,
|
|
const void* b, const std::vector<inc_t>& rs_b, const std::vector<inc_t>& cs_b,
|
|
const void* beta, void* c, const std::vector<inc_t>& rs_c, const std::vector<inc_t>& cs_c );
|