Implement proposed new function pointer fields for obj_t.

The added fields:
1. `pack_t schema`: storing the pack schema on the object allows the macrokernel to act accordingly without side-channel information from the rntm_t and cntx_t. The pack schema and "pack_[ab]" fields could be removed from those structs.
2. `void* user_data`: this field can be used to store any sort of additional information provided by the user. The pointer is propagated to submatrix objects and copies, but is otherwise ignored by the framework and the default implementations of the following three fields. User-specified pack, kernel, or ukr functions can do whatever they want with the data, and the user is 100% responsible for allocating, assigning, and freeing this buffer.
3. `obj_pack_fn_t pack`: the function called when a matrix is packed. This functions receives the expected arguments, as well as a mdim_t and mem_t* as memory must be allocated inside this function, and behavior may differ based on which matrix is being backed (i.e. transposition for B). This could also be achieved by passing a desired pack schema, but this would require additional information to travel down the control tree.
4. `obj_ker_fn_t ker`: the function called when we get to the "second loop", or the macro-kernel. Behavior may depend on the pack schemas of the input matrices. The default implementation would perform the inner two loops around the ukr, and then call either the default ukr or a user-supplied one (next field).
5. `obj_ukr_fn_t ukr`: the function called by the default macrokernel. This would replace the various current "virtual" microkernels, and could also be used to supply user-defined behavior. Users could supply both a custom kernel (above) and microkernel, although the user-specified kernel does **not** necessarily have to call the ukr function specified on the obj_t.

Note that no macros or functions for accessing these new fields have been defined yet. That is next once these are finalized. Addresses https://github.com/flame/blis/projects/1#card-62357687.
This commit is contained in:
Devin Matthews
2021-08-11 17:53:12 -05:00
parent 84f9dcd449
commit 64a1f786d5

View File

@@ -150,7 +150,7 @@ typedef uint32_t objbits_t; // object information bit field
// interoperability with BLIS.
#ifndef _DEFINED_SCOMPLEX
#define _DEFINED_SCOMPLEX
typedef struct
typedef struct scomplex
{
float real;
float imag;
@@ -161,7 +161,7 @@ typedef uint32_t objbits_t; // object information bit field
// interoperability with BLIS.
#ifndef _DEFINED_DCOMPLEX
#define _DEFINED_DCOMPLEX
typedef struct
typedef struct dcomplex
{
double real;
double imag;
@@ -1232,6 +1232,47 @@ typedef struct constdata_s
// -- BLIS object type definitions ---------------------------------------------
//
// Forward declarations for function pointer types
struct obj_s;
struct cntx_s;
struct rntm_s;
struct thrinfo_s;
typedef void (*obj_pack_fn_t)
(
mdim_t mat,
mem_t* mem,
struct obj_s* a,
struct obj_s* ap,
struct cntx_s* cntx,
struct rntm_s* rntm,
struct thrinfo_s* thread
);
typedef void (*obj_ker_fn_t)
(
struct obj_s* a,
struct obj_s* b,
struct obj_s* c,
struct cntx_s* cntx,
struct rntm_s* rntm,
struct thrinfo_s* thread
);
typedef void (*obj_ukr_fn_t)
(
dim_t m,
dim_t n,
dim_t k,
void* restrict alpha,
void* restrict a,
void* restrict b,
void* restrict beta,
void* restrict c, inc_t rs_c, inc_t cs_c,
auxinfo_t* restrict data,
struct cntx_s* restrict cntx
);
typedef struct obj_s
{
// Basic fields
@@ -1261,6 +1302,15 @@ typedef struct obj_s
// usually MR or NR)
dim_t m_panel; // m dimension of a "full" panel
dim_t n_panel; // n dimension of a "full" panel
pack_t schema; // pack schema, which may be unpacked
// User data pointer
void* user_data;
// Function pointers
obj_pack_fn_t pack;
obj_ker_fn_t ker;
obj_ukr_fn_t ukr;
} obj_t;
// Pre-initializors. Things that must be set afterwards:
@@ -1297,7 +1347,14 @@ typedef struct obj_s
.ps = 0, \
.pd = 0, \
.m_panel = 0, \
.n_panel = 0 \
.n_panel = 0, \
.schema = BLIS_NOT_PACKED, \
\
.user_data = NULL, \
\
.pack = NULL, \
.ker = NULL, \
.ukr = NULL \
}
#define BLIS_OBJECT_INITIALIZER_1X1 \
@@ -1325,7 +1382,14 @@ typedef struct obj_s
.ps = 0, \
.pd = 0, \
.m_panel = 0, \
.n_panel = 0 \
.n_panel = 0, \
.schema = BLIS_NOT_PACKED, \
\
.user_data = NULL, \
\
.pack = NULL, \
.ker = NULL, \
.ukr = NULL \
}
// Define these macros here since they must be updated if contents of
@@ -1359,6 +1423,13 @@ BLIS_INLINE void bli_obj_init_full_shallow_copy_of( obj_t* a, obj_t* b )
b->pd = a->pd;
b->m_panel = a->m_panel;
b->n_panel = a->n_panel;
b->schema = a->schema;
b->user_data = a->user_data;
b->pack = a->pack;
b->ker = a->ker;
b->ukr = a->ukr;
}
BLIS_INLINE void bli_obj_init_subpart_from( obj_t* a, obj_t* b )
@@ -1392,6 +1463,13 @@ BLIS_INLINE void bli_obj_init_subpart_from( obj_t* a, obj_t* b )
b->pd = a->pd;
b->m_panel = a->m_panel;
b->n_panel = a->n_panel;
b->schema = a->schema;
b->user_data = a->user_data;
b->pack = a->pack;
b->ker = a->ker;
b->ukr = a->ukr;
}
// Initializors for global scalar constants.